LeakGAN Generator

class textbox.module.Generator.LeakGANGenerator.LeakGANGenerator(config, dataset)[source]

Bases: UnconditionalGenerator

LeakGAN generator consist of worker(LSTM) and manager(LSTM)

adversarial_loss(dis)[source]

Generate data and calculate adversarial loss

calculate_loss(targets, dis)[source]

Returns the nll test for predicting target sequence.

Parameters
  • targets – target_idx(bs*seq_len) ,

  • dis – discriminator model

Returns

the generator test nll

Return type

worker_loss

forward(idx, inp, work_hidden, mana_hidden, feature, real_goal, train=False, pretrain=False)[source]

Embed input and sample on token at a time (seq_len = 1)

Parameters
  • idx – index of current token in sentence

  • inp – current input token for a batch [batch_size]

  • work_hidden – 1 * batch_size * hidden_dim

  • mana_hidden – 1 * batch_size * hidden_dim

  • feature – 1 * batch_size * total_num_filters, feature of current sentence

  • real_goal – batch_size * goal_out_size, real_goal in LeakGAN source code

  • train – whether train or inference

  • pretrain – whether pretrain or not pretrain

Returns

current output prob over vocab with log_softmax or softmax bs*vocab_size cur_goal: bs * 1 * goal_out_size work_hidden: 1 * batch_size * hidden_dim mana_hidden: 1 * batch_size * hidden_dim

Return type

out

generate(batch_data, eval_data, dis)[source]

Generate sentences

get_adv_loss(target, rewards, dis)[source]

Return a pseudo-loss that gives corresponding policy gradients (on calling .backward()). Inspired by the example in http://karpathy.github.io/2016/05/31/rl/

Args: target, rewards, dis, start_letter

target: batch_size * seq_len rewards: batch_size * seq_len (discriminator rewards for each token)

get_reward_leakgan(sentences, rollout_num, dis, current_k=0)[source]

Get reward via Monte Carlo search for LeakGAN

Parameters
  • sentences – size of batch_size * max_seq_len

  • rollout_num – numbers of rollout

  • dis – discriminator

  • current_k – current training gen

Returns

batch_size * (max_seq_len / step_size)

Return type

reward

init_hidden(batch_size=1)[source]

Init hidden state for lstm

leakgan_forward(targets, dis, train=False, pretrain=False)[source]

Get all feature and goals according to given sentences

Parameters
  • targets – batch_size * max_seq_len, pad eos token if the original sentence length less than max_seq_len

  • dis – discriminator model

  • train – if use temperature parameter

  • pretrain – whether pretrain or not pretrain

Returns

batch_size * (seq_len + 1) * total_num_filter goal_array: batch_size * (seq_len + 1) * goal_out_size leak_out_array: batch_size * seq_len * vocab_size with log_softmax

Return type

feature_array

leakgan_generate(targets, dis, train=False)[source]
manager_cos_loss(batch_size, feature_array, goal_array)[source]

Get manager cosine distance loss

Returns

batch_size * (seq_len / step_size)

Return type

cos_loss

pretrain_loss(corpus, dis)[source]

Return the generator pretrain loss for predicting target sequence.

Parameters
  • corpus – target_text(bs*seq_len)

  • dis – discriminator model

Returns

manager loss work_cn_loss: worker loss

Return type

manager_loss

redistribution(idx, total, min_v)[source]
rescale(reward, rollout_num=1.0)[source]

Rescale reward according to original paper

rollout_mc_search_leakgan(targets, dis, given_num)[source]

Roll out to get mc search results

sample(sample_num, dis, start_letter, train=False)[source]

Sample sentences

sample_batch()[source]

Sample a batch of data

split_params()[source]

Split parameter into worker and manager

training: bool
worker_cos_reward(feature_array, goal_array)[source]

Get reward for worker (cosine distance)

Returns

batch_size * seq_len

Return type

cos_loss

worker_cross_entropy_loss(target, leak_out_array, reduction='mean')[source]

Get CrossEntropy loss for worker

worker_nll_loss(target, leak_out_array)[source]

Get nll loss for worker