textbox.model.abstract_generator

class textbox.model.abstract_generator.AbstractModel(config, dataset)[source]

Bases: Module

Base class for all models

generate(batch_data, eval_data)[source]

Predict the texts conditioned on a noise or sequence.

Parameters
  • batch_data (Corpus) – Corpus class of a single batch.

  • eval_data – Common data of all the batches.

Returns

Generated text, shape: [batch_size, max_len]

Return type

torch.Tensor

training: bool
class textbox.model.abstract_generator.AttributeGenerator(config, dataset)[source]

Bases: AbstractModel

This is a abstract general attribute generator. All the attribute model should implement this class. The base general attribute generator class provide the basic parameters information.

training: bool
type = 4
class textbox.model.abstract_generator.GenerativeAdversarialNet(config, dataset)[source]

Bases: UnconditionalGenerator

This is a abstract general generative adversarial network. All the GAN model should implement this class. The base general generative adversarial network class provide the basic parameters information.

calculate_d_train_loss(real_data, fake_data)[source]

Calculate the discriminator training loss for a batch data.

Parameters
  • real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]

  • fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_adversarial_loss()[source]

Calculate the adversarial generator training loss for a batch data.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_g_train_loss(corpus)[source]

Calculate the generator training loss for a batch data.

Parameters

corpus (Corpus) – Corpus class of the batch.

Returns

Training loss, shape: []

Return type

torch.Tensor

calculate_nll_test(eval_data)[source]

Calculate the negative log-likelihood of the batch.

Parameters

eval_data (Corpus) – Corpus class of the batch.

Returns

NLL_test of eval data

Return type

torch.FloatTensor

sample(sample_num)[source]

Sample sample_num padded fake data generated by generator.

Parameters

sample_num (int) – The number of padded fake data generated by generator.

Returns

Fake data generated by generator, shape: [sample_num, max_length]

Return type

torch.LongTensor

training: bool
type = 2
class textbox.model.abstract_generator.Seq2SeqGenerator(config, dataset)[source]

Bases: AbstractModel

This is a abstract general seq2seq generator. All the seq2seq model should implement this class. The base general seq2seq generator class provide the basic parameters information.

training: bool
type = 3
class textbox.model.abstract_generator.UnconditionalGenerator(config, dataset)[source]

Bases: AbstractModel

This is a abstract general unconditional generator. All the unconditional model should implement this class. The base general unconditional generator class provide the basic parameters information.

training: bool
type = 1