Source code for textbox.module.Encoder.transformer_encoder

# @Time   : 2020/11/14
# @Author : Junyi Li
# @Email  : lijunyi@ruc.edu.cn

r"""
Transformer Encoder
####################
"""

import torch
from torch import nn
from torch.nn import Parameter
from textbox.module.layers import TransformerLayer
import torch.nn.functional as F


[docs]class TransformerEncoder(torch.nn.Module): r""" The stacked Transformer encoder layers. """ def __init__( self, embedding_size, ffn_size, num_enc_layers, num_heads, attn_dropout_ratio=0.0, attn_weight_dropout_ratio=0.0, ffn_dropout_ratio=0.0 ): super(TransformerEncoder, self).__init__() self.transformer_layers = nn.ModuleList() for _ in range(num_enc_layers): self.transformer_layers.append( TransformerLayer( embedding_size, ffn_size, num_heads, attn_dropout_ratio, attn_weight_dropout_ratio, ffn_dropout_ratio ) )
[docs] def forward(self, x, kv=None, self_padding_mask=None, output_all_encoded_layers=False): r""" Implement the encoding process step by step. Args: x (Torch.Tensor): target sequence embedding, shape: [batch_size, sequence_length, embedding_size]. kv (Torch.Tensor): the cached history latent vector, shape: [batch_size, sequence_length, embedding_size], default: None. self_padding_mask (Torch.Tensor): padding mask of target sequence, shape: [batch_size, sequence_length], default: None. output_all_encoded_layers (Bool): whether to output all the encoder layers, default: ``False``. Returns: Torch.Tensor: output features, shape: [batch_size, sequence_length, ffn_size]. """ all_encoded_layers = [] for idx, layer in enumerate(self.transformer_layers): x, _, _ = layer(x, kv, self_padding_mask) all_encoded_layers.append(x) if output_all_encoded_layers: return all_encoded_layers return all_encoded_layers[-1]