Source code for textbox.model.abstract_generator

# @Time   : 2020/11/5
# @Author : Junyi Li, Gaole He
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2021/1/30
# @Author : Tianyi Tang
# @Email  : steventang@ruc.edu.cn

# UPDATE:
# @Time   : 2021/3/23
# @Author : Zhuohao Yu
# @Email  : zhuohao@ruc.edu.cn

"""
textbox.model.abstract_generator
##################################
"""

import numpy as np
import torch
import torch.nn as nn

from textbox.utils import ModelType


[docs]class AbstractModel(nn.Module): r"""Base class for all models """ def __init__(self, config, dataset): # load parameters info super(AbstractModel, self).__init__() self.device = config['device'] self.batch_size = config['train_batch_size'] self.dataset = dataset def __getattr__(self, name): if hasattr(self.dataset, name): value = getattr(self.dataset, name) if value is not None: return value return super().__getattr__(name)
[docs] def generate(self, batch_data, eval_data): r"""Predict the texts conditioned on a noise or sequence. Args: batch_data (Corpus): Corpus class of a single batch. eval_data: Common data of all the batches. Returns: torch.Tensor: Generated text, shape: [batch_size, max_len] """ raise NotImplementedError
def __str__(self): """ Model prints with number of trainable parameters """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return super().__str__() + '\nTrainable parameters: {}'.format(params)
[docs]class UnconditionalGenerator(AbstractModel): """This is a abstract general unconditional generator. All the unconditional model should implement this class. The base general unconditional generator class provide the basic parameters information. """ type = ModelType.UNCONDITIONAL def __init__(self, config, dataset): super(UnconditionalGenerator, self).__init__(config, dataset)
[docs]class Seq2SeqGenerator(AbstractModel): """This is a abstract general seq2seq generator. All the seq2seq model should implement this class. The base general seq2seq generator class provide the basic parameters information. """ type = ModelType.SEQ2SEQ def __init__(self, config, dataset): super(Seq2SeqGenerator, self).__init__(config, dataset)
[docs]class AttributeGenerator(AbstractModel): """This is a abstract general attribute generator. All the attribute model should implement this class. The base general attribute generator class provide the basic parameters information. """ type = ModelType.ATTRIBUTE def __init__(self, config, dataset): super(AttributeGenerator, self).__init__(config, dataset)
[docs]class GenerativeAdversarialNet(UnconditionalGenerator): """This is a abstract general generative adversarial network. All the GAN model should implement this class. The base general generative adversarial network class provide the basic parameters information. """ type = ModelType.GAN def __init__(self, config, dataset): super(GenerativeAdversarialNet, self).__init__(config, dataset)
[docs] def calculate_g_train_loss(self, corpus): r"""Calculate the generator training loss for a batch data. Args: corpus (Corpus): Corpus class of the batch. Returns: torch.Tensor: Training loss, shape: [] """ raise NotImplementedError
[docs] def calculate_d_train_loss(self, real_data, fake_data): r"""Calculate the discriminator training loss for a batch data. Args: real_data (torch.LongTensor): Real data of the batch, shape: [batch_size, max_length] fake_data (torch.LongTensor): Fake data of the batch, shape: [batch_size, max_length] Returns: torch.Tensor: Training loss, shape: [] """ raise NotImplementedError
[docs] def calculate_g_adversarial_loss(self): r"""Calculate the adversarial generator training loss for a batch data. Returns: torch.Tensor: Training loss, shape: [] """ raise NotImplementedError
[docs] def calculate_nll_test(self, eval_data): r"""Calculate the negative log-likelihood of the batch. Args: eval_data (Corpus): Corpus class of the batch. Returns: torch.FloatTensor: NLL_test of eval data """ raise NotImplementedError
[docs] def sample(self, sample_num): r"""Sample sample_num padded fake data generated by generator. Args: sample_num (int): The number of padded fake data generated by generator. Returns: torch.LongTensor: Fake data generated by generator, shape: [sample_num, max_length] """ raise NotImplementedError