Source code for textbox.data.dataset.paired_sent_dataset

# @Time   : 2020/11/16
# @Author : Junyi Li
# @Email  : lijunyi@ruc.edu.cn

# UPDATE:
# @Time   : 2021/10/10, 2021/1/29, 2020/12/04
# @Author : Tianyi Tang, Gaole He
# @Email  : steven_tang@ruc.edu.cn, hegaole@ruc.edu.cn

"""
textbox.data.dataset.paired_sent_dataset
########################################
"""

import os
from textbox.data.dataset import AbstractDataset
from textbox.data.utils import load_data, build_vocab, text2idx


[docs]class PairedSentenceDataset(AbstractDataset): def __init__(self, config): self.share_vocab = config['share_vocab'] super().__init__(config) def _get_preset(self): super()._get_preset() self.source_text = [] self.target_text = [] def _load_source_data(self): for i, prefix in enumerate(['train', 'valid', 'test']): filename = os.path.join(self.dataset_path, f'{prefix}.src') text_data = load_data( filename, self.tokenize_strategy, self.source_max_length, self.source_language, self.source_multi_sentence, self.source_max_num ) assert len(text_data) == len(self.target_text[i]) self.source_text.append(text_data) def _build_vocab(self): if self.share_vocab: assert self.source_vocab_size == self.target_vocab_size text_data = self.source_text + self.target_text self.source_idx2token, self.source_token2idx, self.source_vocab_size = build_vocab( text_data, self.source_vocab_size, self.special_token_list ) self.target_idx2token, self.target_token2idx, self.target_vocab_size = self.source_idx2token, self.source_token2idx, self.source_vocab_size else: self.source_idx2token, self.source_token2idx, self.source_vocab_size = build_vocab( self.source_text, self.source_vocab_size, self.special_token_list ) self.target_idx2token, self.target_token2idx, self.target_vocab_size = build_vocab( self.target_text, self.target_vocab_size, self.special_token_list ) def _text2idx(self): self.source_idx, self.source_length, self.source_num = text2idx( self.source_text, self.source_token2idx, self.tokenize_strategy ) self.target_idx, self.target_length, self.target_num = text2idx( self.target_text, self.target_token2idx, self.tokenize_strategy )
[docs]class CopyPairedSentenceDataset(PairedSentenceDataset): def __init__(self, config): super(CopyPairedSentenceDataset, self).__init__(config) def _text2idx(self): data_dict = self.text2idx(self.source_text, self.target_text, self.target_token2idx, self.sos_token_idx, self.eos_token_idx, self.unknown_token_idx, self.config['is_pgen']) for key, value in data_dict.items(): setattr(self, key, value)
[docs] @staticmethod def text2idx(source_text, target_text, token2idx, sos_idx, eos_idx, unk_idx, is_pgen=False): data_dict = {'source_idx': [], 'source_length': [], 'target_input_idx': [], 'target_output_idx': [], 'target_length': []} if is_pgen: data_dict['source_extended_idx'] = [] data_dict['source_oovs'] = [] def article2ids(article_words): ids = [] oovs = [] for w in article_words: i = token2idx.get(w, unk_idx) if i == unk_idx: if w not in oovs: oovs.append(w) oov_num = oovs.index(w) ids.append(len(token2idx) + oov_num) else: ids.append(i) return ids, oovs def abstract2ids(abstract_words, article_oovs): ids = [] for w in abstract_words: i = token2idx.get(w, unk_idx) if i == unk_idx: if w in article_oovs: vocab_idx = len(token2idx) + article_oovs.index(w) ids.append(vocab_idx) else: ids.append(unk_idx) else: ids.append(i) return ids for i, prefix in enumerate(['train', 'valid', 'test']): new_source_idx = [] new_source_length = [] new_target_input_idx = [] new_target_output_idx = [] new_target_length = [] if is_pgen: new_source_extended_idx = [] new_source_oovs = [] for source_sent, target_sent in zip(source_text[i], target_text[i]): source_idx = [token2idx.get(word, unk_idx) for word in source_sent] target_input_idx = [sos_idx] + [token2idx.get(word, unk_idx) for word in target_sent] if is_pgen: source_extended_idx, source_oovs = article2ids(source_sent) target_output_idx = abstract2ids(target_sent, source_oovs) + [eos_idx] new_source_extended_idx.append(source_extended_idx) new_source_oovs.append(source_oovs) else: target_output_idx = [token2idx.get(word, unk_idx) for word in target_sent] + [eos_idx] new_source_idx.append(source_idx) new_source_length.append(len(source_idx)) new_target_input_idx.append(target_input_idx) new_target_output_idx.append(target_output_idx) new_target_length.append(len(target_input_idx)) data_dict['source_idx'].append(new_source_idx) data_dict['source_length'].append(new_source_length) data_dict['target_input_idx'].append(new_target_input_idx) data_dict['target_output_idx'].append(new_target_output_idx) data_dict['target_length'].append(new_target_length) if is_pgen: data_dict['source_extended_idx'].append(new_source_extended_idx) data_dict['source_oovs'].append(new_source_oovs) return data_dict