# @Time : 2020/11/14
# @Author : Junyi Li
# @Email : lijunyi@ruc.edu.cn
r"""
CNN Decoder
###############
"""
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
[docs]class BasicCNNDecoder(torch.nn.Module):
"""
Basic Convolution Neural Network (CNN) decoder.
Code Reference: https://github.com/kefirski/contiguous-succotash
"""
def __init__(self, input_size, latent_size, decoder_kernel_size, decoder_dilations, dropout_ratio):
super(BasicCNNDecoder, self).__init__()
self.latent_size = latent_size
self.input_size = input_size
self.dropout_ratio = dropout_ratio
self.decoder_dilations = decoder_dilations
if isinstance(decoder_kernel_size, int):
self.decoder_kernel_size = [decoder_kernel_size]
elif isinstance(decoder_kernel_size, list):
self.decoder_kernel_size = decoder_kernel_size
else:
raise NotImplementedError("Unrecognized hyper parameters: {}".format(decoder_kernel_size))
self.dropout = nn.Dropout(self.dropout_ratio)
self.decoder_kernels, self.decoder_biases, self.decoder_paddings = self._module_def()
def _module_def(self):
assert len(self.decoder_kernel_size) <= 3
decoder_kernels = []
for i, out_channel in enumerate(self.decoder_kernel_size):
if i == 0:
in_channel = self.latent_size + self.input_size
else:
in_channel = self.decoder_kernel_size[i - 1]
decoder_kernels.append(nn.Parameter(torch.Tensor(out_channel, in_channel, 3).normal_(0, 0.05)))
decoder_biases = [
nn.Parameter(torch.Tensor(out_channel).normal_(0, 0.05)) for out_channel in self.decoder_kernel_size
]
decoder_paddings = [2 * self.decoder_dilations[i] for i in range(len(decoder_kernels))]
return decoder_kernels, decoder_biases, decoder_paddings
[docs] def forward(self, decoder_input, noise):
r""" Implement the decoding process.
Args:
decoder_input (Torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size].
noise (Torch.Tensor): latent code, shape: [batch_size, latent_size].
Returns:
torch.Tensor: output features, shape: [batch_size, sequence_length, feature_size].
"""
device = decoder_input.device
batch_size, seq_len, _ = decoder_input.size()
z = noise.unsqueeze(1).expand(-1, seq_len, -1)
decoder_input = torch.cat([decoder_input, z], 2)
decoder_input = self.dropout(decoder_input)
# x is tensor with shape [batch_size, input_size=in_channels, seq_len=input_width]
x = decoder_input.transpose(1, 2).contiguous()
for layer, kernel in enumerate(self.decoder_kernels):
# apply conv layer with non-linearity and drop last elements of sequence to perfrom input shifting
x = F.conv1d(
x,
weight=kernel.to(device),
bias=self.decoder_biases[layer].to(device),
dilation=self.decoder_dilations[layer],
padding=self.decoder_paddings[layer]
)
x_width = x.size(2)
x = x[:, :, :(x_width - self.decoder_paddings[layer])].contiguous()
x = F.relu(x)
result = x.transpose(1, 2).contiguous()
return result
[docs]class HybridDecoder(nn.Module):
"""
Hybrid Convolution Neural Network (CNN) and Recurrent Neural Network (RNN) decoder.
Code Reference: https://github.com/kefirski/hybrid_rvae
"""
def __init__(self, embedding_size, latent_size, hidden_size, num_dec_layers, rnn_type, vocab_size):
super(HybridDecoder, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.embedding_size = embedding_size
self.num_dec_layers = num_dec_layers
self.rnn_type = rnn_type
self.cnn = nn.Sequential(
nn.ConvTranspose1d(self.latent_size, 512, 4, 2, 0), nn.BatchNorm1d(512), nn.ELU(),
nn.ConvTranspose1d(512, 512, 4, 2, 0, output_padding=1), nn.BatchNorm1d(512), nn.ELU(),
nn.ConvTranspose1d(512, 256, 4, 2, 0), nn.BatchNorm1d(256), nn.ELU(),
nn.ConvTranspose1d(256, 256, 4, 2, 0, output_padding=1), nn.BatchNorm1d(256), nn.ELU(),
nn.ConvTranspose1d(256, 128, 4, 2, 0), nn.BatchNorm1d(128), nn.ELU(),
nn.ConvTranspose1d(128, self.vocab_size, 4, 2, 0)
)
if rnn_type == 'lstm':
self.rnn = nn.LSTM(embedding_size + vocab_size, hidden_size, num_dec_layers, batch_first=True)
elif rnn_type == "gru":
self.rnn = nn.GRU(embedding_size + vocab_size, hidden_size, num_dec_layers, batch_first=True)
elif rnn_type == "rnn":
self.rnn = nn.RNN(embedding_size + vocab_size, hidden_size, num_dec_layers, batch_first=True)
else:
raise ValueError("The RNN type in hybrid decoder must in ['lstm', 'gru', 'rnn'].")
self.token_vocab = nn.Linear(self.hidden_size, self.vocab_size)
[docs] def forward(self, decoder_input, latent_variable):
r""" Implement the decoding process.
Args:
decoder_input (Torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size].
latent_variable (Torch.Tensor): latent code, shape: [batch_size, latent_size].
Returns:
tuple:
- torch.Tensor: RNN output features, shape: [batch_size, sequence_length, feature_size].
- torch.Tensor: CNN output features, shape: [batch_size, sequence_length, feature_size].
"""
cnn_logits = self.conv_decoder(latent_variable)
cnn_logits = cnn_logits[:, :decoder_input.size(1), :].contiguous() # seq_len
rnn_logits, _ = self.rnn_decoder(cnn_logits, decoder_input)
return rnn_logits, cnn_logits
[docs] def conv_decoder(self, latent_variable):
r""" Implement the CNN decoder.
Args:
latent_variable (Torch.Tensor): latent code, shape: [batch_size, latent_size].
Returns:
torch.Tensor: output features, shape: [batch_size, sequence_length, feature_size].
"""
latent_variable = latent_variable.unsqueeze(2)
logits = self.cnn(latent_variable).permute(0, 2, 1)
return logits
[docs] def rnn_decoder(self, cnn_logits, decoder_input, initial_state=None):
r""" Implement the RNN decoder using CNN output.
Args:
cnn_logits (Torch.Tensor): latent code, shape: [batch_size, sequence_length, feature_size].
decoder_input (Torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size].
initial_state (Torch.Tensor): initial hidden states, default: None.
Returns:
tuple:
- Torch.Tensor: output features, shape: [batch_size, sequence_length, num_directions * hidden_size].
- Torch.Tensor: hidden states, shape: [batch_size, num_layers * num_directions, hidden_size].
"""
outputs, hidden_states = self.rnn(torch.cat([cnn_logits, decoder_input], 2), initial_state)
logits = self.token_vocab(outputs)
return logits, hidden_states