LeakGAN Generator¶
- class textbox.module.Generator.LeakGANGenerator.LeakGANGenerator(config, dataset)[source]¶
Bases:
UnconditionalGenerator
LeakGAN generator consist of worker(LSTM) and manager(LSTM)
- calculate_loss(targets, dis)[source]¶
Returns the nll test for predicting target sequence.
- Parameters
targets – target_idx(bs*seq_len) ,
dis – discriminator model
- Returns
the generator test nll
- Return type
worker_loss
- forward(idx, inp, work_hidden, mana_hidden, feature, real_goal, train=False, pretrain=False)[source]¶
Embed input and sample on token at a time (seq_len = 1)
- Parameters
idx – index of current token in sentence
inp – current input token for a batch [batch_size]
work_hidden – 1 * batch_size * hidden_dim
mana_hidden – 1 * batch_size * hidden_dim
feature – 1 * batch_size * total_num_filters, feature of current sentence
real_goal – batch_size * goal_out_size, real_goal in LeakGAN source code
train – whether train or inference
pretrain – whether pretrain or not pretrain
- Returns
current output prob over vocab with log_softmax or softmax bs*vocab_size cur_goal: bs * 1 * goal_out_size work_hidden: 1 * batch_size * hidden_dim mana_hidden: 1 * batch_size * hidden_dim
- Return type
out
- get_adv_loss(target, rewards, dis)[source]¶
Return a pseudo-loss that gives corresponding policy gradients (on calling .backward()). Inspired by the example in http://karpathy.github.io/2016/05/31/rl/
- Args: target, rewards, dis, start_letter
target: batch_size * seq_len rewards: batch_size * seq_len (discriminator rewards for each token)
- get_reward_leakgan(sentences, rollout_num, dis, current_k=0)[source]¶
Get reward via Monte Carlo search for LeakGAN
- Parameters
sentences – size of batch_size * max_seq_len
rollout_num – numbers of rollout
dis – discriminator
current_k – current training gen
- Returns
batch_size * (max_seq_len / step_size)
- Return type
reward
Init hidden state for lstm
- leakgan_forward(targets, dis, train=False, pretrain=False)[source]¶
Get all feature and goals according to given sentences
- Parameters
targets – batch_size * max_seq_len, pad eos token if the original sentence length less than max_seq_len
dis – discriminator model
train – if use temperature parameter
pretrain – whether pretrain or not pretrain
- Returns
batch_size * (seq_len + 1) * total_num_filter goal_array: batch_size * (seq_len + 1) * goal_out_size leak_out_array: batch_size * seq_len * vocab_size with log_softmax
- Return type
feature_array
- manager_cos_loss(batch_size, feature_array, goal_array)[source]¶
Get manager cosine distance loss
- Returns
batch_size * (seq_len / step_size)
- Return type
cos_loss
- pretrain_loss(corpus, dis)[source]¶
Return the generator pretrain loss for predicting target sequence.
- Parameters
corpus – target_text(bs*seq_len)
dis – discriminator model
- Returns
manager loss work_cn_loss: worker loss
- Return type
manager_loss
- training: bool¶
- worker_cos_reward(feature_array, goal_array)[source]¶
Get reward for worker (cosine distance)
- Returns
batch_size * seq_len
- Return type
cos_loss