# @Time : 2020/11/4
# @Author : Gaole He
# @email : hegaole@ruc.edu.cn
# UPDATE:
# @Time : 2021/10/10, 2021/1/29
# @Author : Tianyi Tang
# @Email : steven_tang@ruc.edu.cn
"""
textbox.data.dataloader.abstract_dataloader
################################################
"""
import math
import torch
import random
from logging import getLogger
from textbox.utils.enum_type import SpecialTokens
from textbox.data.utils import pad_sequence
[docs]class AbstractDataLoader(object):
""":class:`AbstractDataLoader` is an abstract object which would return a batch of data.
And it is also the ancestor of all other dataloader.
Args:
config (Config): The config of dataloader.
dataset (Corpus): The corpus for partition of dataset.
batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``.
shuffle (bool): If ``True``, dataloader will shuffle before every epoch.
Attributes:
dataset (dict): The necessary elements of this dataloader.
pr (int): Pointer of dataloader.
step (int): The increment of :attr:`pr` for each batch.
batch_size (int): The max interaction number for all batch.
"""
def __init__(self, config, dataset, batch_size=1, shuffle=False, drop_last=True, DDP=False):
self.DDP = config['DDP'] and DDP
self.config = config
self.device = config['device']
self.logger = getLogger()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
if self.DDP:
self.step = batch_size // torch.distributed.get_world_size()
self.pr = batch_size // torch.distributed.get_world_size() * torch.distributed.get_rank()
else:
self.step = batch_size
self.pr = 0
self.std_pr = 0
self.pr_end = len(self.target_text)
def __getattr__(self, name):
if hasattr(self.dataset, name):
return getattr(self.dataset, name)
return None
def __len__(self):
return math.floor(self.pr_end / self.batch_size) if self.drop_last else math.ceil(self.pr_end / self.batch_size)
def __iter__(self):
if self.shuffle:
self._shuffle()
return self
def __next__(self):
if (self.drop_last
and self.std_pr + self.batch_size >= self.pr_end) or (not self.drop_last and self.pr >= self.pr_end):
if (self.DDP == True):
self.pr = self.batch_size // torch.distributed.get_world_size() * torch.distributed.get_rank()
else:
self.pr = 0
self.std_pr = 0
raise StopIteration()
next_batch = self._next_batch_data()
self.pr += self.batch_size
self.std_pr += self.batch_size
return next_batch
def _shuffle(self):
r"""Shuffle the order of data, and it will be called by :meth:`__iter__()` if self.shuffle is True.
"""
keys = []
values = []
for key, value in self.dataset.__dict__.items():
if key.startswith(('source', 'target')) and isinstance(value,
list) and isinstance(value[0], (list, str, int)):
keys.append(key)
values.append(value)
values = list(zip(*values))
random.shuffle(values)
for key, value in zip(keys, list(zip(*values))):
getattr(self.dataset, key)[:] = value
def _next_source_patch(self):
r"""Assemble next batch of source data in form of Interaction, and return these data.
Returns:
Interaction: The next batch of source data.
"""
raise NotImplementedError('Method [next_batch_data] should be implemented.')
def _next_target_patch(self):
r"""Assemble next batch of target data in form of Interaction, and return these data.
Returns:
Interaction: The next batch of target data.
"""
target_text = self.target_text[self.pr:self.pr + self.step]
if self.target_idx is not None:
target_idx = self.target_idx[self.pr:self.pr + self.step]
target_length = self.target_length[self.pr:self.pr + self.step]
target_num = self.target_num[self.pr:self.pr + self.step] if self.target_num is not None else None
target_idx, target_length, target_num = pad_sequence(
target_idx, target_length, self.padding_token_idx, target_num
)
batch_data = {
'target_text': target_text,
'target_idx': target_idx.to(self.device),
'target_length': target_length.to(self.device)
}
if target_num is not None:
batch_data['target_num'] = target_num
return batch_data
else:
return {'target_text': target_text}
def _next_batch_data(self):
r"""Assemble next batch of data in form of Interaction, and return these data.
Returns:
Interaction: The next batch of data.
"""
source_batch = self._next_source_patch()
target_batch = self._next_target_patch()
return dict(**source_batch, **target_batch)
[docs] def get_reference(self):
r"""Get reference documents for current data loader
return is supposed to be reference_corpus as list -> list -> word
"""
target_text = self.target_text if isinstance(self.target_text[0][0], str) else [sum(doc, []) for doc in self.target_text]
if self.config['tokenize_strategy'] == 'none':
return [text.split(' ') for text in target_text]
else:
return target_text