LeakGAN¶
- Reference:
Guo et al. “Long Text Generation via Adversarial Training with Leaked Information” in AAAI 2018.
- class textbox.model.GAN.leakgan.LeakGAN(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
LeakGAN is a generative adversarial network to address the problem for long text generation. We allow the discriminative net to leak its own high-level extracted features to the generative net to further help the guidance. The generator incorporates such informative signals into all generation steps through an additional Manager module, which takes the extracted features of current generated words and outputs a latent vector to guide the Worker module for next-word generation.
- calculate_d_train_loss(real_data, fake_data, epoch_idx)[source]¶
Calculate the discriminator training loss for a batch data.
- Parameters
real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]
fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_adversarial_loss(epoch_idx)[source]¶
Calculate the adversarial generator training loss for a batch data.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_train_loss(corpus, epoch_idx)[source]¶
Calculate the generator training loss for a batch data.
- Parameters
corpus (Corpus) – Corpus class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_nll_test(corpus, epoch_idx)[source]¶
Calculate the negative log-likelihood of the batch.
- Parameters
eval_data (Corpus) – Corpus class of the batch.
- Returns
NLL_test of eval data
- Return type
torch.FloatTensor
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- sample(sample_num)[source]¶
Sample sample_num padded fake data generated by generator.
- Parameters
sample_num (int) – The number of padded fake data generated by generator.
- Returns
Fake data generated by generator, shape: [sample_num, max_length]
- Return type
torch.LongTensor
- training: bool¶