Source code for textbox.model.Seq2Seq.transformerencdec
# @Time : 2020/11/14
# @Author : Junyi Li
# @Email : lijunyi@ruc.edu.cn
# UPDATE:
# @Time : 2020/12/27
# @Author : Tianyi Tang
# @Email : steventang@ruc.edu.cn
r"""
TransformerEncDec
################################################
Reference:
Vaswani et al. "Attention is All you Need" in NIPS 2017.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from textbox.model.abstract_generator import Seq2SeqGenerator
from textbox.module.Encoder.transformer_encoder import TransformerEncoder
from textbox.module.Decoder.transformer_decoder import TransformerDecoder
from textbox.module.Embedder.position_embedder import LearnedPositionalEmbedding, SinusoidalPositionalEmbedding
from textbox.module.Attention.attention_mechanism import SelfAttentionMask
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import topk_sampling, greedy_search, Beam_Search_Hypothesis
[docs]class TransformerEncDec(Seq2SeqGenerator):
r"""Transformer-based Encoder-Decoder architecture is a powerful framework for Seq2Seq text generation.
"""
def __init__(self, config, dataset):
super(TransformerEncDec, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.ffn_size = config['ffn_size']
self.num_heads = config['num_heads']
self.num_enc_layers = config['num_enc_layers']
self.num_dec_layers = config['num_dec_layers']
self.attn_dropout_ratio = config['attn_dropout_ratio']
self.attn_weight_dropout_ratio = config['attn_weight_dropout_ratio']
self.ffn_dropout_ratio = config['ffn_dropout_ratio']
self.decoding_strategy = config['decoding_strategy']
if (self.decoding_strategy not in ['topk_sampling', 'greedy_search', 'beam_search']):
raise NotImplementedError("{} decoding strategy not implemented".format(self.strategy))
if (self.decoding_strategy == 'beam_search'):
self.beam_size = config['beam_size']
# define layers and loss
self.source_token_embedder = nn.Embedding(
self.source_vocab_size, self.embedding_size, padding_idx=self.padding_token_idx
)
if config['share_vocab']:
self.target_token_embedder = self.source_token_embedder
else:
self.target_token_embedder = nn.Embedding(
self.target_vocab_size, self.embedding_size, padding_idx=self.padding_token_idx
)
if config['learned_position_embedder']:
self.position_embedder = LearnedPositionalEmbedding(self.embedding_size)
else:
self.position_embedder = SinusoidalPositionalEmbedding(self.embedding_size)
self.self_attn_mask = SelfAttentionMask()
self.encoder = TransformerEncoder(
self.embedding_size, self.ffn_size, self.num_enc_layers, self.num_heads, self.attn_dropout_ratio,
self.attn_weight_dropout_ratio, self.ffn_dropout_ratio
)
self.decoder = TransformerDecoder(
self.embedding_size,
self.ffn_size,
self.num_dec_layers,
self.num_heads,
self.attn_dropout_ratio,
self.attn_weight_dropout_ratio,
self.ffn_dropout_ratio,
with_external=True
)
self.vocab_linear = nn.Linear(self.embedding_size, self.target_vocab_size)
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none')
# parameters initialization
self.reset_parameters()
[docs] def reset_parameters(self):
nn.init.normal_(self.vocab_linear.weight, std=0.02)
nn.init.constant_(self.vocab_linear.bias, 0.)
[docs] def generate(self, batch_data, eval_data):
generate_corpus = []
idx2token = eval_data.target_idx2token
source_text = batch_data['source_idx']
source_embeddings = self.source_token_embedder(source_text) + \
self.position_embedder(source_text).to(self.device)
source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device)
encoder_outputs = self.encoder(
source_embeddings, self_padding_mask=source_padding_mask, output_all_encoded_layers=False
)
for bid in range(source_text.size(0)):
encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
encoder_mask = source_padding_mask[bid, :].unsqueeze(0)
generate_tokens = []
prev_token_ids = [self.sos_token_idx]
input_seq = torch.LongTensor([prev_token_ids]).to(self.device)
if (self.decoding_strategy == 'beam_search'):
hypothesis = Beam_Search_Hypothesis(
self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token
)
for gen_idx in range(self.target_max_length):
self_attn_mask = self.self_attn_mask(input_seq.size(-1)).bool().to(self.device)
decoder_input = self.target_token_embedder(input_seq) + \
self.position_embedder(input_seq).to(self.device)
decoder_outputs = self.decoder(
decoder_input,
self_attn_mask=self_attn_mask,
external_states=encoder_output,
external_padding_mask=encoder_mask
)
token_logits = self.vocab_linear(decoder_outputs[:, -1, :].unsqueeze(1))
if (self.decoding_strategy == 'topk_sampling'):
token_idx = topk_sampling(token_logits).item()
elif (self.decoding_strategy == 'greedy_search'):
token_idx = greedy_search(token_logits).item()
elif (self.decoding_strategy == 'beam_search'):
input_seq, encoder_output, encoder_mask = \
hypothesis.step(gen_idx, token_logits, encoder_output=encoder_output, encoder_mask=encoder_mask, input_type='whole')
if (self.decoding_strategy in ['topk_sampling', 'greedy_search']):
if token_idx == self.eos_token_idx:
break
else:
generate_tokens.append(idx2token[token_idx])
prev_token_ids.append(token_idx)
input_seq = torch.LongTensor([prev_token_ids]).to(self.device)
elif (self.decoding_strategy == 'beam_search'):
if (hypothesis.stop()):
break
if (self.decoding_strategy == 'beam_search'):
generate_tokens = hypothesis.generate()
generate_corpus.append(generate_tokens)
return generate_corpus
[docs] def forward(self, corpus, epoch_idx=0):
source_text = corpus['source_idx']
input_text = corpus['target_idx'][:, :-1]
target_text = corpus['target_idx'][:, 1:]
source_embeddings = self.source_token_embedder(source_text) + self.position_embedder(source_text).to(
self.device
)
source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device)
encoder_outputs = self.encoder(source_embeddings, self_padding_mask=source_padding_mask)
input_embeddings = self.target_token_embedder(input_text) + self.position_embedder(input_text).to(self.device)
self_padding_mask = torch.eq(input_text, self.padding_token_idx).to(self.device)
self_attn_mask = self.self_attn_mask(input_text.size(-1)).bool().to(self.device)
decoder_outputs = self.decoder(
input_embeddings,
self_padding_mask=self_padding_mask,
self_attn_mask=self_attn_mask,
external_states=encoder_outputs,
external_padding_mask=source_padding_mask
)
token_logits = self.vocab_linear(decoder_outputs)
loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.contiguous().view(-1))
loss = loss.reshape_as(target_text)
length = corpus['target_length'] - 1
loss = loss.sum(dim=1) / length.float()
return loss.mean()