textbox.module.strategy¶
Common Strategys in text generation
- class textbox.module.strategy.Beam_Search_Hypothesis(beam_size, sos_token_idx, eos_token_idx, device, idx2token)[source]¶
Bases:
object
Class designed for beam search.
- generate()[source]¶
Pick the hypothesis with max prob among beam_size hypothesises.
- Returns
the generated tokens
- Return type
List[str]
- step(gen_idx, token_logits, decoder_states=None, encoder_output=None, encoder_mask=None, input_type='token')[source]¶
A step for beam search.
- Parameters
gen_idx (int) – the generated step number.
token_logits (torch.Tensor) – logits distribution, shape: [hyp_num, sequence_length, vocab_size].
decoder_states (torch.Tensor, optional) – the states of decoder needed to choose, shape: [hyp_num, sequence_length, hidden_size], default: None.
encoder_output (torch.Tensor, optional) – the output of encoder needed to copy, shape: [hyp_num, sequence_length, hidden_size], default: None.
encoder_mask (torch.Tensor, optional) – the mask of encoder to copy, shape: [hyp_num, sequence_length], default: None.
- Returns
the next input squence, shape: [hyp_num], torch.Tensor, optional: the chosen states of decoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed output of encoder, shape: [new_hyp_num, sequence_length, hidden_size] torch.Tensor, optional: the copyed mask of encoder, shape: [new_hyp_num, sequence_length]
- Return type
torch.Tensor
- class textbox.module.strategy.Copy_Beam_Search(beam_size, sos_token_idx, eos_token_idx, unknown_token_idx, device, idx2token, is_attention=False, is_pgen=False, is_coverage=False)[source]¶
Bases:
object
- textbox.module.strategy.greedy_search(logits)[source]¶
Find the index of max logits
- Parameters
logits (torch.Tensor) – logits distribution
- Returns
the chosen index of token
- Return type
torch.Tensor
- textbox.module.strategy.topk_sampling(logits, temperature=1.0, top_k=0, top_p=0.9)[source]¶
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
- Parameters
logits (torch.Tensor) – logits distribution
>0 (top_k) – keep only top k tokens with highest probability (top-k filtering).
>0.0 (top_p) – keep the top tokens with cumulative probability >= top_p (nucleus filtering).
- Returns
the chosen index of token.
- Return type
torch.Tensor