Source code for textbox.model.VAE.rnnvae

# @Time   : 2020/11/8
# @Author : Gaole He
# @Email  : hegaole@ruc.edu.cn

r"""
RNNVAE
################################################
Reference:
    Bowman et al. "Generating Sentences from a Continuous Space" in CoNLL 2016.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from textbox.model.abstract_generator import UnconditionalGenerator
from textbox.module.Encoder.rnn_encoder import BasicRNNEncoder
from textbox.module.Decoder.rnn_decoder import BasicRNNDecoder
from textbox.model.init import xavier_normal_initialization
from textbox.module.strategy import topk_sampling


[docs]class RNNVAE(UnconditionalGenerator): r"""LSTMVAE is the first text generation model with VAE, we modify its architecture to fit all RNN type, and rename it as RNNVAE """ def __init__(self, config, dataset): super(RNNVAE, self).__init__(config, dataset) # load parameters info self.embedding_size = config['embedding_size'] self.hidden_size = config['hidden_size'] self.latent_size = config['latent_size'] self.num_enc_layers = config['num_enc_layers'] self.num_dec_layers = config['num_dec_layers'] self.num_highway_layers = config['num_highway_layers'] self.rnn_type = config['rnn_type'] self.max_epoch = config['epochs'] self.bidirectional = config['bidirectional'] self.dropout_ratio = config['dropout_ratio'] self.num_directions = 2 if self.bidirectional else 1 # define layers and loss self.token_embedder = nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=self.padding_token_idx) self.encoder = BasicRNNEncoder( self.embedding_size, self.hidden_size, self.num_enc_layers, self.rnn_type, self.dropout_ratio, self.bidirectional ) self.decoder = BasicRNNDecoder( self.embedding_size, self.hidden_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio ) self.dropout = nn.Dropout(self.dropout_ratio) self.vocab_linear = nn.Linear(self.hidden_size, self.vocab_size) self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none') if self.rnn_type == "lstm": self.hidden_to_mean = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.hidden_to_logvar = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.latent_to_hidden = nn.Linear(self.latent_size, 2 * self.hidden_size) elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': self.hidden_to_mean = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.hidden_to_logvar = nn.Linear(self.num_directions * self.hidden_size, self.latent_size) self.latent_to_hidden = nn.Linear(self.latent_size, self.hidden_size) else: raise ValueError("No such rnn type {} for RNNVAE.".format(self.rnn_type)) # parameters initialization self.apply(xavier_normal_initialization)
[docs] def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token batch_size = len(batch_data['target_text']) for _ in range(batch_size): if self.rnn_type == "lstm": hidden_states = torch.randn(size=(1, 2 * self.hidden_size), device=self.device) hidden_states = torch.chunk(hidden_states, 2, dim=-1) h_0 = hidden_states[0].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() c_0 = hidden_states[1].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() hidden_states = (h_0, c_0) else: hidden_states = torch.randn(size=(self.num_dec_layers, 1, self.hidden_size), device=self.device) # draw noise from standard gussian distribution generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device) for _ in range(self.max_length): decoder_input = self.token_embedder(input_seq) outputs, hidden_states = self.decoder(input_embeddings=decoder_input, hidden_states=hidden_states) token_logits = self.vocab_linear(outputs) token_idx = topk_sampling(token_logits) token_idx = token_idx.item() if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx]]).to(self.device) generate_corpus.append(generate_tokens) return generate_corpus
[docs] def forward(self, corpus, epoch_idx=0, nll_test=False): input_text = corpus['target_idx'][:, :-1] target_text = corpus['target_idx'][:, 1:] input_length = corpus['target_length'] - 1 batch_size = input_text.size(0) input_emb = self.token_embedder(input_text) _, hidden_states = self.encoder(input_emb, input_length) if self.rnn_type == "lstm": h_n, c_n = hidden_states elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': h_n = hidden_states else: raise NotImplementedError("No such rnn type {} for RNNVAE.".format(self.rnn_type)) if self.bidirectional: h_n = h_n.view(self.num_enc_layers, 2, batch_size, self.hidden_size) h_n = h_n[-1] h_n = torch.cat([h_n[0], h_n[1]], dim=1) else: h_n = h_n[-1] mean = self.hidden_to_mean(h_n) logvar = self.hidden_to_logvar(h_n) z = torch.randn([batch_size, self.latent_size]).to(self.device) z = mean + z * torch.exp(0.5 * logvar) hidden = self.latent_to_hidden(z) if self.rnn_type == "lstm": decoder_hidden = torch.chunk(hidden, 2, dim=-1) h_0 = decoder_hidden[0].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() c_0 = decoder_hidden[1].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() decoder_hidden = (h_0, c_0) else: decoder_hidden = hidden.unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() input_emb = self.dropout(input_emb) outputs, hidden_states = self.decoder(input_embeddings=input_emb, hidden_states=decoder_hidden) token_logits = self.vocab_linear(outputs) loss = self.loss(token_logits.view(-1, token_logits.size(-1)), target_text.contiguous().view(-1)) loss = loss.reshape_as(target_text) if nll_test: loss = loss.sum(dim=1).mean() else: length = corpus['target_length'] - 1 loss = loss.sum(dim=1) / length.float() kld_coef = float(epoch_idx / self.max_epoch) + 1e-3 kld = -0.5 * torch.sum(logvar - mean.pow(2) - logvar.exp() + 1, 1).mean() # gradually increase the kld weight loss = loss.mean() + kld_coef * kld return loss
[docs] def calculate_nll_test(self, corpus, epoch_idx=0): return self.forward(corpus, nll_test=True)