SeqGAN

Reference:

Yu et al. “SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient” in AAAI 2017.

class textbox.model.GAN.seqgan.SeqGAN(config, dataset)[source]

Bases: GenerativeAdversarialNet

SeqGAN is a generative adversarial network consisting of a generator and a discriminator. Modeling the data generator as a stochastic policy in reinforcement learning (RL), SeqGAN bypasses the generator differentiation problem by directly performing gradient policy update. The RL reward signal comes from the GAN discriminator judged on a complete sequence, and is passed back to the intermediate state-action steps using Monte Carlo search.

calculate_d_train_loss(real_data, fake_data, epoch_idx)[source]

Calculate the discriminator training loss for a batch data.

Parameters
  • real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]

  • fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_adversarial_loss(epoch_idx)[source]

Calculate the adversarial generator training loss for a batch data.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_train_loss(corpus, epoch_idx)[source]

Calculate the generator training loss for a batch data.

Parameters

corpus (Corpus) – Corpus class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_nll_test(corpus, epoch_idx)[source]

Calculate the negative log-likelihood of the batch.

Parameters

eval_data (Corpus) – Corpus class of the batch.

Returns

NLL_test of eval data

Return type

torch.FloatTensor

generate(batch_data, eval_data)[source]

Predict the texts conditioned on a noise or sequence.

Parameters
  • batch_data (Corpus) – Corpus class of a single batch.

  • eval_data – Common data of all the batches.

Returns

Generated text, shape: [batch_size, max_len]

Return type

torch.Tensor

sample(sample_num)[source]

Sample sample_num padded fake data generated by generator.

Parameters

sample_num (int) – The number of padded fake data generated by generator.

Returns

Fake data generated by generator, shape: [sample_num, max_length]

Return type

torch.LongTensor

training: bool