Source code for textbox.data.dataset.attr_sent_dataset

# @Time   : 2021/1/30
# @Author : Tianyi Tang
# @Email  : steven_tang@ruc.edu.cn

# UPDATE:
# @Time   : 2021/10/10
# @Author : Tianyi Tang
# @Email  : steven_tang@ruc.edu.cn

"""
textbox.data.dataset.attr_sent_dataset
########################################
"""

import os
from textbox.data.dataset import AbstractDataset
from textbox.data.utils import build_vocab, text2idx, build_attribute_vocab, attribute2idx


[docs]class AttributedSentenceDataset(AbstractDataset): def __init__(self, config): super().__init__(config) def _get_preset(self): super()._get_preset() self.source_text = [] self.target_text = [] def _load_attribute(self, dataset_path): if not os.path.isfile(dataset_path): raise ValueError('File {} not exist'.format(dataset_path)) attribute_data = [] with open(dataset_path, "r") as fin: for line in fin: attribute = line.strip().split('\t') attribute_data.append(attribute) return attribute_data def _load_source_data(self): for i, prefix in enumerate(['train', 'valid', 'test']): filename = os.path.join(self.dataset_path, f'{prefix}.src') text_data = self._load_attribute(filename) assert len(text_data) == len(self.target_text[i]) self.source_text.append(text_data) def _build_vocab(self): self.source_idx2token, self.source_token2idx = build_attribute_vocab(self.source_text) self.target_idx2token, self.target_token2idx, self.target_vocab_size = build_vocab( self.target_text, self.target_vocab_size, self.special_token_list ) def _text2idx(self): self.source_idx = attribute2idx(self.source_text, self.source_token2idx) self.target_idx, self.target_length, _ = text2idx( self.target_text, self.target_token2idx, self.tokenize_strategy )