textbox.module.strategy

Common Strategys in text generation

class textbox.module.strategy.Beam_Search_Hypothesis(beam_size, sos_token_idx, eos_token_idx, device, idx2token)[source]

Bases: object

Class designed for beam search.

generate()[source]

Pick the hypothesis with max prob among beam_size hypothesises.

Returns

the generated tokens

Return type

List[str]

step(gen_idx, token_logits, decoder_states=None, encoder_output=None, encoder_mask=None, input_type='token')[source]

A step for beam search.

Parameters
  • gen_idx (int) – the generated step number.

  • token_logits (torch.Tensor) – logits distribution, shape: [hyp_num, sequence_length, vocab_size].

  • decoder_states (torch.Tensor, optional) – the states of decoder needed to choose, shape: [hyp_num, sequence_length, hidden_size], default: None.

  • encoder_output (torch.Tensor, optional) – the output of encoder needed to copy, shape: [hyp_num, sequence_length, hidden_size], default: None.

  • encoder_mask (torch.Tensor, optional) – the mask of encoder to copy, shape: [hyp_num, sequence_length], default: None.

Returns

the next input squence, shape: [hyp_num], torch.Tensor, optional: the chosen states of decoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed output of encoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed mask of encoder, shape: [new_hyp_num, sequence_length]

Return type

torch.Tensor

stop()[source]

Determine if the beam search is over.

Returns

True represents the search over, Flase represents the search working.

Return type

Bool

Bases: object

generate()[source]
step(gen_idx, vocab_dists, decoder_hidden_states, kwargs=None)[source]
stop()[source]

Find the index of max logits

Parameters

logits (torch.Tensor) – logits distribution

Returns

the chosen index of token

Return type

torch.Tensor

textbox.module.strategy.topk_sampling(logits, temperature=1.0, top_k=0, top_p=0.9)[source]

Filter a distribution of logits using top-k and/or nucleus (top-p) filtering

Parameters
  • logits (torch.Tensor) – logits distribution

  • >0 (top_k) – keep only top k tokens with highest probability (top-k filtering).

  • >0.0 (top_p) – keep the top tokens with cumulative probability >= top_p (nucleus filtering).

Returns

the chosen index of token.

Return type

torch.Tensor