Source code for textbox.data.dataloader.attr_sent_dataloader

# @Time   : 2021/1/30
# @Author : Tianyi Tang
# @Email  : steven_tang@ruc.edu.cn

# UPDATE:
# @Time   : 2021/10/10
# @Author : Tianyi Tang
# @Email  : steven_tang@ruc.edu.cn

"""
textbox.data.dataloader.attr_sent_dataloader
################################################
"""

import torch
from textbox.data.dataloader import AbstractDataLoader


[docs]class AttributedSentenceDataLoader(AbstractDataLoader): """:class:`GeneralDataLoader` is used for general model and it just return the origin data. Args: config (Config): The config of dataloader. dataset (AttributedSentenceDataset): The dataset of dataloader. Corpus, see textbox.data.corpus for more details batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ def __init__(self, config, dataset, batch_size=1, shuffle=False, drop_last=True, DDP=False): super().__init__(config, dataset, batch_size, shuffle, drop_last, DDP) self.attribute_size = [len(a2t) for a2t in self.source_idx2token] self.attribute_num = len(self.attribute_size) def _next_source_patch(self): source_text = self.source_text[self.pr:self.pr + self.step] source_idx = self.source_idx[self.pr:self.pr + self.step] source_idx = torch.LongTensor(source_idx) batch_data = { 'source_text': source_text, 'source_idx': source_idx.to(self.device), } return batch_data