MaskGAN¶
- Reference:
Fedus et al. “MaskGAN: Better Text Generation via Filling in the ________” in ICLR 2018.
- class textbox.model.GAN.maskgan.MaskGAN(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
MaskGAN is a generative adversarial network to improve sample quality, which introduces an actor-critic conditional GAN that fills in missing text conditioned on the surrounding context.
- calculate_d_train_loss(data, epoch_idx)[source]¶
Specified for maskgan calculate discriminator masked token predicted
- calculate_g_adversarial_loss(data, epoch_idx)[source]¶
Specified for maskgan calculate adversarial masked token predicted
- calculate_g_train_loss(corpus, epoch_idx=0, validate=False)[source]¶
Specified for maskgan calculate generator masked token predicted
- calculate_nll_test(eval_batch, epoch_idx)[source]¶
Specified for maskgan calculating the negative log-likelihood of the batch.
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- generate_mask(batch_size, seq_len, mask_strategy)[source]¶
Generate the mask to be fed into the model.
- training: bool¶