MaskGAN

Reference:

Fedus et al. “MaskGAN: Better Text Generation via Filling in the ________” in ICLR 2018.

class textbox.model.GAN.maskgan.MaskGAN(config, dataset)[source]

Bases: GenerativeAdversarialNet

MaskGAN is a generative adversarial network to improve sample quality, which introduces an actor-critic conditional GAN that fills in missing text conditioned on the surrounding context.

calculate_d_train_loss(data, epoch_idx)[source]

Specified for maskgan calculate discriminator masked token predicted

calculate_g_adversarial_loss(data, epoch_idx)[source]

Specified for maskgan calculate adversarial masked token predicted

calculate_g_train_loss(corpus, epoch_idx=0, validate=False)[source]

Specified for maskgan calculate generator masked token predicted

calculate_nll_test(eval_batch, epoch_idx)[source]

Specified for maskgan calculating the negative log-likelihood of the batch.

generate(batch_data, eval_data)[source]

Predict the texts conditioned on a noise or sequence.

Parameters
  • batch_data (Corpus) – Corpus class of a single batch.

  • eval_data – Common data of all the batches.

Returns

Generated text, shape: [batch_size, max_len]

Return type

torch.Tensor

generate_mask(batch_size, seq_len, mask_strategy)[source]

Generate the mask to be fed into the model.

training: bool
update_is_present_rate()[source]