RNN Encoder

class textbox.module.Encoder.rnn_encoder.BasicRNNEncoder(embedding_size, hidden_size, num_enc_layers, rnn_type, dropout_ratio, bidirectional=True)[source]

Bases: Module

Basic Recurrent Neural Network (RNN) encoder.

forward(input_embeddings, input_length, hidden_states=None)[source]

Implement the encoding process.

Parameters
  • input_embeddings (Torch.Tensor) – source sequence embedding, shape: [batch_size, sequence_length, embedding_size].

  • input_length (Torch.Tensor) – length of input sequence, shape: [batch_size].

  • hidden_states (Torch.Tensor) – initial hidden states, default: None.

Returns

  • Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size].

  • Torch.Tensor: hidden states, shape: [num_layers * num_directions, batch_size, hidden_size].

Return type

tuple

init_hidden(input_embeddings)[source]

Initialize initial hidden states of RNN.

Parameters

input_embeddings (Torch.Tensor) – input sequence embedding, shape: [batch_size, sequence_length, embedding_size].

Returns

the initial hidden states.

Return type

Torch.Tensor

training: bool