LeakGAN

Reference:

Guo et al. “Long Text Generation via Adversarial Training with Leaked Information” in AAAI 2018.

class textbox.model.GAN.leakgan.LeakGAN(config, dataset)[source]

Bases: GenerativeAdversarialNet

LeakGAN is a generative adversarial network to address the problem for long text generation. We allow the discriminative net to leak its own high-level extracted features to the generative net to further help the guidance. The generator incorporates such informative signals into all generation steps through an additional Manager module, which takes the extracted features of current generated words and outputs a latent vector to guide the Worker module for next-word generation.

calculate_d_train_loss(real_data, fake_data, epoch_idx)[source]

Calculate the discriminator training loss for a batch data.

Parameters
  • real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]

  • fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_adversarial_loss(epoch_idx)[source]

Calculate the adversarial generator training loss for a batch data.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_train_loss(corpus, epoch_idx)[source]

Calculate the generator training loss for a batch data.

Parameters

corpus (Corpus) – Corpus class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_nll_test(corpus, epoch_idx)[source]

Calculate the negative log-likelihood of the batch.

Parameters

eval_data (Corpus) – Corpus class of the batch.

Returns

NLL_test of eval data

Return type

torch.FloatTensor

generate(batch_data, eval_data)[source]

Predict the texts conditioned on a noise or sequence.

Parameters
  • batch_data (Corpus) – Corpus class of a single batch.

  • eval_data – Common data of all the batches.

Returns

Generated text, shape: [batch_size, max_len]

Return type

torch.Tensor

sample(sample_num)[source]

Sample sample_num padded fake data generated by generator.

Parameters

sample_num (int) – The number of padded fake data generated by generator.

Returns

Fake data generated by generator, shape: [sample_num, max_length]

Return type

torch.LongTensor

training: bool