# @Time : 2020/12/26
# @Author : Jinhao Jiang
# @Email : jiangjinhao@std.uestc.edu.cn
r"""
MaskGAN Discriminator
#####################
"""
import torch
import torch.nn as nn
from textbox.model.abstract_generator import GenerativeAdversarialNet
from textbox.module.Encoder.rnn_encoder import BasicRNNEncoder
from textbox.module.Decoder.rnn_decoder import BasicRNNDecoder, AttentionalRNNDecoder
from textbox.model.init import xavier_normal_initialization
[docs]class MaskGANDiscriminator(GenerativeAdversarialNet):
r""" RNN-based Encoder-Decoder architecture for MaskGAN discriminator
"""
def __init__(self, config, dataset):
super(MaskGANDiscriminator, self).__init__(config, dataset)
# load parameters info
self.embedding_size = config['embedding_size']
self.hidden_size = config['hidden_size']
self.num_enc_layers = config['num_enc_layers']
self.num_dec_layers = config['num_dec_layers']
self.rnn_type = config['rnn_type']
self.bidirectional = config['bidirectional']
self.alignment_method = config['alignment_method']
self.dropout_ratio = config['dropout_ratio']
self.attention_type = config['attention_type']
self.context_size = config['context_size']
self.attention_type = config['attention_type']
self.mask_token_idx = dataset.user_token_idx[0]
self.encoder = BasicRNNEncoder(
self.embedding_size, self.hidden_size, self.num_enc_layers, self.rnn_type, self.dropout_ratio,
self.bidirectional
)
if self.attention_type is not None:
self.decoder = AttentionalRNNDecoder(
self.embedding_size, self.hidden_size, self.context_size, self.num_dec_layers, self.rnn_type,
self.dropout_ratio, self.attention_type, self.alignment_method
)
else:
self.decoder = BasicRNNDecoder(
self.embedding_size, self.hidden_size, self.num_dec_layers, self.rnn_type, self.dropout_ratio
)
self.dropout = nn.Dropout(self.dropout_ratio)
self.fc_linear = nn.Linear(self.hidden_size, 1)
self.critic_fc_linear = nn.Linear(self.hidden_size, 1)
# parameters initialization
self.apply(xavier_normal_initialization)
[docs] def forward(self, inputs, inputs_length, sequence, targets_present, embedder):
r"""Predict the real prob of the filled_in token using real sentence and fake sentence
Args:
inputs: real input bs*seq_len
inputs_length: sentences length list[bs]
sequence: real target or the generated sentence by Generator
targets_present: target sentences present matrix bs*seq_len
embedder: shared embedding with generator
Returns:
prediction: the real prob of filled_in token predicted by discriminator
"""
masked_inputs = self.mask_input(inputs, targets_present)
sequence_rnn_inputs = embedder(sequence) # bs*seq_len*emb
masked_rnn_inputs = embedder(masked_inputs)
sequence_rnn_inputs.detach()
masked_rnn_inputs.detach()
if self.rnn_type == "lstm":
encoder_outputs, encoder_final_state = self.encoder(masked_rnn_inputs, inputs_length)
h, c = encoder_final_state
h = h.contiguous()
c = c.contiguous()
encoder_final_state = (h, c)
else:
encoder_outputs, encoder_final_state = self.encoder(masked_rnn_inputs, inputs_length)
seq_len = inputs.size()[1]
hidden_size = encoder_final_state
predictions = []
values = []
for t in range(seq_len):
input = sequence_rnn_inputs[:, t, :].unsqueeze(dim=1) # bs*1*emb_dim
if self.attention_type is not None:
encoder_mask = torch.ones_like(inputs)
rnn_output, hidden_size, _ = self.decoder(input, hidden_size, encoder_outputs, encoder_mask)
else:
rnn_output, hidden_size = self.decoder(input, hidden_size)
prediction = self.fc_linear(rnn_output) # bs*1*1
prediction = prediction.squeeze(dim=2).squeeze(dim=1) # bs
predictions.append(prediction) # bs
# value = self.critic_fc_linear(rnn_output) # bs*1*1
# value = value.squeeze(dim=2).squeeze(dim=1) # bs
# values.append(value)
predictions = torch.stack(predictions, dim=1) # bs*seq_len
# values = torch.stack(values, dim=1) # bs*seq_len
return predictions, values
[docs] def critic(self, fake_sequence, embedder):
r"""Define the Critic graph which is derived from the seq2seq Discriminator. This will be
initialized with the same parameters as the language model and will share the forward RNN
components with the Discriminator. This estimates the V(s_t), where the state s_t = x_0,...,x_t-1.
Args:
fake_sequence: sequence generated bs*seq_len
Returns:
values: bs*seq_len
"""
fake_sequence = embedder(fake_sequence)
fake_sequence = fake_sequence.detach()
seq_len = fake_sequence.size()[1]
values = []
hidden_size = None
for t in range(seq_len):
if t == 0:
input = torch.zeros_like(fake_sequence[:, 0].unsqueeze(dim=1))
else:
input = fake_sequence[:, t - 1].unsqueeze(dim=1)
rnn_output, hidden_size = self.decoder.decoder(input, hidden_size)
rnn_output_ = rnn_output.detach() # only update critic_linear
value = self.critic_fc_linear(rnn_output_) # bs*1*1
values.append(value)
values = torch.stack(values, dim=1).squeeze() # bs*seq_len
return values
[docs] def calculate_dis_loss(self, fake_prediction, real_prediction, target_present):
r"""Compute Discriminator loss across real/fake
"""
missing = 1 - target_present.float() # bs*seq_len
real_label = torch.ones_like(real_prediction, dtype=torch.float) # bs*seq_len
fake_label = torch.zeros_like(fake_prediction, dtype=torch.float)
loss = nn.BCEWithLogitsLoss(weight=missing, reduction='none')
real_prediction = torch.clamp(real_prediction, min=-5, max=5)
fake_prediction = torch.clamp(fake_prediction, min=-5, max=5)
real_loss = loss(real_prediction, real_label)
real_loss = torch.sum(real_loss) / torch.sum(missing)
fake_loss = loss(fake_prediction, fake_label)
fake_loss = torch.sum(fake_loss) / torch.sum(missing)
loss = (real_loss + fake_loss) / 2
return loss
[docs] def create_critic_loss(self, cumulative_rewards, estimated_values, target_present):
r"""Compute Critic loss in estimating the value function. This should be an
estimate only for the missing elements.
"""
missing = 1 - target_present.float()
estimated_values = torch.where(target_present, cumulative_rewards, estimated_values)
loss = nn.MSELoss(reduction='none')
l = loss(estimated_values, cumulative_rewards)
l = torch.sum(l) / torch.sum(missing)
return l
[docs] def calculate_loss(self, real_sequence, lengths, fake_sequence, targets_present, embedder):
r"""Calculate discriminator loss
"""
fake_prediction, _ = self.forward(real_sequence, lengths, fake_sequence, targets_present, embedder)
real_prediction, _ = self.forward(real_sequence, lengths, real_sequence, targets_present, embedder)
return self.calculate_dis_loss(fake_prediction, real_prediction, targets_present)
[docs] def mask_target_present(self, targets_present, lengths):
batch_size, seq_len = targets_present.size()
for i in range(batch_size):
len = lengths[i].int()
targets_present[i, len:] = True
return targets_present