# @Time : 2020/11/17
# @Author : Xiaoxuan Hu
# @Email : huxiaoxuan@ruc.edu.cn
r"""
MaliGAN Discriminator
#####################
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from textbox.model.abstract_generator import UnconditionalGenerator
[docs]class MaliGANDiscriminator(UnconditionalGenerator):
r"""MaliGANDiscriminator is LSTMs.
"""
def __init__(self, config, dataset):
super(MaliGANDiscriminator, self).__init__(config, dataset)
self.hidden_size = config['hidden_size']
self.embedding_size = config['discriminator_embedding_size']
self.max_length = config['seq_len'] + 2
self.num_dis_layers = config['num_dis_layers']
self.dropout_rate = config['dropout_rate']
self.LSTM = nn.LSTM(self.embedding_size, self.hidden_size, self.num_dis_layers, batch_first=True)
self.word_embedding = nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=self.padding_token_idx)
self.vocab_projection = nn.Linear(self.hidden_size, self.vocab_size)
self.hidden_linear = nn.Linear(self.num_dis_layers * self.hidden_size, self.hidden_size)
self.label_linear = nn.Linear(self.hidden_size, 1)
self.dropout = nn.Dropout(self.dropout_rate)
[docs] def forward(self, data):
r"""Calculate the probability that the data is realistic.
Args:
data (torch.Tensor): The sentence data, shape: [batch_size, max_seq_len].
Returns:
torch.Tensor: The probability that each sentence is realistic, shape: [batch_size].
"""
data_embedding = self.word_embedding(data) # b * l * e
_, (hidden, _) = self.LSTM(data_embedding) # hidden: b * num_layers * h
out = self.hidden_linear(
hidden.view(-1, self.num_dis_layers * self.hidden_size)
) # b * (num_layers * h) -> b * h
pred = self.label_linear(self.dropout(torch.tanh(out))).squeeze(1) # b * h -> b
pred = torch.sigmoid(pred)
return pred
[docs] def calculate_loss(self, real_data, fake_data):
r"""Calculate the loss for real data and fake data.
The discriminator is trained with the standard objective that GAN employs.
Args:
real_data (torch.Tensor): The realistic sentence data, shape: [batch_size, max_seq_len].
fake_data (torch.Tensor): The generated sentence data, shape: [batch_size, max_seq_len].
Returns:
torch.Tensor: The calculated loss of real data and fake data, shape: [].
"""
real_y = self.forward(real_data) # b * l --> b
fake_y = self.forward(fake_data)
logits = torch.cat((real_y, fake_y), dim=0)
real_label = torch.ones_like(real_y)
fake_label = torch.zeros_like(fake_y)
target = torch.cat((real_label, fake_label), dim=0)
loss = F.binary_cross_entropy(logits, target)
return loss