Source code for textbox.module.layers
# -*- coding: utf-8 -*-
# @Time : 2020/11/14
# @Author : Junyi Li
# @Email : lijunyi@ruc.edu.cn
"""
textbox.module.layers
#############################
Common Layers in text generation
"""
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import normal_
from textbox.module.Attention.attention_mechanism import MultiHeadAttention
[docs]class Highway(nn.Module):
r"""Highway Layers
Args:
- num_highway_layers(int): number of highway layers.
- input_size(int): size of highway input.
"""
def __init__(self, num_highway_layers, input_size):
super(Highway, self).__init__()
self.num_highway_layers = num_highway_layers
self.non_linear = nn.ModuleList([nn.Linear(input_size, input_size) for _ in range(self.num_highway_layers)])
self.linear = nn.ModuleList([nn.Linear(input_size, input_size) for _ in range(self.num_highway_layers)])
self.gate = nn.ModuleList([nn.Linear(input_size, input_size) for _ in range(self.num_highway_layers)])
[docs] def forward(self, x):
for layer in range(self.num_highway_layers):
gate = torch.sigmoid(self.gate[layer](x))
# Compute percentage of non linear information to be allowed for each element in x
non_linear = F.relu(self.non_linear[layer](x))
# Compute non linear information
linear = self.linear[layer](x)
# Compute linear information
x = gate * non_linear + (1 - gate) * linear
# Combine non linear and linear information according to gate
return x
[docs]class TransformerLayer(torch.nn.Module):
r"""Transformer Layer, including
a multi-head self-attention,
a external multi-head self-attention layer (only for conditional decoder) and
a point-wise feed-forward layer.
Args:
self_padding_mask (torch.bool): the padding mask for the multi head attention sublayer.
self_attn_mask (torch.bool): the attention mask for the multi head attention sublayer.
external_states (torch.Tensor): the external context for decoder, e.g., hidden states from encoder.
external_padding_mask (torch.bool): the padding mask for the external states.
Returns:
feedforward_output (torch.Tensor): the output of the point-wise feed-forward sublayer, is the output of the transformer layer
"""
def __init__(
self,
embedding_size,
ffn_size,
num_heads,
attn_dropout_ratio=0.0,
attn_weight_dropout_ratio=0.0,
ffn_dropout_ratio=0.0,
with_external=False
):
super(TransformerLayer, self).__init__()
self.multi_head_attention = MultiHeadAttention(embedding_size, num_heads, attn_weight_dropout_ratio)
self.feed_forward_1 = nn.Linear(embedding_size, ffn_size)
self.feed_forward_2 = nn.Linear(ffn_size, embedding_size)
self.attn_layer_norm = nn.LayerNorm(embedding_size, eps=1e-6)
self.ffn_layer_norm = nn.LayerNorm(embedding_size, eps=1e-6)
self.attn_dropout = nn.Dropout(attn_dropout_ratio)
self.ffn_dropout = nn.Dropout(ffn_dropout_ratio)
self.with_external = with_external
if self.with_external:
self.external_multi_head_attention = MultiHeadAttention(
embedding_size, num_heads, attn_weight_dropout_ratio
)
self.external_layer_norm = nn.LayerNorm(embedding_size)
self.reset_parameters()
[docs] def reset_parameters(self):
nn.init.normal_(self.feed_forward_1.weight, std=0.02)
nn.init.normal_(self.feed_forward_2.weight, std=0.02)
nn.init.constant_(self.feed_forward_1.bias, 0.)
nn.init.constant_(self.feed_forward_2.bias, 0.)
[docs] def forward(
self,
x,
kv=None,
self_padding_mask=None,
self_attn_mask=None,
external_states=None,
external_padding_mask=None
):
residual = x
if kv is None:
x, self_attn_weights = self.multi_head_attention(
query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask
)
else:
x, self_attn_weights = self.multi_head_attention(
query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask
)
x = self.attn_dropout(x)
x = self.attn_layer_norm(residual + x)
if self.with_external:
residual = x
x, external_attn_weights = self.external_multi_head_attention(
query=x, key=external_states, value=external_states, key_padding_mask=external_padding_mask
)
x = self.attn_dropout(x)
x = self.external_layer_norm(residual + x)
else:
external_attn_weights = None
residual = x
x = self.feed_forward_2(self.gelu(self.feed_forward_1(x)))
x = self.ffn_dropout(x)
x = self.ffn_layer_norm(residual + x)
return x, self_attn_weights, external_attn_weights