TextGAN¶
- Reference:
Zhang et al. “Adversarial Feature Matching for Text Generation” in ICML 2017.
- class textbox.model.GAN.textgan.TextGAN(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
TextGAN is a generative adversarial network, which proposes matching the high-dimensional latent feature distributions of real and synthetic sentences, via a kernelized discrepancy metric.
- calculate_d_train_loss(real_data, fake_data, z, epoch_idx)[source]¶
Calculate the discriminator training loss for a batch data.
- Parameters
real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]
fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_adversarial_loss(real_data, epoch_idx)[source]¶
Calculate the adversarial generator training loss for a batch data.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_train_loss(corpus, epoch_idx)[source]¶
Calculate the generator training loss for a batch data.
- Parameters
corpus (Corpus) – Corpus class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_nll_test(corpus, epoch_idx)[source]¶
Calculate the negative log-likelihood of the batch.
- Parameters
eval_data (Corpus) – Corpus class of the batch.
- Returns
NLL_test of eval data
- Return type
torch.FloatTensor
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- sample()[source]¶
Sample sample_num padded fake data generated by generator.
- Parameters
sample_num (int) – The number of padded fake data generated by generator.
- Returns
Fake data generated by generator, shape: [sample_num, max_length]
- Return type
torch.LongTensor
- training: bool¶