Source code for textbox.model.Seq2Seq.hred

# @Time   : 2021/1/26
# @Author : Lai Xu
# @Email  : tsui_lai@163.com

r"""
HierarchicalRNN
################################################
Reference:
    Serban et al. "Building End-To-End Dialogue Systems Using Generative Hierarchical Neural Network Models" in AAAI 2016.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from textbox.model.abstract_generator import Seq2SeqGenerator
from textbox.module.Encoder.rnn_encoder import BasicRNNEncoder
from textbox.module.Decoder.rnn_decoder import BasicRNNDecoder, AttentionalRNNDecoder
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import topk_sampling, greedy_search, Beam_Search_Hypothesis


[docs]class HRED(Seq2SeqGenerator): r""" This is a description """ def __init__(self, config, dataset): super(HRED, self).__init__(config, dataset) # # load parameters info self.embedding_size = config['embedding_size'] self.hidden_size = config['hidden_size'] self.num_enc_layers = config['num_enc_layers'] self.num_dec_layers = config['num_dec_layers'] self.rnn_type = config['rnn_type'] self.bidirectional = config['bidirectional'] self.num_directions = 2 if self.bidirectional else 1 self.dropout_ratio = config['dropout_ratio'] self.strategy = config['decoding_strategy'] self.attention_type = config['attention_type'] self.alignment_method = config['alignment_method'] if (self.strategy not in ['topk_sampling', 'greedy_search', 'beam_search']): raise NotImplementedError("{} decoding strategy not implemented".format(self.strategy)) if (self.strategy == 'beam_search'): self.beam_size = config['beam_size'] self.context_size = self.hidden_size * self.num_directions # define layers and loss self.token_embedder = nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=self.padding_token_idx) self.utterance_encoder = BasicRNNEncoder( self.embedding_size, self.hidden_size, self.num_enc_layers, self.rnn_type, self.dropout_ratio, self.bidirectional ) self.context_encoder = BasicRNNEncoder( self.hidden_size * 2, self.hidden_size, self.num_enc_layers, self.rnn_type, self.dropout_ratio, self.bidirectional ) if self.attention_type is not None: self.decoder = AttentionalRNNDecoder( self.embedding_size + self.hidden_size, self.hidden_size, self.context_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio, self.attention_type, self.alignment_method ) else: self.decoder = BasicRNNDecoder( self.embedding_size + self.hidden_size, self.hidden_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio ) self.dropout = nn.Dropout(self.dropout_ratio) self.vocab_linear = nn.Linear(self.hidden_size, self.vocab_size) self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') # parameters initialization self.apply(xavier_normal_initialization)
[docs] def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token utt_states, context_states = self.encode(batch_data) # [b, t, nd * h], [nl, b, h] source_length = batch_data['source_length'] # [b, t] utt_masks = torch.ne(source_length, 0) # [b, t] for bid in range(utt_states.size(0)): encoder_states = utt_states[bid].unsqueeze(0) # [1, t, nd * h] decoder_states = context_states[:, bid, :].unsqueeze(1) # [nl, 1, h] context_state = decoder_states[-1].unsqueeze(0) # [1, 1, h] encoder_masks = utt_masks[bid].unsqueeze(0) # [1, t] genetare_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token ) for gen_idx in range(self.target_max_length): input_embedding = self.token_embedder(input_seq) # [beam, 1, e] decoder_input = torch.cat((input_embedding, context_state.repeat(input_embedding.size(0), 1, 1)), dim=-1) # [beam, 1, e + h] if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_states, encoder_masks ) else: decoder_outputs, decoder_states = self.decoder(decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_states, encoder_masks = hypothesis.step( gen_idx, token_logits, decoder_states, encoder_states, encoder_masks ) if (self.strategy in ['topk_sampling', 'geedy_search']): if token_idx == self.eos_token_idx: break else: genetare_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx]]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
[docs] def encode(self, corpus): source_text = corpus['source_idx'] # [b, t, l] source_length = corpus['source_length'] # [b, t] source_sentence_num = corpus['source_num'] # [b] batch_size = source_text.size(0) turn_num = source_text.size(1) total_utt_states = [] for turn in range(turn_num): utt_embeddings = self.token_embedder(source_text[:, turn, :]) # [b, len, e] utt_length = source_length[:, turn] # [b] utt_mask = torch.ne(utt_length, 0) _, utt_states = self.utterance_encoder(utt_embeddings[utt_mask], utt_length[utt_mask]) # [nl * nd, b, h] tp_utt_states = torch.zeros(self.num_directions * self.num_enc_layers, batch_size, self.hidden_size).to(self.device) tp_utt_states[:, utt_mask] = utt_states total_utt_states.append(tp_utt_states) utt_states = torch.stack(total_utt_states, dim=2) # [nl * nd, b, t, h] utt_states = utt_states[-self.num_directions:] # [nd, b, t, h] utt_states = utt_states.permute(1, 2, 0, 3).reshape(batch_size, turn_num, -1) # [b, t, nd * h] _, context_states = self.context_encoder(utt_states, source_sentence_num) # [nl * nd, b, h] if self.bidirectional: context_states = context_states.reshape(self.num_enc_layers, 2, batch_size, -1).sum(dim=1).contiguous() # [nl, b, h] return utt_states, context_states
[docs] def forward(self, corpus, epoch_idx=0): utt_states, context_states = self.encode(corpus) # [b, t, nd * h], [nl, b, h] input_text = corpus['target_idx'][:, :-1] target_text = corpus['target_idx'][:, 1:] target_length = corpus['target_length'] input_embeddings = self.token_embedder(input_text) # [b, l, e] context_state = context_states[-1].unsqueeze(1).repeat(1, input_embeddings.size(1), 1) # [b, l, h] inputs = torch.cat((input_embeddings, context_state), dim=-1) # [b, l, e + h] if self.attention_type is not None: source_length = corpus['source_length'] utt_masks = torch.ne(source_length, 0) decoder_outputs, decoder_states, _ = self.decoder(inputs, context_states, utt_states, utt_masks) else: decoder_outputs, decoder_states = self.decoder(inputs, context_states) token_logits = self.vocab_linear(self.dropout(decoder_outputs)) # [b, l, v] loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.contiguous().view(-1)) loss = loss.reshape_as(target_text) length = target_length - 1 loss = loss.sum(dim=1) / length.float() loss = loss.mean() return loss