TextGAN Discriminator

class textbox.module.Discriminator.TextGANDiscriminator.TextGANDiscriminator(config, dataset)[source]

Bases: UnconditionalGenerator

The discriminator of TextGAN.

calculate_g_loss(real_data, fake_data)[source]

Calculate the maximum mean discrepancy loss for real data and fake data.

Parameters
  • real_data (torch.Tensor) – The realistic sentence data, shape: [batch_size, max_seq_len].

  • fake_data (torch.Tensor) – The generated sentence data, shape: [batch_size, max_seq_len].

Returns

The calculated mmd loss of real data and fake data, shape: [].

Return type

torch.Tensor

calculate_loss(real_data, fake_data, z)[source]

Calculate the loss for real data and fake data.

Parameters
  • real_data (torch.Tensor) – The realistic sentence data, shape: [batch_size, max_seq_len].

  • fake_data (torch.Tensor) – The generated sentence data, shape: [batch_size, max_seq_len].

  • z (torch.Tensor) – The latent code for generation, shape: [batch_size, hidden_size].

Returns

The calculated loss of real data and fake data, shape: [].

Return type

torch.Tensor

feature(data)[source]

Get the feature map extracted from CNN for data.

Parameters

data (torch.Tensor) – The data to be extraced, shape: [batch_size, max_seq_len, vocab_size].

Returns

The feature of data, shape: [batch_size, total_filter_num].

Return type

torch.Tensor

forward(data)[source]

Calculate the probability that the data is realistic.

Parameters

data (torch.Tensor) – The sentence data, shape: [batch_size, max_seq_len, vocab_size].

Returns

The probability that each sentence is realistic, shape: [batch_size].

Return type

torch.Tensor

training: bool