MaskGAN Generator¶
- class textbox.module.Generator.MaskGANGenerator.MaskGANGenerator(config, dataset)[source]¶
Bases:
GenerativeAdversarialNet
RNN-based Encoder-Decoder architecture for maskgan generator
- adversarial_loss(inputs, lengths, targets, targets_present, discriminator)[source]¶
Calculate adversarial loss
- calculate_reinforce_objective(log_probs, dis_predictions, mask_present, estimated_values=None)[source]¶
Calculate the REINFORCE objectives. The REINFORCE objective should only be on the tokens that were missing. Specifically, the final Generator reward should be based on the Discriminator predictions on missing tokens. The log probabilities should be only for missing tokens and the baseline should be calculated only on the missing tokens. For this model, we optimize the reward is the log of the conditional probability the Discriminator assigns to the distribution. Specifically, for a Discriminator D which outputs probability of real, given the past context, r_t = log D(x_t|x_0,x_1,…x_{t-1}) And the policy for Generator G is the log-probability of taking action x2 given the past context.
- Parameters
log_probs – Tensor of log probabilities of the tokens selected by the Generator. Shape [batch_size, sequence_length].
dis_predictions – Tensor of the predictions from the Discriminator. Shape [batch_size, sequence_length].
present – Tensor indicating which tokens are present. Shape [batch_size, sequence_length].
estimated_values – Tensor of estimated state values of tokens. Shape [batch_size, sequence_length]
- Returns
Final REINFORCE objective for the sequence. rewards: Tensor of rewards for sequence of shape [batch_size, sequence_length] advantages: Tensor of advantages for sequence of shape [batch_size, sequence_length] baselines: Tensor of baselines for sequence of shape [batch_size, sequence_length] maintain_averages_op: ExponentialMovingAverage apply average op to maintain the baseline.
- Return type
final_gen_objective
- calculate_train_loss(inputs, lengths, targets, targets_present, validate=False)[source]¶
Calculate train loss for generator
- create_critic_loss(cumulative_rewards, estimated_values, target_present)[source]¶
Compute Critic loss in estimating the value function. This should be an estimate only for the missing elements.
- forward(inputs, input_length, targets, targets_present, pretrain=False, validate=False)[source]¶
Input real padded input and target sentence which not start from sos and end with eos(According to origin code). And input length used for LSTM
- Parameters
inputs – bs*seq_len
input_length – list[bs]
targets_present – target present matrix bs*seq_len 1: not mask 0: mask
pretrain – control whether LM pretrain
- Returns
samples log_probs: log prob logits: logits
- Return type
output
- mask_cross_entropy_loss(targets, logits, targets_present)[source]¶
Calculate the filling token cross entropy loss
- mask_input(inputs, targets_present)[source]¶
Transforms the inputs to have missing tokens when it’s masked out. The mask is for the targets, so therefore, to determine if an input at time t is masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
masked_input = [a, b, <missing>, d]
- Parameters
inputs – Tensor of shape [batch_size, sequence_length]
targets_present – Bool tensor of shape [batch_size, sequence_length] with 1 representing the presence of the word.
- Returns
- Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on value=mask_token_idx to indicate a missing token.
- Return type
masked_input
- training: bool¶