MaskGAN Discriminator¶
- class textbox.module.Discriminator.MaskGANDiscriminator.MaskGANDiscriminator(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
RNN-based Encoder-Decoder architecture for MaskGAN discriminator
- calculate_dis_loss(fake_prediction, real_prediction, target_present)[source]¶
Compute Discriminator loss across real/fake
- calculate_loss(real_sequence, lengths, fake_sequence, targets_present, embedder)[source]¶
Calculate discriminator loss
- create_critic_loss(cumulative_rewards, estimated_values, target_present)[source]¶
Compute Critic loss in estimating the value function. This should be an estimate only for the missing elements.
- critic(fake_sequence, embedder)[source]¶
Define the Critic graph which is derived from the seq2seq Discriminator. This will be initialized with the same parameters as the language model and will share the forward RNN components with the Discriminator. This estimates the V(s_t), where the state s_t = x_0,…,x_t-1.
- Parameters
fake_sequence – sequence generated bs*seq_len
- Returns
bs*seq_len
- Return type
values
- forward(inputs, inputs_length, sequence, targets_present, embedder)[source]¶
Predict the real prob of the filled_in token using real sentence and fake sentence
- Parameters
inputs – real input bs*seq_len
inputs_length – sentences length list[bs]
sequence – real target or the generated sentence by Generator
targets_present – target sentences present matrix bs*seq_len
embedder – shared embedding with generator
- Returns
the real prob of filled_in token predicted by discriminator
- Return type
prediction
- mask_input(inputs, targets_present)[source]¶
Transforms the inputs to have missing tokens when it’s masked out. The mask is for the targets, so therefore, to determine if an input at time t is masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
masked_input = [a, b, <missing>, d]
- Parameters
inputs – Tensor of shape [batch_size, sequence_length]
targets_present – Bool tensor of shape [batch_size, sequence_length] with 1 representing the presence of the word.
- Returns
- Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on value=mask_token_idx to indicate a missing token.
- Return type
masked_input
- training: bool¶