Source code for textbox.data.utils

# @Time   : 2020/11/14
# @Author : Junyi Li, Gaole He
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2021/10/9, 2021/1/29, 2020/12/04
# @Author : Tianyi Tang, Gaole He2021/10/10
# @Email  : steven_tang@ruc.edu.cn, hegaole@ruc.edu.cn

"""
textbox.data.utils
########################
"""

import os
from posix import listdir
import nltk
import collections
import torch
import copy
import shutil
from logging import getLogger
from textbox.utils.enum_type import SpecialTokens


[docs]def get_dataset(config): """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`. Args: config (Config): An instance object of Config, used to record parameter information. Returns: Dataset: Constructed dataset. """ task_type = config['task_type'].lower() if task_type == "unconditional": from .dataset import SingleSentenceDataset return SingleSentenceDataset elif task_type == "attribute": from .dataset import AttributedSentenceDataset return AttributedSentenceDataset elif task_type in ["translation", "summarization"]: if config['model'] == 'PointerNet': from .dataset import CopyPairedSentenceDataset return CopyPairedSentenceDataset from .dataset import PairedSentenceDataset return PairedSentenceDataset elif task_type in ["kg2text"]: from .dataset import KGSentenceDataset return KGSentenceDataset elif config['dataset'] == 'WikiBio': from .dataset import WikiBioSentenceDataset return WikiBioSentenceDataset elif config['dataset'] == 'RotoWire': from .dataset import RotoWireSentenceDataset return RotoWireSentenceDataset else: raise NotImplementedError("No such dataset for TASK_TYPE: {}".format(task_type))
[docs]def dataloader_construct(name, config, dataset, batch_size=1, shuffle=False, drop_last=True, DDP=False): """Get a correct dataloader class by calling :func:`get_dataloader` to construct dataloader. Args: name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. dataset (Dataset or list of Dataset): The split dataset for constructing dataloader. batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. drop_last (bool, optional): Whether the dataloader will drop the last batch. Defaults to ``True``. DDP (bool, optional): Whether the dataloader will distribute in different GPU. Defaults to ``False``. Returns: AbstractDataLoader or list of AbstractDataLoader: Constructed dataloader in split dataset. """ task_type = config['task_type'].lower() logger = getLogger() logger.info('Build [{}] DataLoader for [{}]'.format(task_type, name)) logger.info('batch_size = [{}], shuffle = [{}], drop_last = [{}]\n'.format(batch_size, shuffle, drop_last)) DataLoader = get_dataloader(config) return DataLoader( config=config, dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, DDP=DDP )
[docs]def get_dataloader(config): """Return a dataloader class according to :attr:`config` and :attr:`split_strategy`. Args: config (Config): An instance object of Config, used to record parameter information. Returns: type: The dataloader class that meets the requirements in :attr:`config` and :attr:`split_strategy`. """ task_type = config['task_type'].lower() if task_type == "unconditional": from .dataloader import SingleSentenceDataLoader return SingleSentenceDataLoader elif task_type == "attribute": from .dataloader import AttributedSentenceDataLoader return AttributedSentenceDataLoader elif task_type in ["translation", "summarization"]: if config['model'] == 'PointerNet': from .dataloader import CopyPairedSentenceDataLoader return CopyPairedSentenceDataLoader from .dataloader import PairedSentenceDataLoader return PairedSentenceDataLoader elif task_type in ["kg2text"]: from .dataloader import KGSentenceDataLoader return KGSentenceDataLoader elif config['dataset'] == 'WikiBio': from .dataloader import WikiBioSentenceDataLoader return WikiBioSentenceDataLoader elif config['dataset'] == 'RotoWire': from .dataloader import RotoWireSentenceDataLoader return RotoWireSentenceDataLoader else: raise NotImplementedError("No such dataloader for TASK_TYPE: {}".format(task_type))
[docs]def construct_quick_test_dataset(dataset_path): files = listdir(dataset_path) for file in files: filename = os.path.join(dataset_path, file) if filename.endswith('.bin'): os.remove(filename) else: shutil.copy(filename, filename + '.tmp') for file in files: filename = os.path.join(dataset_path, file) if not filename.endswith('.bin'): with open(filename + '.tmp', 'r') as fin, open(filename, 'w') as fout: for line in fin.readlines()[:10]: fout.write(line)
[docs]def deconstruct_quick_test_dataset(dataset_path): files = listdir(dataset_path) for file in files: filename = os.path.join(dataset_path, file) if filename.endswith('.bin'): os.remove(filename) elif not filename.endswith('.tmp'): shutil.move(filename + '.tmp', filename)
[docs]def data_preparation(config, save=False): """call :func:`dataloader_construct` to create corresponding dataloader. Args: config (Config): An instance object of Config, used to record parameter information. save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. Defaults to ``False``. Returns: tuple: - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """ if config['quick_test']: construct_quick_test_dataset(config['data_path']) dataset = get_dataset(config)(config) if config['quick_test']: deconstruct_quick_test_dataset(config['data_path']) train_dataset = copy.copy(dataset) valid_dataset = copy.copy(dataset) test_dataset = copy.copy(dataset) for prefix in ['train', 'valid', 'test']: dataset = locals()[f'{prefix}_dataset'] content = getattr(dataset, f'{prefix}_data') for key, value in content.items(): setattr(dataset, key, value) train_data = dataloader_construct( name='train', config=config, dataset=train_dataset, batch_size=config['train_batch_size'], shuffle=True, DDP=True ) valid_data = dataloader_construct( name='valid', config=config, dataset=valid_dataset, batch_size=config['train_batch_size'], shuffle=True, DDP=True ) test_data = dataloader_construct( name='test', config=config, dataset=test_dataset, batch_size=config['eval_batch_size'], drop_last=False, ) return train_data, valid_data, test_data
[docs]def tokenize(text, tokenize_strategy, language, multi_sentence): """Tokenize text data. Args: text (str): text data. tokenize_strategy (str): strategy of tokenizer. language (str): language of text. multi_sentence (bool): whether to split text into sentence level. Returns: List[str]: the tokenized text data. """ if multi_sentence: text = text.split('\t') if tokenize_strategy == 'none': words = text elif tokenize_strategy == 'by_space': words = [t.split() for t in text] elif tokenize_strategy == 'nltk': words = [nltk.word_tokenize(t, language=language) for t in text] else: text.replace('\t', ' ') if tokenize_strategy == 'none': words = text elif tokenize_strategy == 'by_space': words = text.split() elif tokenize_strategy == 'nltk': words = nltk.word_tokenize(text, language=language) return words
[docs]def load_data(dataset_path, tokenize_strategy, max_length, language, multi_sentence, max_num): """Load dataset from split (train, valid, test). This is designed for single sentence format. Args: dataset_path (str): path of dataset dir. tokenize_strategy (str): strategy of tokenizer. max_length (int): max length of sequence. language (str): language of text. multi_sentence (bool): whether to split text into sentence level. max_num (int): max number of sequence. Returns: List[List[str]]: the text list loaded from dataset path. """ if not os.path.isfile(dataset_path): raise ValueError('File {} not exist'.format(dataset_path)) text = [] with open(dataset_path, "r") as fin: for line in fin: line = line.strip().lower() words = tokenize(line, tokenize_strategy, language, multi_sentence) if isinstance(words, str): # no split text.append(words) elif isinstance(words[0], str): # single sentence text.append(words[:max_length]) else: # multiple sentences text.append([word[:max_length] for word in words[:max_num]]) return text
[docs]def build_vocab(text, max_vocab_size, special_token_list): """Build vocabulary of list of text data. Args: text (List[List[List[str]]] or List[List[List[List[str]]]]): list of text data, consisting of multiple groups. max_vocab_size (int): max size of vocabulary. special_token_list (List[str]): list of special tokens. Returns: tuple: - idx2token (dict): map index to token. - token2idx (dict): map token to index. - max_vocab_size (int): updated max size of vocabulary. """ word_list = list() for group in text: # train, valid, test if isinstance(group[0], str): # single word word_list.extend(group) else: for doc in group: if isinstance(doc[0], str): # single sentence word_list.extend(doc) else: for sent in doc: if isinstance(sent, tuple): # kg word_list.extend(sent[0] + [sent[1]] + sent[2]) else: # multiple sentences word_list.extend(sent) token_count = [(count, token) for token, count in collections.Counter(word_list).items()] token_count.sort(reverse=True) tokens = [word for count, word in token_count] tokens = special_token_list + tokens tokens = tokens[:max_vocab_size] max_vocab_size = len(tokens) idx2token = dict(zip(range(max_vocab_size), tokens)) token2idx = dict(zip(tokens, range(max_vocab_size))) return idx2token, token2idx, max_vocab_size
[docs]def build_attribute_vocab(text): """Build attribute vocabulary of list of attribute data. Args: text (List[List[List[str]]] or List[List[List[List[str]]]]): list of attribute data, consisting of multiple groups. Returns: tuple: - idx2token (dict): map index to token. - token2idx (dict): map token to index. """ attribute_num = len(text[0][0]) if isinstance(text[0][0][0], str) else len(text[0][0][0]) attribute_set = [set() for _ in range(attribute_num)] for group in text: for doc in group: if isinstance(doc[0], str): assert len(doc) == attribute_num for i, attr in enumerate(doc): attribute_set[i].add(attr) else: for sent in doc: assert len(sent) == attribute_num for i, attr in enumerate(sent): attribute_set[i].add(attr) idx2token = [] token2idx = [] for i in range(attribute_num): attribute = list(attribute_set[i]) attribute_size = len(attribute) idx2token.append(dict(zip(range(attribute_size), attribute))) token2idx.append(dict(zip(attribute, range(attribute_size)))) return idx2token, token2idx
[docs]def text2idx(text, token2idx, tokenize_strategy): r"""transform text to id and add sos and eos token index. Args: text (List[List[List[str]]] or List[List[List[List[str]]]]): list of text data, consisting of multiple groups. token2idx (dict): map token to index tokenize_strategy (str): strategy of tokenizer. Returns: idx (List[List[List[int]]] or List[List[List[List[int]]]]): word index length (List[List[int]] or List[List[List[int]]]): sequence length num (None or List[List[int]]): sequence number """ new_idx = [] new_length = [] requires_start_end = tokenize_strategy != 'none' sos_idx = token2idx[SpecialTokens.SOS] eos_idx = token2idx[SpecialTokens.EOS] unknown_idx = token2idx[SpecialTokens.UNK] if isinstance(text[0][0][0], str): # single sentence for group in text: idx = [] length = [] for sent in group: sent_idx = [token2idx.get(word, unknown_idx) for word in sent] if requires_start_end: sent_idx = [sos_idx] + sent_idx + [eos_idx] idx.append(sent_idx) length.append(len(sent_idx)) new_idx.append(idx) new_length.append(length) return new_idx, new_length, None else: # multiple sentences new_num = [] for group in text: idx = [] length = [] num = [] for doc in group: doc_idx = [] doc_length = [] for sent in doc: sent_idx = [token2idx.get(word, unknown_idx) for word in sent] if requires_start_end: sent_idx = [sos_idx] + sent_idx + [eos_idx] doc_idx.append(sent_idx) doc_length.append(len(sent_idx)) idx.append(doc_idx) length.append(doc_length) num.append(len(doc)) new_idx.append(idx) new_length.append(length) new_num.append(num) return new_idx, new_length, new_num
[docs]def attribute2idx(text, token2idx): r"""transform attribute to id. Args: text (List[List[List[str]]] or List[List[List[List[str]]]]): list of attribute data, consisting of multiple groups. token2idx (dict): map token to index Returns: idx (List[List[List[int]]] or List[List[List[List[int]]]]): attribute index length (None or List[List[int]]): sequence length """ new_idx = [] new_length = [] for group in text: idx = [] length = [] for doc in group: if isinstance(doc[0], str): doc_idx = [token2idx[i][attr] for i, attr in enumerate(doc)] else: doc_idx = [] for sent in doc: sent_idx = [token2idx[i][attr] for i, attr in enumerate(sent)] doc_idx.append(sent_idx) length.append(len(doc)) idx.append(doc_idx) new_idx.append(idx) new_length.append(length) if new_length[0] != []: return new_idx, new_length else: return new_idx
[docs]def pad_sequence(idx, length, padding_idx, num=None): r"""padding a batch of word index data, to make them have equivalent length Args: idx (List[List[int]] or List[List[List[int]]]): word index length (List[int] or List[List[int]]): sequence length padding_idx (int): the index of padding token num (List[int]): sequence number Returns: idx (List[List[int]] or List[List[List[int]]]): word index length (List[int] or List[List[int]]): sequence length num (List[int]): sequence number """ if num is None: max_length = max(length) new_idx = [] for sent_idx, sent_length in zip(idx, length): new_idx.append(sent_idx + [padding_idx] * (max_length - sent_length)) new_idx = torch.LongTensor(new_idx) length = torch.LongTensor(length) return new_idx, length, None else: max_length = max([max(sent_length) for sent_length in length]) max_num = max(num) new_length = [] new_idx = [] for doc_idx, doc_length, doc_num in zip(idx, length, num): new_length.append(doc_length + [0] * (max_num - doc_num)) new_sent_idx = [] for sent_idx, sent_length in zip(doc_idx, doc_length): new_sent_idx.append(sent_idx + [padding_idx] * (max_length - sent_length)) for _ in range(max_num - doc_num): new_sent_idx.append([0] * max_length) new_idx.append(new_sent_idx) new_num = torch.LongTensor(num) new_length = torch.LongTensor(new_length) new_idx = torch.LongTensor(new_idx) return new_idx, new_length, new_num