Source code for textbox.model.Seq2Seq.rnnencdec

# @Time   : 2020/11/14
# @Author : Junyi Li
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2020/12/25
# @Author : Tianyi Tang
# @Email  : steventang@ruc.edu.cn

r"""
RNNEncDec
################################################
Reference:
    Sutskever et al. "Sequence to Sequence Learning with Neural Networks" in NIPS 2014.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from textbox.model.abstract_generator import Seq2SeqGenerator
from textbox.module.Encoder.rnn_encoder import BasicRNNEncoder
from textbox.module.Decoder.rnn_decoder import BasicRNNDecoder, AttentionalRNNDecoder
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import topk_sampling, greedy_search, Beam_Search_Hypothesis


[docs]class RNNEncDec(Seq2SeqGenerator): r"""RNN-based Encoder-Decoder architecture is a basic framework for Seq2Seq text generation. """ def __init__(self, config, dataset): super(RNNEncDec, self).__init__(config, dataset) # load parameters info self.embedding_size = config['embedding_size'] self.hidden_size = config['hidden_size'] self.num_enc_layers = config['num_enc_layers'] self.num_dec_layers = config['num_dec_layers'] self.rnn_type = config['rnn_type'] self.bidirectional = config['bidirectional'] self.dropout_ratio = config['dropout_ratio'] self.attention_type = config['attention_type'] self.alignment_method = config['alignment_method'] self.strategy = config['decoding_strategy'] if (self.strategy not in ['topk_sampling', 'greedy_search', 'beam_search']): raise NotImplementedError("{} decoding strategy not implemented".format(self.strategy)) if (self.strategy == 'beam_search'): self.beam_size = config['beam_size'] self.context_size = self.hidden_size # define layers and loss self.source_token_embedder = nn.Embedding( self.source_vocab_size, self.embedding_size, padding_idx=self.padding_token_idx ) if config['share_vocab']: self.target_token_embedder = self.source_token_embedder else: self.target_token_embedder = nn.Embedding( self.target_vocab_size, self.embedding_size, padding_idx=self.padding_token_idx ) self.encoder = BasicRNNEncoder( self.embedding_size, self.hidden_size, self.num_enc_layers, self.rnn_type, self.dropout_ratio, self.bidirectional ) if self.attention_type is not None: self.decoder = AttentionalRNNDecoder( self.embedding_size, self.hidden_size, self.context_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio, self.attention_type, self.alignment_method ) else: self.decoder = BasicRNNDecoder( self.embedding_size, self.hidden_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio ) self.dropout = nn.Dropout(self.dropout_ratio) self.vocab_linear = nn.Linear(self.hidden_size, self.target_vocab_size) self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') # parameters initialization self.apply(xavier_normal_initialization)
[docs] def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.target_idx2token source_text = batch_data['source_idx'] source_length = batch_data['source_length'] source_embeddings = self.source_token_embedder(source_text) encoder_outputs, encoder_states = self.encoder(source_embeddings, source_length) if self.bidirectional: encoder_outputs = encoder_outputs[:, :, self.hidden_size:] + encoder_outputs[:, :, :self.hidden_size] if (self.rnn_type == 'lstm'): encoder_states = (encoder_states[0][::2], encoder_states[1][::2]) else: encoder_states = encoder_states[::2] encoder_masks = torch.ne(source_text, self.padding_token_idx) for bid in range(source_text.size(0)): decoder_states = encoder_states[:, bid, :].unsqueeze(1) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) encoder_mask = encoder_masks[bid, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token ) for gen_idx in range(self.target_max_length): decoder_input = self.target_token_embedder(input_seq) if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_output, encoder_mask ) else: decoder_outputs, decoder_states = self.decoder(decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): if self.attention_type is not None: input_seq, decoder_states, encoder_output, encoder_mask = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output, encoder_mask) else: input_seq, decoder_states = hypothesis.step(gen_idx, token_logits, decoder_states) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx]]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
[docs] def forward(self, corpus, epoch_idx=0): source_text = corpus['source_idx'] source_length = corpus['source_length'] input_text = corpus['target_idx'][:, :-1] target_text = corpus['target_idx'][:, 1:] source_embeddings = self.source_token_embedder(source_text) input_embeddings = self.target_token_embedder(input_text) encoder_outputs, encoder_states = self.encoder(source_embeddings, source_length) if self.bidirectional: encoder_outputs = encoder_outputs[:, :, self.hidden_size:] + encoder_outputs[:, :, :self.hidden_size] if (self.rnn_type == 'lstm'): encoder_states = (encoder_states[0][::2], encoder_states[1][::2]) else: encoder_states = encoder_states[::2] encoder_masks = torch.ne(source_text, self.padding_token_idx) if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( input_embeddings, encoder_states, encoder_outputs, encoder_masks ) else: decoder_outputs, decoder_states = self.decoder(input_embeddings, encoder_states) token_logits = self.vocab_linear(self.dropout(decoder_outputs)) loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.contiguous().view(-1)) loss = loss.reshape_as(target_text) length = corpus['target_length'] - 1 loss = loss.sum(dim=1) / length.float() loss = loss.mean() return loss