Source code for textbox.module.Attention.attention_mechanism

# @Time   : 2020/11/14
# @Author : Junyi Li
# @Email  :

# @Time   : 2020/12/26
# @Author : Jinhao Jiang
# @Email  :

Attention Layers

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import math

[docs]class LuongAttention(torch.nn.Module): r"""Luong Attention is proposed in the following paper: Effective Approaches to Attention-based Neural Machine Translation. Reference: """ def __init__(self, source_size, target_size, alignment_method='concat', is_coverage=False): super(LuongAttention, self).__init__() self.source_size = source_size self.target_size = target_size self.alignment_method = alignment_method self.is_coverage = is_coverage if self.is_coverage: self.coverage_linear = nn.Linear(1, target_size, bias=False) if self.alignment_method == 'general': self.energy_linear = nn.Linear(target_size, source_size, bias=False) elif self.alignment_method == 'concat': self.energy_linear = nn.Linear(source_size + target_size, target_size) self.v = nn.Parameter(torch.rand(target_size, dtype=torch.float32)) elif self.alignment_method == 'dot': assert self.source_size == target_size else: raise ValueError("The alignment method for Luong Attention must be in ['general', 'concat', 'dot'].")
[docs] def score(self, hidden_states, encoder_outputs, coverages=None): r"""Calculate the attention scores between encoder outputs and decoder states.""" tgt_len = hidden_states.size(1) src_len = encoder_outputs.size(1) if self.alignment_method == 'general': energy = self.energy_linear(hidden_states) encoder_outputs = encoder_outputs.permute(0, 2, 1) energy = energy.bmm(encoder_outputs) return energy elif self.alignment_method == 'concat': hidden_states = hidden_states.unsqueeze(2).repeat(1, 1, src_len, 1) # B * tgt_len * src_len * target_size encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, tgt_len, 1, 1) energy = self.energy_linear(, encoder_outputs), dim=-1)) if self.is_coverage: coverages = self.coverage_linear(coverages.unsqueeze(3)) energy = energy + coverages energy = torch.tanh(energy) energy = self.v.mul(energy).sum(dim=-1) return energy elif self.alignment_method == 'dot': encoder_outputs = encoder_outputs.permute(0, 2, 1) energy = hidden_states.bmm(encoder_outputs) return energy else: raise NotImplementedError( "No such alignment method {} for computing Luong scores.".format(self.alignment_method) )
[docs] def forward(self, hidden_states, encoder_outputs, encoder_masks, coverages=None): r""" Luong attention Args: hidden_states: shape: [batch_size, tgt_len, target_size] encoder_outputs: shape: [batch_size, src_len, source_size] encoder_masks: shape: [batch_size, src_len] Return: tuple: - context: shape: [batch_size, tgt_len, source_size] - probs: shape: [batch_size, tgt_len, src_len] """ tgt_len = hidden_states.size(1) energy = self.score(hidden_states, encoder_outputs, coverages=coverages) probs = F.softmax(energy, dim=-1) * encoder_masks.unsqueeze(1).repeat(1, tgt_len, 1) normalization_factor = probs.sum(-1, keepdim=True) + 1e-12 probs = probs / normalization_factor context = probs.bmm(encoder_outputs) if self.is_coverage: coverages = probs + coverages return context, probs, coverages return context, probs
[docs]class BahdanauAttention(torch.nn.Module): r"""Bahdanau Attention is proposed in the following paper: Neural Machine Translation by Jointly Learning to Align and Translate. Reference: """ def __init__(self, source_size, target_size): super(BahdanauAttention, self).__init__() self.source_size = source_size self.target_size = target_size self.energy_linear = nn.Linear(source_size + target_size, target_size) self.v = nn.Parameter(torch.FloatTensor(target_size))
[docs] def score(self, hidden_states, encoder_outputs): r"""Calculate the attention scores between encoder outputs and decoder states.""" src_len = encoder_outputs.size(1) hidden_states = hidden_states.unsqueeze(1).repeat(1, src_len, 1) energy = torch.tanh(self.energy_linear(, encoder_outputs), dim=-1))) energy = self.v.mul(energy).sum(dim=-1) return energy
[docs] def forward(self, hidden_states, encoder_outputs, encoder_masks): r""" Bahdanau attention Args: hidden_states: shape: [batch_size, tgt_len, target_size] encoder_outputs: shape: [batch_size, src_len, source_size] encoder_masks: shape: [batch_size, src_len] Return: tuple: - context: shape: [batch_size, tgt_len, source_size] - probs: shape: [batch_size, tgt_len, src_len] """ energy = self.score(hidden_states, encoder_outputs) probs = F.softmax(energy, dim=-1) * encoder_masks normalization_factor = probs.sum(-1, keepdim=True) + 1e-12 probs = probs / normalization_factor probs = probs.unsqueeze(1) context = probs.bmm(encoder_outputs) return context, probs
[docs]class MonotonicAttention(torch.nn.Module): r"""Monotonic Attention is proposed in the following paper: Online and Linear-Time Attention by Enforcing Monotonic Alignments. Reference: """ def __init__(self, source_size, target_size, init_r=-4): super(MonotonicAttention, self).__init__() self.source_size = source_size self.target_size = target_size self.w_linear = nn.Linear(source_size, target_size) self.v_linear = nn.Linear(target_size, target_size) self.bias = nn.Parameter(torch.Tensor(target_size).normal_()) self.v = nn.utils.weight_norm(nn.Linear(target_size, 1)) = torch.Tensor([1 / target_size]).sqrt() self.r = nn.Parameter(torch.Tensor([init_r]))
[docs] def gaussian_noise(self, *size): r"""Additive gaussian nosie to encourage discreteness""" return torch.FloatTensor(*size).normal_()
[docs] def safe_cumprod(self, x): r"""Numerically stable cumulative product by cumulative sum in log-space""" return torch.exp(torch.cumsum(torch.log(torch.clamp(x, min=1e-10, max=1)), dim=1))
[docs] def exclusive_cumprod(self, x): r"""Exclusive cumulative product [a, b, c] => [1, a, a * b]""" batch_size = x.size(0) ones = torch.ones(batch_size, 1).to(x.device) one_x =, x), dim=1)[:, :-1] return torch.cumprod(one_x, dim=1)
[docs] def score(self, hidden_states, encoder_outputs): r"""Calculate the attention scores between encoder outputs and decoder states.""" tgt_len = hidden_states.size(1) src_len = encoder_outputs.size(1) energy = torch.tanh( self.w_linear(encoder_outputs).unsqueeze(1).repeat(1, tgt_len, 1, 1) + self.v_linear(hidden_states).unsqueeze(2).repeat(1, 1, src_len, 1) + self.bias ) energy = self.v(energy).squeeze(-1) + self.r return energy
[docs] def soft(self, hidden_states, encoder_outputs, encoder_masks, previous_probs=None): r""" Soft monotonic attention (Train) Args: hidden_states: shape: [batch_size, tgt_len, target_size] encoder_outputs: shape: [batch_size, src_len, source_size] encoder_masks: shape: [batch_size, src_len] previous_probs: shape: [batch_size, tgt_len, src_len] Return: tuple: - context: shape: [batch_size, tgt_len, source_size] - probs: shape: [batch_size, tgt_len, src_len] """ device = hidden_states.device tgt_len = hidden_states.size(1) batch_size, src_len, _ = encoder_outputs.size() energy = self.score(hidden_states, encoder_outputs) p_select = torch.sigmoid(energy + self.gaussian_noise(energy.size()).to(device)) cumprod_1_minus_p = self.safe_cumprod(1 - p_select) if previous_probs is None: probs = torch.zeros(batch_size, tgt_len, src_len).to(device) probs[:, :, 0] = torch.ones(batch_size, tgt_len).to(device) else: probs = p_select * cumprod_1_minus_p * torch.cumsum(previous_probs / cumprod_1_minus_p, dim=-1) encoder_masks = encoder_masks.unsqueeze(1).repeat(1, tgt_len, 1) probs = probs * encoder_masks normalization_factor = probs.sum(-1, keepdim=True) + 1e-12 probs = probs / normalization_factor context = probs.bmm(encoder_outputs) return context, probs
[docs] def hard(self, hidden_states, encoder_outputs, encoder_masks, previous_probs=None): r""" Hard monotonic attention (Test) Args: hidden_states: shape: [batch_size, tgt_len, target_size] encoder_outputs: shape: [batch_size, src_len, source_size] encoder_masks: shape: [batch_size, src_len] previous_probs: shape: [batch_size, tgt_len, src_len] Return: tuple: - context: shape: [batch_size, tgt_len, source_size] - probs: shape: [batch_size, tgt_len, src_len] """ device = hidden_states.device tgt_len = hidden_states.size(1) batch_size, src_len, _ = encoder_outputs.size() if previous_probs is None: probs = torch.zeros(batch_size, tgt_len, src_len).to(device) probs[:, :, 0] = torch.ones(batch_size, tgt_len).to(device) else: energy = self.score(hidden_states, encoder_outputs) # Hard Sigmoid # Attend when monotonic energy is above threshold (Sigmoid > 0.5) above_threshold = (energy > 0).float() p_select = above_threshold * torch.cumsum(previous_probs, dim=-1) probs = p_select * self.exclusive_cumprod(1 - p_select) # Not attended => attend at last encoder output # Assume that encoder outputs are not padded attended = probs.sum(dim=-1) for batch_i in range(batch_size): if not attended[batch_i]: probs[batch_i, -1] = 1 encoder_masks = encoder_masks.unsqueeze(1).repeat(1, tgt_len, 1) probs = probs * encoder_masks normalization_factor = probs.sum(-1, keepdim=True) + 1e-12 probs = probs / normalization_factor context = probs.bmm(encoder_outputs) return context, probs
[docs]class MultiHeadAttention(torch.nn.Module): r"""Multi-head Attention is proposed in the following paper: Attention Is All You Need. Reference: """ def __init__(self, embedding_size, num_heads, attn_weight_dropout_ratio=0.0, return_distribute=False): super(MultiHeadAttention, self).__init__() self.embedding_size = embedding_size self.num_heads = num_heads self.head_size = embedding_size // num_heads assert self.head_size * num_heads == self.embedding_size, "embedding size must be divisible by num_heads" self.scaling = self.head_size ** -0.5 # d_k ** -0.5 self.query_proj = nn.Linear(embedding_size, embedding_size) self.key_proj = nn.Linear(embedding_size, embedding_size) self.value_proj = nn.Linear(embedding_size, embedding_size) self.out_proj = nn.Linear(embedding_size, embedding_size) self.weight_dropout = nn.Dropout(attn_weight_dropout_ratio) self.return_distribute = return_distribute self.reset_parameters()
[docs] def reset_parameters(self): nn.init.normal_(self.query_proj.weight, std=0.02) nn.init.normal_(self.key_proj.weight, std=0.02) nn.init.normal_(self.value_proj.weight, std=0.02) nn.init.normal_(self.out_proj.weight, std=0.02) nn.init.constant_(self.query_proj.bias, 0.) nn.init.constant_(self.key_proj.bias, 0.) nn.init.constant_(self.value_proj.bias, 0.) nn.init.constant_(self.out_proj.bias, 0.)
[docs] def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): r""" Multi-head attention Args: query: shape: [batch_size, tgt_len, embedding_size] key and value: shape: [batch_size, src_len, embedding_size] key_padding_mask: shape: [batch_size, src_len] attn_mask: shape: [batch_size, tgt_len, src_len] Return: tuple: - attn_repre: shape: [batch_size, tgt_len, embedding_size] - attn_weights: shape: [batch_size, tgt_len, src_len] """ batch_size, tgt_len, embedding_size = query.size() src_len = key.size(1) assert key.size() == value.size() q = self.query_proj(query) * self.scaling k = self.key_proj(key) v = self.value_proj(value) q = q.view(batch_size, tgt_len, self.num_heads, self.head_size).permute(0, 2, 1, 3) k = k.view(batch_size, src_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) v = v.view(batch_size, src_len, self.num_heads, self.head_size).permute(0, 2, 1, 3) attn_weights = torch.matmul(q, k) assert list(attn_weights.size()) == [batch_size, self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_weights.masked_fill_(attn_mask.unsqueeze(0).unsqueeze(1), float('-inf')) if key_padding_mask is not None: attn_weights.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) attn_weights_ = torch.log_softmax(attn_weights, -1) attn_weights = self.weight_dropout(F.softmax(attn_weights, dim=-1)) attn_repre = torch.matmul(attn_weights, v) assert list(attn_repre.size()) == [batch_size, self.num_heads, tgt_len, self.head_size] attn_repre = attn_repre.transpose(1, 2).contiguous().view(batch_size, tgt_len, embedding_size) attn_repre = self.out_proj(attn_repre) # maximum attention weight over heads attn_weights, _ = attn_weights.max(dim=1) if self.return_distribute: return attn_repre, attn_weights, attn_weights_ else: return attn_repre, attn_weights
[docs]class SelfAttentionMask(torch.nn.Module): def __init__(self, init_size=100): super(SelfAttentionMask, self).__init__() self.weights = SelfAttentionMask.get_mask(init_size)
[docs] @staticmethod def get_mask(size): weights = torch.ones((size, size), dtype=torch.uint8).triu_(1) # above the diagonal == 1 return weights
[docs] def forward(self, size): if self.weights is None or size > self.weights.size(0): self.weights = SelfAttentionMask.get_mask(size) masks = self.weights[:size, :size].detach() return masks