SeqGAN¶
- Reference:
Yu et al. “SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient” in AAAI 2017.
- class textbox.model.GAN.seqgan.SeqGAN(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
SeqGAN is a generative adversarial network consisting of a generator and a discriminator. Modeling the data generator as a stochastic policy in reinforcement learning (RL), SeqGAN bypasses the generator differentiation problem by directly performing gradient policy update. The RL reward signal comes from the GAN discriminator judged on a complete sequence, and is passed back to the intermediate state-action steps using Monte Carlo search.
- calculate_d_train_loss(real_data, fake_data, epoch_idx)[source]¶
Calculate the discriminator training loss for a batch data.
- Parameters
real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]
fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_adversarial_loss(epoch_idx)[source]¶
Calculate the adversarial generator training loss for a batch data.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_train_loss(corpus, epoch_idx)[source]¶
Calculate the generator training loss for a batch data.
- Parameters
corpus (Corpus) – Corpus class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_nll_test(corpus, epoch_idx)[source]¶
Calculate the negative log-likelihood of the batch.
- Parameters
eval_data (Corpus) – Corpus class of the batch.
- Returns
NLL_test of eval data
- Return type
torch.FloatTensor
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- sample(sample_num)[source]¶
Sample sample_num padded fake data generated by generator.
- Parameters
sample_num (int) – The number of padded fake data generated by generator.
- Returns
Fake data generated by generator, shape: [sample_num, max_length]
- Return type
torch.LongTensor
- training: bool¶