Source code for textbox.model.Attribute.attr2seq

# @Time   : 2021/2/3
# @Author : Zhipeng Chen
# @Email  : zhipeng_chen@ruc.edu.cn

r"""
Attr2Seq
################################################
Reference:
    Li Dong et al. "Learning to Generate Product Reviews from Attributes" in 2017.
"""

import torch
import torch.nn as nn

from textbox.model.abstract_generator import AttributeGenerator
from textbox.module.Decoder.rnn_decoder import AttentionalRNNDecoder
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import topk_sampling, greedy_search, Beam_Search_Hypothesis


[docs]class Attr2Seq(AttributeGenerator): r"""Attribute Encoder and RNN-based Decoder architecture is a basic frame work for Attr2Seq text generation. """ def __init__(self, config, dataset): super(Attr2Seq, self).__init__(config, dataset) # load parameters info self.rnn_type = config['rnn_type'] self.attention_type = config['attention_type'] self.alignment_method = config['alignment_method'] self.embedding_size = config['embedding_size'] self.hidden_size = config['hidden_size'] self.num_dec_layers = config['num_dec_layers'] self.dropout_ratio = config['dropout_ratio'] self.strategy = config['decoding_strategy'] if (self.strategy not in ['topk_sampling', 'greedy_search', 'beam_search']): raise NotImplementedError("{} decoding strategy not implemented".format(self.strategy)) if (self.strategy == 'beam_search'): self.beam_size = config['beam_size'] self.source_token_embedder = nn.ModuleList([ nn.Embedding(self.attribute_size[i], self.embedding_size) for i in range(self.attribute_num) ]) self.target_token_embedder = nn.Embedding( self.vocab_size, self.embedding_size, padding_idx=self.padding_token_idx ) self.decoder = AttentionalRNNDecoder( self.embedding_size, self.hidden_size, self.embedding_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio, self.attention_type, self.alignment_method ) self.dropout = nn.Dropout(self.dropout_ratio) self.vocab_linear = nn.Linear(self.hidden_size, self.vocab_size) self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') self.H = nn.Linear(self.attribute_num * self.embedding_size, self.num_dec_layers * self.hidden_size) # parameters initialization self.apply(xavier_normal_initialization)
[docs] def encoder(self, source_idx): r""" Args: source_idx (Torch.Tensor): source attribute index, shape: [batch_size, attribute_num]. Returns: tuple: - Torch.Tensor: output features, shape: [batch_size, attribute_num, embedding_size]. - Torch.Tensor: hidden states, shape: [num_dec_layers, batch_size, hidden_size]. """ # g (torch.Tensor): [batch_size, attribute_num * embedding_size]. g = [self.source_token_embedder[i](source_idx[:, i]) for i in range(self.attribute_num)] g = torch.cat(g, 1) #outputs (Torch.Tensor): shape: [batch_size, attribute_num, embedding_size]. outputs = g.reshape(self.batch_size, self.attribute_num, self.embedding_size) # a (Torch.Tensor): shape: [batch_size, num_dec_layers * hidden_size]. a = torch.tanh(self.H(g)) # hidden_states (Torch.Tensor): shape: [num_dec_layers, batch_size, hidden_size]. hidden_states = a.reshape(self.batch_size, self.num_dec_layers, self.hidden_size) hidden_states = hidden_states.transpose(0, 1) return outputs, hidden_states
[docs] def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token source_idx = batch_data['source_idx'] self.batch_size = source_idx.size(0) encoder_outputs, encoder_states = self.encoder(source_idx) for bid in range(self.batch_size): c = torch.zeros(self.num_dec_layers, 1, self.hidden_size).to(self.device) decoder_states = (encoder_states[:, bid, :].unsqueeze(1), c) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token ) for gen_idx in range(self.target_max_length): decoder_input = self.target_token_embedder(input_seq) decoder_outputs, decoder_states, _ = self.decoder(decoder_input, decoder_states, encoder_output) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_output = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx]]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
[docs] def forward(self, corpus, epoch_idx=0): # target_length (Torch.Tensor): shape: [batch_size] target_length = corpus['target_length'] # source_idx (Torch.Tensor): shape: [batch_size, attribute_num]. source_idx = corpus['source_idx'] # target_idx (torch.Tensor): shape: [batch_size, length]. target_idx = corpus['target_idx'] self.batch_size = source_idx.size(0) encoder_outputs, encoder_states = self.encoder(source_idx) encoder_states = encoder_states.contiguous() input_text = target_idx[:, :-1] target_text = target_idx[:, 1:] input_embeddings = self.dropout(self.target_token_embedder(input_text)) c = torch.zeros(self.num_dec_layers, self.batch_size, self.hidden_size).to(self.device) decoder_outputs, decoder_states, _ = \ self.decoder(input_embeddings, (encoder_states, c), encoder_outputs) # token_logits (Torch.Tensor): shape: [batch_size, target_length, vocabulary_size]. token_logits = self.vocab_linear(decoder_outputs) # token_logits.view(-1, token_logits.size(-1)) (Torch.Tensor): shape: [batch_size * target_length, vocabulary_size]. # target_text.reshape(-1) (Torch.Tensor): shape: [batch_size * target_length]. loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.reshape(-1)) loss = loss.reshape_as(target_text) loss = loss.sum(dim=1) / (target_length - 1).float() loss = loss.mean() return loss