MaskGAN Discriminator

class textbox.module.Discriminator.MaskGANDiscriminator.MaskGANDiscriminator(config, dataset)[source]

Bases: GenerativeAdversarialNet

RNN-based Encoder-Decoder architecture for MaskGAN discriminator

calculate_dis_loss(fake_prediction, real_prediction, target_present)[source]

Compute Discriminator loss across real/fake

calculate_loss(real_sequence, lengths, fake_sequence, targets_present, embedder)[source]

Calculate discriminator loss

create_critic_loss(cumulative_rewards, estimated_values, target_present)[source]

Compute Critic loss in estimating the value function. This should be an estimate only for the missing elements.

critic(fake_sequence, embedder)[source]

Define the Critic graph which is derived from the seq2seq Discriminator. This will be initialized with the same parameters as the language model and will share the forward RNN components with the Discriminator. This estimates the V(s_t), where the state s_t = x_0,…,x_t-1.

Parameters

fake_sequence – sequence generated bs*seq_len

Returns

bs*seq_len

Return type

values

forward(inputs, inputs_length, sequence, targets_present, embedder)[source]

Predict the real prob of the filled_in token using real sentence and fake sentence

Parameters
  • inputs – real input bs*seq_len

  • inputs_length – sentences length list[bs]

  • sequence – real target or the generated sentence by Generator

  • targets_present – target sentences present matrix bs*seq_len

  • embedder – shared embedding with generator

Returns

the real prob of filled_in token predicted by discriminator

Return type

prediction

mask_input(inputs, targets_present)[source]

Transforms the inputs to have missing tokens when it’s masked out. The mask is for the targets, so therefore, to determine if an input at time t is masked, we have to check if the target at time t - 1 is masked out.

e.g.

  • inputs = [a, b, c, d]

  • targets = [b, c, d, e]

  • targets_present = [1, 0, 1, 0]

then,

  • masked_input = [a, b, <missing>, d]

Parameters
  • inputs – Tensor of shape [batch_size, sequence_length]

  • targets_present – Bool tensor of shape [batch_size, sequence_length] with 1 representing the presence of the word.

Returns

Tensor of shape [batch_size, sequence_length]

which takes on value of inputs when the input is present and takes on value=mask_token_idx to indicate a missing token.

Return type

masked_input

mask_target_present(targets_present, lengths)[source]
training: bool