Source code for textbox.model.LM.gpt2

# @Time   : 2020/11/15
# @Author : Junyi Li
# @Email  : lijunyi@ruc.edu.cn

r"""
GPT-2
################################################
Reference:
    Radford et al. "Language models are unsupervised multitask".
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from textbox.model.abstract_generator import UnconditionalGenerator
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from math import ceil


[docs]class GPT2(UnconditionalGenerator): r"""GPT-2 is an auto-regressive language model with stacked Transformer decoders. """ def __init__(self, config, dataset): super(GPT2, self).__init__(config, dataset) self.pretrained_model_path = config['pretrained_model_path'] self.tokenizer = GPT2Tokenizer.from_pretrained( self.pretrained_model_path, bos_token=dataset.sos_token, eos_token=dataset.eos_token, pad_token=dataset.padding_token ) self.sos_token = self.tokenizer.bos_token self.eos_token = self.tokenizer.eos_token self.sos_token_idx = self.tokenizer.bos_token_id self.eos_token_idx = self.tokenizer.eos_token_id self.padding_token_idx = self.tokenizer.pad_token_id self.configuration = GPT2Config.from_pretrained( self.pretrained_model_path, bos_token_id=self.sos_token_idx, eos_token_id=self.eos_token_idx, pad_token_id=self.padding_token_idx ) self.decoder = GPT2LMHeadModel.from_pretrained(self.pretrained_model_path, config=self.configuration) self.decoder.resize_token_embeddings(len(self.tokenizer)) self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none')
[docs] def generate(self, batch_data, eval_data): generate_corpus = [] batch_size = len(batch_data['target_text']) sample_outputs = self.decoder.generate( bos_token_id=self.sos_token_idx, do_sample=True, max_length=self.max_length, num_return_sequences=batch_size ) generated_text = self.tokenizer.batch_decode(sample_outputs, skip_special_tokens=True) generate_corpus.extend([text.lower().split() for text in generated_text]) return generate_corpus
[docs] def forward(self, corpus, epoch_idx=-1, nll_test=False): text_sequence = corpus['target_text'] input_ids = [] attn_masks = [] for text in text_sequence: sentence = ' '.join([self.sos_token] + text + [self.eos_token]) encoding_dict = self.tokenizer( sentence, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt" ) input_ids.append(encoding_dict['input_ids']) attn_masks.append(encoding_dict['attention_mask']) input_ids = torch.cat(input_ids, dim=0).to(self.device) attn_masks = torch.cat(attn_masks, dim=0).to(self.device) decoder_input_ids = input_ids[:, :-1].contiguous() decoder_target_ids = input_ids[:, 1:].contiguous() attn_masks = attn_masks[:, :-1].contiguous() outputs = self.decoder(decoder_input_ids, attention_mask=attn_masks, use_cache=False) token_logits = outputs.logits loss = self.loss(token_logits.view(-1, token_logits.size(-1)), decoder_target_ids.view(-1)) loss = loss.reshape_as(decoder_target_ids) if nll_test: loss = loss.sum(dim=1) else: length = (decoder_target_ids != self.padding_token_idx).sum(dim=1).float() loss = loss.sum(dim=1) / length.float() return loss.mean()
[docs] def calculate_nll_test(self, corpus, epoch_idx=-1): return self.forward(corpus, epoch_idx, True)