Source code for textbox.module.Decoder.rnn_decoder

# @Time   : 2020/11/14
# @Author : Junyi Li
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2020/12/26
# @Author : Jinhao Jiang
# @Email  : jiangjinhao@std.uestc.edu.cn

r"""
RNN Decoder
###############
"""

import torch
from torch import nn
import torch.nn.functional as F
from textbox.module.Attention.attention_mechanism import LuongAttention, BahdanauAttention, MonotonicAttention


[docs]class BasicRNNDecoder(torch.nn.Module): r""" Basic Recurrent Neural Network (RNN) decoder. """ def __init__(self, embedding_size, hidden_size, num_dec_layers, rnn_type, dropout_ratio=0.0): super(BasicRNNDecoder, self).__init__() self.rnn_type = rnn_type self.num_dec_layers = num_dec_layers self.hidden_size = hidden_size self.embedding_size = embedding_size if rnn_type == 'lstm': self.decoder = nn.LSTM(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == "gru": self.decoder = nn.GRU(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == "rnn": self.decoder = nn.RNN(embedding_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) else: raise ValueError("The RNN type in decoder must in ['lstm', 'gru', 'rnn'].")
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (Torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: Torch.Tensor: the initial hidden states. """ batch_size = input_embeddings.size(0) device = input_embeddings.device if self.rnn_type == 'lstm': h_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': return torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) else: raise NotImplementedError("No such rnn type {} for initializing decoder states.".format(self.rnn_type))
[docs] def forward(self, input_embeddings, hidden_states=None): r""" Implement the decoding process. Args: input_embeddings (Torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size]. hidden_states (Torch.Tensor): initial hidden states, default: None. Returns: tuple: - Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. - Torch.Tensor: hidden states, shape: [num_layers * num_directions, batch_size, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) if not isinstance(hidden_states, tuple): hidden_states = hidden_states.contiguous() outputs, hidden_states = self.decoder(input_embeddings, hidden_states) return outputs, hidden_states
[docs]class AttentionalRNNDecoder(torch.nn.Module): r""" Attention-based Recurrent Neural Network (RNN) decoder. """ def __init__( self, embedding_size, hidden_size, context_size, num_dec_layers, rnn_type, dropout_ratio=0.0, attention_type='LuongAttention', alignment_method='concat' ): super(AttentionalRNNDecoder, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.context_size = context_size self.num_dec_layers = num_dec_layers self.rnn_type = rnn_type self.attention_type = attention_type self.alignment_method = alignment_method if attention_type == 'LuongAttention': self.attentioner = LuongAttention(self.context_size, self.hidden_size, self.alignment_method) dec_input_size = embedding_size elif attention_type == 'BahdanauAttention': self.attentioner = BahdanauAttention(self.context_size, self.hidden_size) dec_input_size = embedding_size + context_size elif attention_type == 'MonotonicAttention': self.attentioner = MonotonicAttention(self.context_size, self.hidden_size) dec_input_size = embedding_size else: raise ValueError("Attention type must be in ['LuongAttention', 'BahdanauAttention', 'MonotonicAttention'].") if rnn_type == 'lstm': self.decoder = nn.LSTM(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == 'gru': self.decoder = nn.GRU(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == 'rnn': self.decoder = nn.RNN(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) else: raise ValueError("RNN type in attentional decoder must be in ['lstm', 'gru', 'rnn'].") self.attention_dense = nn.Linear(hidden_size + context_size, hidden_size)
[docs] def init_hidden(self, input_embeddings): r""" Initialize initial hidden states of RNN. Args: input_embeddings (Torch.Tensor): input sequence embedding, shape: [batch_size, sequence_length, embedding_size]. Returns: Torch.Tensor: the initial hidden states. """ batch_size = input_embeddings.size(0) device = input_embeddings.device if self.rnn_type == 'lstm': h_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) c_0 = torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) hidden_states = (h_0, c_0) return hidden_states elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': return torch.zeros(self.num_dec_layers, batch_size, self.hidden_size).to(device) else: raise NotImplementedError("No such rnn type {} for initializing decoder states.".format(self.rnn_type))
[docs] def forward( self, input_embeddings, hidden_states=None, encoder_outputs=None, encoder_masks=None, previous_probs=None ): r""" Implement the attention-based decoding process. Args: input_embeddings (Torch.Tensor): source sequence embedding, shape: [batch_size, sequence_length, embedding_size]. hidden_states (Torch.Tensor): initial hidden states, default: None. encoder_outputs (Torch.Tensor): encoder output features, shape: [batch_size, sequence_length, hidden_size], default: None. encoder_masks (Torch.Tensor): encoder state masks, shape: [batch_size, sequence_length], default: None. Returns: tuple: - Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size]. - Torch.Tensor: hidden states, shape: [batch_size, num_layers * num_directions, hidden_size]. """ if hidden_states is None: hidden_states = self.init_hidden(input_embeddings) if encoder_outputs is not None and encoder_masks is None: encoder_masks = torch.ones(encoder_outputs.size(0), encoder_outputs.size(1)).to(encoder_outputs.device) decode_length = input_embeddings.size(1) probs = previous_probs all_outputs = [] for step in range(decode_length): if self.attention_type == 'BahdanauAttention': # only top layer if self.rnn_type == 'lstm': hidden = hidden_states[0][-1] else: hidden = hidden_states[-1] context, probs = self.attentioner(hidden, encoder_outputs, encoder_masks) embed = input_embeddings[:, step, :].unsqueeze(1) inputs = torch.cat((embed, context), dim=-1) else: inputs = input_embeddings[:, step, :].unsqueeze(1) context = None if (not isinstance(hidden_states, tuple)): hidden_states = hidden_states.contiguous() outputs, hidden_states = self.decoder(inputs, hidden_states) if self.attention_type == 'LuongAttention' and context is None: context, probs = self.attentioner(outputs, encoder_outputs, encoder_masks) elif self.attention_type == 'MonotonicAttention' and context is None: if self.training: context, probs = self.attentioner.soft(outputs, encoder_outputs, encoder_masks, probs) else: context, probs = self.attentioner.hard(outputs, encoder_outputs, encoder_masks, probs) elif self.attention_type == 'BahdanauAttention': pass else: raise NotImplementedError("No such attention type {} for decoder output.".format(self.attention_type)) outputs = self.attention_dense(torch.cat((outputs, context), dim=2)) all_outputs.append(outputs) outputs = torch.cat(all_outputs, dim=1) return outputs, hidden_states, probs
[docs]class PointerRNNDecoder(nn.Module): def __init__( self, vocab_size, embedding_size, hidden_size, context_size, num_dec_layers, rnn_type, dropout_ratio=0.0, is_attention=False, is_pgen=False, is_coverage=False ): super(PointerRNNDecoder, self).__init__() self.embedding_size = embedding_size self.hidden_size = hidden_size self.context_size = context_size self.num_dec_layers = num_dec_layers self.rnn_type = rnn_type dec_input_size = embedding_size if rnn_type == 'lstm': self.decoder = nn.LSTM(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == 'gru': self.decoder = nn.GRU(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) elif rnn_type == 'rnn': self.decoder = nn.RNN(dec_input_size, hidden_size, num_dec_layers, batch_first=True, dropout=dropout_ratio) else: raise ValueError("RNN type in attentional decoder must be in ['lstm', 'gru', 'rnn'].") self.is_attention = is_attention self.is_pgen = is_pgen and is_attention self.is_coverage = is_coverage and is_attention self.vocab_linear = nn.Linear(hidden_size, vocab_size) if self.is_attention: self.x_context = nn.Linear(embedding_size + context_size, embedding_size) self.attention = LuongAttention(self.context_size, self.hidden_size, 'concat', self.is_coverage) self.attention_dense = nn.Linear(hidden_size + context_size, hidden_size) if self.is_pgen: self.p_gen_linear = nn.Linear(context_size + hidden_size + embedding_size, 1)
[docs] def forward(self, input_embeddings, decoder_hidden_states, kwargs=None): if not self.is_attention: decoder_outputs, decoder_hidden_states = self.decoder(input_embeddings, decoder_hidden_states) vocab_dists = F.softmax(self.vocab_linear(decoder_outputs), dim=-1) return vocab_dists, decoder_hidden_states, kwargs else: vocab_dists = [] encoder_outputs = kwargs['encoder_outputs'] # B x src_len x 256 encoder_masks = kwargs['encoder_masks'] # B x src_len context = kwargs['context'] # B x 1 x 256 extra_zeros = None source_extended_idx = None if self.is_pgen: extra_zeros = kwargs['extra_zeros'] # B x max_oovs_num source_extended_idx = kwargs['source_extended_idx'] # B x src_len (contains oovs ids) coverage = None attn_dists = None coverages = None if self.is_coverage: coverage = kwargs['coverages'] coverages = [] attn_dists = [] tgt_len = input_embeddings.size(1) for step in range(tgt_len): step_input_embeddings = input_embeddings[:, step, :].unsqueeze(1) # B x 1 x 128 x = self.x_context(torch.cat((step_input_embeddings, context), dim=-1)) # B x 1 x 128 decoder_outputs, decoder_hidden_states = self.decoder(x, decoder_hidden_states) # B x 1 x 256 context, attn_dist, coverage = self.attention(decoder_outputs, encoder_outputs, encoder_masks, coverage) # B x 1 x src_len vocab_logits = self.vocab_linear(self.attention_dense(torch.cat((decoder_outputs, context), dim=-1))) vocab_dist = F.softmax(vocab_logits, dim=-1) # B x 1 x vocab_size if self.is_pgen: p_gen_input = torch.cat((context, decoder_outputs, x), dim=-1) # B x 1 x (256 + 256 + 128) p_gen = torch.sigmoid(self.p_gen_linear(p_gen_input)) # B x 1 x 1 copy_attn_dist = (1 - p_gen) * attn_dist # B x 1 x src_len # B x 1 x (vocab_size+max_oovs_num) extended_vocab_dist = torch.cat(((vocab_dist * p_gen), extra_zeros.unsqueeze(1)), dim=-1) # add copy probs to vocab dist vocab_dist = extended_vocab_dist.scatter_add(2, source_extended_idx.unsqueeze(1), copy_attn_dist) if self.is_coverage: attn_dists.append(attn_dist) coverages.append(coverage) vocab_dists.append(vocab_dist) vocab_dists = torch.cat(vocab_dists, dim=1) # B x tgt_len x vocab_size+(max_oovs_num) kwargs['context'] = context if self.is_coverage: coverages = torch.cat(coverages, dim=1) # B x tgt_len x src_len attn_dists = torch.cat(attn_dists, dim=1) # B x tgt_len x src_len kwargs['attn_dists'] = attn_dists kwargs['coverages'] = coverages return vocab_dists, decoder_hidden_states, kwargs