RNN Decoder¶
- class textbox.module.Decoder.rnn_decoder.AttentionalRNNDecoder(embedding_size, hidden_size, context_size, num_dec_layers, rnn_type, dropout_ratio=0.0, attention_type='LuongAttention', alignment_method='concat')[source]¶
Bases:
Module
Attention-based Recurrent Neural Network (RNN) decoder.
- forward(input_embeddings, hidden_states=None, encoder_outputs=None, encoder_masks=None, previous_probs=None)[source]¶
Implement the attention-based decoding process.
- Parameters
input_embeddings (Torch.Tensor) – source sequence embedding, shape: [batch_size, sequence_length, embedding_size].
hidden_states (Torch.Tensor) – initial hidden states, default: None.
encoder_outputs (Torch.Tensor) – encoder output features, shape: [batch_size, sequence_length, hidden_size], default: None.
encoder_masks (Torch.Tensor) – encoder state masks, shape: [batch_size, sequence_length], default: None.
- Returns
Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size].
Torch.Tensor: hidden states, shape: [batch_size, num_layers * num_directions, hidden_size].
- Return type
tuple
Initialize initial hidden states of RNN.
- Parameters
input_embeddings (Torch.Tensor) – input sequence embedding, shape: [batch_size, sequence_length, embedding_size].
- Returns
the initial hidden states.
- Return type
Torch.Tensor
- training: bool¶
- class textbox.module.Decoder.rnn_decoder.BasicRNNDecoder(embedding_size, hidden_size, num_dec_layers, rnn_type, dropout_ratio=0.0)[source]¶
Bases:
Module
Basic Recurrent Neural Network (RNN) decoder.
- forward(input_embeddings, hidden_states=None)[source]¶
Implement the decoding process.
- Parameters
input_embeddings (Torch.Tensor) – target sequence embedding, shape: [batch_size, sequence_length, embedding_size].
hidden_states (Torch.Tensor) – initial hidden states, default: None.
- Returns
Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size].
Torch.Tensor: hidden states, shape: [num_layers * num_directions, batch_size, hidden_size].
- Return type
tuple
Initialize initial hidden states of RNN.
- Parameters
input_embeddings (Torch.Tensor) – input sequence embedding, shape: [batch_size, sequence_length, embedding_size].
- Returns
the initial hidden states.
- Return type
Torch.Tensor
- training: bool¶
- class textbox.module.Decoder.rnn_decoder.PointerRNNDecoder(vocab_size, embedding_size, hidden_size, context_size, num_dec_layers, rnn_type, dropout_ratio=0.0, is_attention=False, is_pgen=False, is_coverage=False)[source]¶
Bases:
Module
- forward(input_embeddings, decoder_hidden_states, kwargs=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶