RNN Encoder¶
- class textbox.module.Encoder.rnn_encoder.BasicRNNEncoder(embedding_size, hidden_size, num_enc_layers, rnn_type, dropout_ratio, bidirectional=True)[source]¶
Bases:
Module
Basic Recurrent Neural Network (RNN) encoder.
- forward(input_embeddings, input_length, hidden_states=None)[source]¶
Implement the encoding process.
- Parameters
input_embeddings (Torch.Tensor) – source sequence embedding, shape: [batch_size, sequence_length, embedding_size].
input_length (Torch.Tensor) – length of input sequence, shape: [batch_size].
hidden_states (Torch.Tensor) – initial hidden states, default: None.
- Returns
Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size].
Torch.Tensor: hidden states, shape: [num_layers * num_directions, batch_size, hidden_size].
- Return type
tuple
Initialize initial hidden states of RNN.
- Parameters
input_embeddings (Torch.Tensor) – input sequence embedding, shape: [batch_size, sequence_length, embedding_size].
- Returns
the initial hidden states.
- Return type
Torch.Tensor
- training: bool¶