TextGAN

Reference:

Zhang et al. “Adversarial Feature Matching for Text Generation” in ICML 2017.

class textbox.model.GAN.textgan.TextGAN(config, dataset)[source]

Bases: GenerativeAdversarialNet

TextGAN is a generative adversarial network, which proposes matching the high-dimensional latent feature distributions of real and synthetic sentences, via a kernelized discrepancy metric.

calculate_d_train_loss(real_data, fake_data, z, epoch_idx)[source]

Calculate the discriminator training loss for a batch data.

Parameters
  • real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]

  • fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_adversarial_loss(real_data, epoch_idx)[source]

Calculate the adversarial generator training loss for a batch data.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_train_loss(corpus, epoch_idx)[source]

Calculate the generator training loss for a batch data.

Parameters

corpus (Corpus) – Corpus class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_nll_test(corpus, epoch_idx)[source]

Calculate the negative log-likelihood of the batch.

Parameters

eval_data (Corpus) – Corpus class of the batch.

Returns

NLL_test of eval data

Return type

torch.FloatTensor

generate(batch_data, eval_data)[source]

Predict the texts conditioned on a noise or sequence.

Parameters
  • batch_data (Corpus) – Corpus class of a single batch.

  • eval_data – Common data of all the batches.

Returns

Generated text, shape: [batch_size, max_len]

Return type

torch.Tensor

sample()[source]

Sample sample_num padded fake data generated by generator.

Parameters

sample_num (int) – The number of padded fake data generated by generator.

Returns

Fake data generated by generator, shape: [sample_num, max_length]

Return type

torch.LongTensor

training: bool