Source code for textbox.module.Embedder.position_embedder

# @Time   : 2020/11/14
# @Author : Junyi Li
# @Email  :

Positional Embedding

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import math

[docs]class LearnedPositionalEmbedding(nn.Module): r"""This module produces Learned Positional Embedding. """ def __init__(self, embedding_size, max_length=512): super(LearnedPositionalEmbedding, self).__init__() self.weights = nn.Embedding(max_length, embedding_size)
[docs] def forward(self, input_seq, offset=0): batch_size, seq_len = input_seq.size() positions = (offset + torch.arange(seq_len)) pos_embeddings = self.weights(positions).unsqueeze(0).expand(batch_size, -1, -1) return pos_embeddings
[docs]class SinusoidalPositionalEmbedding(nn.Module): r"""This module produces sinusoidal positional embeddings of any length. """ def __init__(self, embedding_size, max_length=512): super(SinusoidalPositionalEmbedding, self).__init__() self.embedding_size = embedding_size self.weights = SinusoidalPositionalEmbedding.get_embedding(max_length, embedding_size)
[docs] @staticmethod def get_embedding(max_length, embedding_size): r"""Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_size // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(max_length, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb =[torch.sin(emb), torch.cos(emb)], dim=1).view(max_length, -1) if embedding_size % 2 == 1: # zero pad emb =[emb, torch.zeros(max_length, 1)], dim=1) return emb
[docs] def forward(self, input_seq, offset=0): batch_size, seq_len = input_seq.size() max_position = seq_len + offset if self.weights is None or max_position > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_position, self.embedding_size, ) positions = offset + torch.arange(seq_len) pos_embeddings = self.weights.index_select(0, positions).unsqueeze(0).expand(batch_size, -1, -1).detach() return pos_embeddings