TextGAN Discriminator¶
- class textbox.module.Discriminator.TextGANDiscriminator.TextGANDiscriminator(config, dataset)[source]¶
Bases:
UnconditionalGenerator
The discriminator of TextGAN.
- calculate_g_loss(real_data, fake_data)[source]¶
Calculate the maximum mean discrepancy loss for real data and fake data.
- Parameters
real_data (torch.Tensor) – The realistic sentence data, shape: [batch_size, max_seq_len].
fake_data (torch.Tensor) – The generated sentence data, shape: [batch_size, max_seq_len].
- Returns
The calculated mmd loss of real data and fake data, shape: [].
- Return type
torch.Tensor
- calculate_loss(real_data, fake_data, z)[source]¶
Calculate the loss for real data and fake data.
- Parameters
real_data (torch.Tensor) – The realistic sentence data, shape: [batch_size, max_seq_len].
fake_data (torch.Tensor) – The generated sentence data, shape: [batch_size, max_seq_len].
z (torch.Tensor) – The latent code for generation, shape: [batch_size, hidden_size].
- Returns
The calculated loss of real data and fake data, shape: [].
- Return type
torch.Tensor
- feature(data)[source]¶
Get the feature map extracted from CNN for data.
- Parameters
data (torch.Tensor) – The data to be extraced, shape: [batch_size, max_seq_len, vocab_size].
- Returns
The feature of data, shape: [batch_size, total_filter_num].
- Return type
torch.Tensor
- forward(data)[source]¶
Calculate the probability that the data is realistic.
- Parameters
data (torch.Tensor) – The sentence data, shape: [batch_size, max_seq_len, vocab_size].
- Returns
The probability that each sentence is realistic, shape: [batch_size].
- Return type
torch.Tensor
- training: bool¶