textbox.model.abstract_generator¶
- class textbox.model.abstract_generator.AbstractModel(config, dataset)[source]¶
Bases:
Module
Base class for all models
- generate(batch_data, eval_data)[source]¶
Predict the texts conditioned on a noise or sequence.
- Parameters
batch_data (Corpus) – Corpus class of a single batch.
eval_data – Common data of all the batches.
- Returns
Generated text, shape: [batch_size, max_len]
- Return type
torch.Tensor
- training: bool¶
- class textbox.model.abstract_generator.AttributeGenerator(config, dataset)[source]¶
Bases:
AbstractModel
This is a abstract general attribute generator. All the attribute model should implement this class. The base general attribute generator class provide the basic parameters information.
- training: bool¶
- type = 4¶
- class textbox.model.abstract_generator.GenerativeAdversarialNet(config, dataset)[source]¶
Bases:
UnconditionalGenerator
This is a abstract general generative adversarial network. All the GAN model should implement this class. The base general generative adversarial network class provide the basic parameters information.
- calculate_d_train_loss(real_data, fake_data)[source]¶
Calculate the discriminator training loss for a batch data.
- Parameters
real_data (torch.LongTensor) – Real data of the batch, shape: [batch_size, max_length]
fake_data (torch.LongTensor) – Fake data of the batch, shape: [batch_size, max_length]
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_adversarial_loss()[source]¶
Calculate the adversarial generator training loss for a batch data.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_g_train_loss(corpus)[source]¶
Calculate the generator training loss for a batch data.
- Parameters
corpus (Corpus) – Corpus class of the batch.
- Returns
Training loss, shape: []
- Return type
torch.Tensor
- calculate_nll_test(eval_data)[source]¶
Calculate the negative log-likelihood of the batch.
- Parameters
eval_data (Corpus) – Corpus class of the batch.
- Returns
NLL_test of eval data
- Return type
torch.FloatTensor
- sample(sample_num)[source]¶
Sample sample_num padded fake data generated by generator.
- Parameters
sample_num (int) – The number of padded fake data generated by generator.
- Returns
Fake data generated by generator, shape: [sample_num, max_length]
- Return type
torch.LongTensor
- training: bool¶
- type = 2¶
- class textbox.model.abstract_generator.Seq2SeqGenerator(config, dataset)[source]¶
Bases:
AbstractModel
This is a abstract general seq2seq generator. All the seq2seq model should implement this class. The base general seq2seq generator class provide the basic parameters information.
- training: bool¶
- type = 3¶
- class textbox.model.abstract_generator.UnconditionalGenerator(config, dataset)[source]¶
Bases:
AbstractModel
This is a abstract general unconditional generator. All the unconditional model should implement this class. The base general unconditional generator class provide the basic parameters information.
- training: bool¶
- type = 1¶