# @Time : 2021/1/27
# @Author : Zhuohao Yu
# @Email : zhuohao@ruc.edu.cn
r"""
C2S
################################################
Reference:
Jian Tang et al. "Context-aware Natural Language Generation with Recurrent Neural Networks" in 2016.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from textbox.model.abstract_generator import AttributeGenerator
from textbox.module.Decoder.rnn_decoder import BasicRNNDecoder
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import Beam_Search_Hypothesis
[docs]class C2S(AttributeGenerator):
r"""Context-aware Natural Language Generation with Recurrent Neural Network
"""
def __init__(self, config, dataset):
super(C2S, self).__init__(config, dataset)
# Load hyperparameters
self.embedding_size = config['embedding_size']
self.hidden_size = config['hidden_size']
self.num_dec_layers = config['num_dec_layers']
self.dropout_ratio = config['dropout_ratio']
self.rnn_type = config['rnn_type']
self.is_gated = config['gated']
self.decoding_strategy = config['decoding_strategy']
if self.decoding_strategy == 'beam_search':
self.beam_size = config['beam_size']
# Layers
self.token_embedder = nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=self.padding_token_idx)
self.attr_embedder = nn.ModuleList([
nn.Embedding(self.attribute_size[i], self.embedding_size) for i in range(self.attribute_num)
])
self.decoder = BasicRNNDecoder(
self.embedding_size, self.hidden_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio
)
self.vocab_linear = nn.Linear(self.hidden_size, self.vocab_size)
self.attr_linear = nn.Linear(self.attribute_num * self.embedding_size, self.hidden_size)
if self.is_gated:
self.gate_linear = nn.Linear(self.hidden_size, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_ratio)
# Loss
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none')
# Initialize parameters
self.apply(xavier_normal_initialization)
[docs] def encoder(self, attr_data):
attr_embeddings = []
for attr_idx in range(self.attribute_num):
kth_dim_attr = attr_data[:, attr_idx]
kth_dim_embeddings = self.attr_embedder[attr_idx](kth_dim_attr)
attr_embeddings.append(kth_dim_embeddings)
attr_embeddings = torch.cat(attr_embeddings, dim=1)
h_c = torch.tanh(self.attr_linear(attr_embeddings)).contiguous()
return attr_embeddings, h_c
[docs] def forward(self, corpus, epoch_idx=-1, nll_test=False):
input_text = corpus['target_idx'][:, :-1]
input_attr = corpus['source_idx']
target_text = corpus['target_idx'][:, 1:]
attr_embeddings, h_c_1D = self.encoder(input_attr)
h_c = h_c_1D.repeat(self.num_dec_layers, 1, 1)
input_embeddings = self.token_embedder(input_text)
outputs, _ = self.decoder(input_embeddings, h_c)
if self.is_gated:
m_t = torch.sigmoid(self.gate_linear(outputs)).permute(1, 0, 2)
outputs = outputs + (m_t * h_c_1D).permute(1, 0, 2)
outputs = self.dropout(outputs)
token_logits = self.vocab_linear(outputs)
loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.contiguous().view(-1))
loss = loss.reshape_as(target_text)
length = corpus['target_length'] - 1
loss = loss.sum(dim=1) / length.float()
loss = loss.mean()
return loss
[docs] def generate(self, batch_data, eval_data):
generate_corpus = []
idx2token = eval_data.idx2token
batch_size = batch_data['source_idx'].size(0)
attr_embeddings, h_c_1D = self.encoder(batch_data['source_idx'])
h_c = h_c_1D.repeat(self.num_dec_layers, 1, 1)
for bid in range(batch_size):
hidden_states = h_c[:, bid, :].unsqueeze(1).contiguous()
generate_tokens = []
input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device)
if (self.decoding_strategy == 'beam_search'):
hypothesis = Beam_Search_Hypothesis(
self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token
)
for gen_idx in range(self.max_length):
decoder_input = self.token_embedder(input_seq)
outputs, hidden_states = self.decoder(decoder_input, hidden_states)
if self.is_gated:
m_t = torch.sigmoid(self.gate_linear(outputs))
outputs = outputs + m_t * h_c_1D[bid]
token_logits = self.vocab_linear(outputs)
if (self.decoding_strategy == 'topk_sampling'):
token_idx = topk_sampling(token_logits).item()
elif (self.decoding_strategy == 'greedy_search'):
token_idx = greedy_search(token_logits).item()
elif (self.decoding_strategy == 'beam_search'):
input_seq, hidden_states = \
hypothesis.step(gen_idx, token_logits, hidden_states)
if (self.decoding_strategy in ['topk_sampling', 'greedy_search']):
if token_idx == self.eos_token_idx:
break
else:
generate_tokens.append(idx2token[token_idx])
input_seq = torch.LongTensor([[token_idx]]).to(self.device)
elif (self.decoding_strategy == 'beam_search'):
if (hypothesis.stop()):
break
if (self.decoding_strategy == 'beam_search'):
generate_tokens = hypothesis.generate()
generate_corpus.append(generate_tokens)
return generate_corpus