MaskGAN Generator

class textbox.module.Generator.MaskGANGenerator.MaskGANGenerator(config, dataset)[source]

Bases: GenerativeAdversarialNet

RNN-based Encoder-Decoder architecture for maskgan generator

adversarial_loss(inputs, lengths, targets, targets_present, discriminator)[source]

Calculate adversarial loss

calculate_loss(logits, target_inputs)[source]

Calculate nll test loss

calculate_reinforce_objective(log_probs, dis_predictions, mask_present, estimated_values=None)[source]

Calculate the REINFORCE objectives. The REINFORCE objective should only be on the tokens that were missing. Specifically, the final Generator reward should be based on the Discriminator predictions on missing tokens. The log probabilities should be only for missing tokens and the baseline should be calculated only on the missing tokens. For this model, we optimize the reward is the log of the conditional probability the Discriminator assigns to the distribution. Specifically, for a Discriminator D which outputs probability of real, given the past context, r_t = log D(x_t|x_0,x_1,…x_{t-1}) And the policy for Generator G is the log-probability of taking action x2 given the past context.

Parameters
  • log_probs – Tensor of log probabilities of the tokens selected by the Generator. Shape [batch_size, sequence_length].

  • dis_predictions – Tensor of the predictions from the Discriminator. Shape [batch_size, sequence_length].

  • present – Tensor indicating which tokens are present. Shape [batch_size, sequence_length].

  • estimated_values – Tensor of estimated state values of tokens. Shape [batch_size, sequence_length]

Returns

Final REINFORCE objective for the sequence. rewards: Tensor of rewards for sequence of shape [batch_size, sequence_length] advantages: Tensor of advantages for sequence of shape [batch_size, sequence_length] baselines: Tensor of baselines for sequence of shape [batch_size, sequence_length] maintain_averages_op: ExponentialMovingAverage apply average op to maintain the baseline.

Return type

final_gen_objective

calculate_train_loss(inputs, lengths, targets, targets_present, validate=False)[source]

Calculate train loss for generator

create_critic_loss(cumulative_rewards, estimated_values, target_present)[source]

Compute Critic loss in estimating the value function. This should be an estimate only for the missing elements.

forward(inputs, input_length, targets, targets_present, pretrain=False, validate=False)[source]

Input real padded input and target sentence which not start from sos and end with eos(According to origin code). And input length used for LSTM

Parameters
  • inputs – bs*seq_len

  • input_length – list[bs]

  • targets_present – target present matrix bs*seq_len 1: not mask 0: mask

  • pretrain – control whether LM pretrain

Returns

samples log_probs: log prob logits: logits

Return type

output

generate(batch_data, eval_data)[source]

Sample sentence

mask_cross_entropy_loss(targets, logits, targets_present)[source]

Calculate the filling token cross entropy loss

mask_input(inputs, targets_present)[source]

Transforms the inputs to have missing tokens when it’s masked out. The mask is for the targets, so therefore, to determine if an input at time t is masked, we have to check if the target at time t - 1 is masked out.

e.g.

  • inputs = [a, b, c, d]

  • targets = [b, c, d, e]

  • targets_present = [1, 0, 1, 0]

then,

  • masked_input = [a, b, <missing>, d]

Parameters
  • inputs – Tensor of shape [batch_size, sequence_length]

  • targets_present – Bool tensor of shape [batch_size, sequence_length] with 1 representing the presence of the word.

Returns

Tensor of shape [batch_size, sequence_length]

which takes on value of inputs when the input is present and takes on value=mask_token_idx to indicate a missing token.

Return type

masked_input

training: bool