MaliGAN¶
- Reference:
Che et al. “Maximum-Likelihood Augmented Discrete Generative Adversarial Networks”.
- class textbox.model.GAN.maligan.MaliGAN(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
MaliGAN is a generative adversarial network using a normalized maximum likelihood optimization.
- calculate_d_train_loss(real_data, fake_data, epoch_idx)[source]¶
Calculate the discriminator training loss for a batch data.
- Parameters
real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]
fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_adversarial_loss(epoch_idx)[source]¶
Calculate the adversarial generator training loss for a batch data.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_train_loss(corpus, epoch_idx)[source]¶
Calculate the generator training loss for a batch data.
- Parameters
corpus (Corpus) – Corpus class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_nll_test(corpus, epoch_idx)[source]¶
Calculate the negative log-likelihood of the batch.
- Parameters
eval_data (Corpus) – Corpus class of the batch.
- Returns
NLL_test of eval data
- Return type
torch.FloatTensor
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- sample(sample_num)[source]¶
Sample sample_num padded fake data generated by generator.
- Parameters
sample_num (int) – The number of padded fake data generated by generator.
- Returns
Fake data generated by generator, shape: [sample_num, max_length]
- Return type
torch.LongTensor
- training: bool¶