Source code for textbox.model.GAN.maskgan

# @Time   : 2020/12/26
# @Author : Jinhao Jiang
# @Email  : jiangjinhao@std.uestc.edu.cn

r"""
MaskGAN
################################################
Reference:
    Fedus et al. "MaskGAN: Better Text Generation via Filling in the ________" in ICLR 2018.
"""

import torch
import numpy as np

from textbox.model.abstract_generator import GenerativeAdversarialNet
from textbox.module.Generator.MaskGANGenerator import MaskGANGenerator
from textbox.module.Discriminator.MaskGANDiscriminator import MaskGANDiscriminator


[docs]class MaskGAN(GenerativeAdversarialNet): r"""MaskGAN is a generative adversarial network to improve sample quality, which introduces an actor-critic conditional GAN that fills in missing text conditioned on the surrounding context. """ def __init__(self, config, dataset): super(MaskGAN, self).__init__(config, dataset) self.source_vocab_size = self.vocab_size self.target_vocab_size = self.vocab_size self.generator = MaskGANGenerator(config, dataset) self.discriminator = MaskGANDiscriminator(config, dataset) self.mask_strategy = config['mask_strategy'] self.is_present_rate = config['is_present_rate'] self.is_present_rate_decay = config['is_present_rate_decay']
[docs] def calculate_g_train_loss(self, corpus, epoch_idx=0, validate=False): r"""Specified for maskgan calculate generator masked token predicted """ real_inputs = corpus[:, :-1] # bs * self.max_len - 1 target_inputs = corpus[:, 1:] bs, seq_len = target_inputs.size() lengths = torch.tensor([seq_len] * bs) target_present = self.generate_mask(bs, seq_len, "continuous") device = target_inputs.device lengths = lengths.cuda(device) target_present = target_present.cuda(device) return self.generator.calculate_train_loss( real_inputs, lengths, target_inputs, target_present, validate=validate )
[docs] def calculate_d_train_loss(self, data, epoch_idx): r"""Specified for maskgan calculate discriminator masked token predicted """ self.generator.eval() inputs = data[:, :-1] targets = data[:, 1:] batch_size, seq_len = inputs.size() lengths = torch.tensor([seq_len] * batch_size) targets_present = self.generate_mask(batch_size, seq_len, "continuous") device = inputs.device targets_present = targets_present.cuda(device) lengths = lengths.cuda(device) fake_sequence, _, _ = self.generator.forward(inputs, lengths, targets, targets_present) self.generator.train() return self.discriminator.calculate_loss( inputs, lengths, fake_sequence, targets_present, self.generator.embedder )
[docs] def generate_mask(self, batch_size, seq_len, mask_strategy): r"""Generate the mask to be fed into the model. """ if mask_strategy == 'random': p = np.random.choice([True, False], size=[batch_size, seq_len], p=[self.is_present_rate, 1. - self.is_present_rate]) elif mask_strategy == 'continuous': masked_length = int((1 - self.is_present_rate) * seq_len) - 1 # Determine location to start masking. start_mask = np.random.randint(1, seq_len - masked_length + 1, size=batch_size) p = np.full([batch_size, seq_len], True, dtype=bool) # Create contiguous masked section to be False. for i, index in enumerate(start_mask): p[i, index:index + masked_length] = False else: raise NotImplementedError p = torch.from_numpy(p) return p
[docs] def calculate_g_adversarial_loss(self, data, epoch_idx): r"""Specified for maskgan calculate adversarial masked token predicted """ real_inputs = data[:, :-1] target_inputs = data[:, 1:] batch_size, seq_len = real_inputs.size() lengths = torch.tensor([seq_len] * batch_size) targets_present = self.generate_mask(batch_size, seq_len, "continuous") device = real_inputs.device targets_present = targets_present.cuda(device) lengths = lengths.cuda(device) loss = self.generator.adversarial_loss(real_inputs, lengths, target_inputs, targets_present, self.discriminator) return loss
[docs] def calculate_nll_test(self, eval_batch, epoch_idx): r"""Specified for maskgan calculating the negative log-likelihood of the batch. """ real_inputs = eval_batch[:, :-1] target_inputs = eval_batch[:, 1:] batch_size, seq_len = real_inputs.size() lengths = torch.tensor([seq_len] * batch_size) targets_present = torch.zeros_like(target_inputs).byte() device = real_inputs.device lengths = lengths.cuda(device) outputs, log_probs, logits = self.generator.forward(real_inputs, lengths, target_inputs, targets_present) return self.generator.calculate_loss(logits, target_inputs)
[docs] def generate(self, batch_data, eval_data): return self.generator.generate(batch_data, eval_data)
[docs] def update_is_present_rate(self): self.is_present_rate *= (1. - self.is_present_rate_decay)