From 31b7ac000c9720d7637711f400e758b88bf67fee Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Fri, 6 Oct 2023 14:45:32 -0700 Subject: [PATCH] Introduce ShawRelativePositionSDPA. --- src/fairseq2/models/conformer/convolution.py | 43 +++++ .../models/s2t_transformer/builder.py | 1 + src/fairseq2/models/w2vbert/builder.py | 2 + src/fairseq2/models/wav2vec2/builder.py | 31 +++- src/fairseq2/nn/transformer/__init__.py | 3 + .../relative_position_attention.py | 170 ++++++++++++++++++ 6 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 src/fairseq2/nn/transformer/relative_position_attention.py diff --git a/src/fairseq2/models/conformer/convolution.py b/src/fairseq2/models/conformer/convolution.py index 9b59ce6fd..a1baf813b 100644 --- a/src/fairseq2/models/conformer/convolution.py +++ b/src/fairseq2/models/conformer/convolution.py @@ -15,6 +15,49 @@ from fairseq2.typing import DataType, Device +class CausalConv1d(Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + padding: str = "same", + dilation: int = 1, + groups: int = 1, + bias: bool = False, + padding_mode: str = "zeros", + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ): + if padding != "same": + raise ValueError( + f"We currently only support 'same' padding, and not {padding} padding." + ) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + # No padding, since we are manually padding. + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + # Pad the last dimension entirely on the left to make the convolution causal. + self.pad = (kernel_size - 1, 0) + + def forward(self, x: Tensor) -> Tensor: + x = F.pad(x, self.pad) + return super().forward(x) + + class ConformerConvolution(Module): """Represents a Conformer convolution module as described in :cite:t:`https://doi.org/10.48550/arxiv.2005.08100`.""" diff --git a/src/fairseq2/models/s2t_transformer/builder.py b/src/fairseq2/models/s2t_transformer/builder.py index 928a8dd54..3bb8333dd 100644 --- a/src/fairseq2/models/s2t_transformer/builder.py +++ b/src/fairseq2/models/s2t_transformer/builder.py @@ -380,6 +380,7 @@ def build_conformer_block(self) -> TransformerEncoderLayer: conv = ConformerConvolution( self.config.model_dim, self.config.depthwise_conv_kernel_size, + norm_type=self.config.conv_norm_type, device=self.device, dtype=self.dtype, ) diff --git a/src/fairseq2/models/w2vbert/builder.py b/src/fairseq2/models/w2vbert/builder.py index 03e784582..967637e60 100644 --- a/src/fairseq2/models/w2vbert/builder.py +++ b/src/fairseq2/models/w2vbert/builder.py @@ -45,6 +45,7 @@ def _encoder_600m() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=31, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_position_sdpa_config=None, ) @@ -78,6 +79,7 @@ def _encoder_300m() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=31, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_position_sdpa_config=None, ) diff --git a/src/fairseq2/models/wav2vec2/builder.py b/src/fairseq2/models/wav2vec2/builder.py index 297b9c482..aa9bde0ed 100644 --- a/src/fairseq2/models/wav2vec2/builder.py +++ b/src/fairseq2/models/wav2vec2/builder.py @@ -34,6 +34,7 @@ MultiheadAttention, RelativePositionalEncoding, RelativePositionSDPA, + ShawRelativePositionSDPA, StandardFeedForwardNetwork, StandardMultiheadAttention, StandardTransformerEncoder, @@ -46,6 +47,18 @@ from fairseq2.typing import DataType, Device +@dataclass +class ShawRelativePositionSDPAConfig: + max_left_rel_position: int + """The left clipping value for relative positions.""" + + max_right_rel_position: Optional[int] + """The right clipping value for relative positions.""" + + use_rel_position_values: bool = False + """Whether to use relative position values to compute relative attention.""" + + @dataclass class Wav2Vec2EncoderConfig: """Holds the configuration of a wav2vec 2.0 encoder.""" @@ -97,7 +110,7 @@ class Wav2Vec2EncoderConfig: sample_fbank_every_k: int # Position Encoder - pos_encoder_type: Literal["conv", "relative", "rotary"] + pos_encoder_type: Literal["conv", "relative", "relative_shaw", "rotary"] """The type of position encoder.""" # Convolutional Position Encoder @@ -146,6 +159,9 @@ class Wav2Vec2EncoderConfig: conv_norm_type: Literal["batch_norm", "layer_norm"] """The type of normalization to use in the Conformer convolution module.""" + shaw_rel_position_sdpa_config: Optional[ShawRelativePositionSDPAConfig] + """The parameters for ShawRelativePositionSDPA.""" + def _encoder_base() -> Wav2Vec2EncoderConfig: layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 @@ -179,6 +195,7 @@ def _encoder_base() -> Wav2Vec2EncoderConfig: depthwise_conv_kernel_size=0, causal_depthwise_conv=False, conv_norm_type="batch_norm", + shaw_rel_position_sdpa_config=None, ) @@ -369,6 +386,18 @@ def build_attention(self) -> MultiheadAttention: device=self.device, dtype=self.dtype, ) + elif self.config.pos_encoder_type == "relative_shaw": + sdpa_config = self.config.shaw_rel_position_sdpa_config + sdpa = ShawRelativePositionSDPA( + self.config.model_dim, + self.config.num_encoder_attn_heads, + sdpa_config.max_left_rel_position, + max_right_rel_position=sdpa_config.max_right_rel_position, + use_rel_position_values=sdpa_config.use_rel_position_values, + attn_dropout_p=self.config.attn_dropout_p, + device=self.device, + dtype=self.dtype, + ) else: sdpa = create_default_sdpa(self.config.attn_dropout_p) diff --git a/src/fairseq2/nn/transformer/__init__.py b/src/fairseq2/nn/transformer/__init__.py index 0fae90fa7..a28b4716e 100644 --- a/src/fairseq2/nn/transformer/__init__.py +++ b/src/fairseq2/nn/transformer/__init__.py @@ -86,3 +86,6 @@ from fairseq2.nn.transformer.relative_attention import ( RelativePositionSDPA as RelativePositionSDPA, ) +from fairseq2.nn.transformer.relative_position_attention import ( + ShawRelativePositionSDPA as ShawRelativePositionSDPA, +) diff --git a/src/fairseq2/nn/transformer/relative_position_attention.py b/src/fairseq2/nn/transformer/relative_position_attention.py new file mode 100644 index 000000000..9c780a8ac --- /dev/null +++ b/src/fairseq2/nn/transformer/relative_position_attention.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Embedding +from torch.nn.functional import dropout, softmax + +from fairseq2.nn.transformer.attention import SDPA +from fairseq2.typing import DataType, Device, finaloverride + + +@final +class ShawRelativePositionSDPA(SDPA): + """Computes relative position scaled dot-product attention + as described in :cite:t:`https://arxiv.org/pdf/1803.02155v2.pdf`.""" + + model_dim: int + num_heads: int + max_left_rel_position: int + max_right_rel_position: Optional[int] + rel_k_embedding: Embedding + rel_v_embedding: Optional[Embedding] + device: Optional[Device] + + def __init__( + self, + model_dim: int, + num_heads: int, + max_left_rel_position: int, + *, + max_right_rel_position: Optional[int] = None, + use_rel_position_values: bool = False, + attn_dropout_p: float = 0.0, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param model_dim: + The dimensionality of the model. + :param: num_heads: + The number of attention heads. + :param: max_left_rel_position: + The left clipping value for relative positions. + :param: max_right_rel_position: + The right clipping value for relative positions. + :param: use_rel_position_values: + Whether to use relative position values to compute relative attention. + :param attn_dropout_p: + The dropout probability on attention weights. + """ + super().__init__(attn_dropout_p=attn_dropout_p) + + if model_dim % num_heads != 0: + raise ValueError( + f"`model_dim` must be a multiple of `num_heads` ({num_heads}), but is {model_dim} instead." + ) + + self.model_dim = model_dim + self.num_heads = num_heads + + head_dim = model_dim // num_heads + + self.max_left_rel_position = max_left_rel_position + self.max_right_rel_position = ( + max_right_rel_position + if max_right_rel_position is not None + else max_left_rel_position + ) + num_positions = self.max_left_rel_position + 1 + self.max_right_rel_position + + self.rel_k_embedding = Embedding( + num_positions, head_dim, device=device, dtype=dtype + ) + + if use_rel_position_values: + self.rel_v_embedding = Embedding( + num_positions, head_dim, device=device, dtype=dtype + ) + else: + self.register_module("rel_v_embedding", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.xavier_uniform_(self.rel_k_embedding.weight) + if self.rel_v_embedding is not None: + nn.init.xavier_uniform_(self.rel_v_embedding.weight) + + def rel_position_indices(self, seq_len: int) -> Tensor: + positions = torch.arange(seq_len).unsqueeze(0) + rel_dist = positions - positions.t() + rel_dist = torch.clamp( + rel_dist, -self.max_left_rel_position, self.max_right_rel_position + ) + return rel_dist + self.max_left_rel_position + + @finaloverride + def forward( + self, + queries: Tensor, + keys: Tensor, + values: Tensor, + *, + mask: Optional[Tensor] = None, + needs_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + if queries.ndim != 4 or keys.ndim != 4 or values.ndim != 4: + raise ValueError( + "`ShawRelativePositionSDPA` can only be used as part of a multi-head attention layer and expects its input tensors to be 4 dimensional." + ) + + # (N, H, S, head_dim) @ (N, H, head_dim, S_kv) = (N, H, S, S_kv) + attn_weights = torch.matmul(queries, keys.transpose(-1, -2)) + + query_length, kv_length = queries.shape[2], keys.shape[2] + + # (S_kv, S_kv) + rel_position_indices = self.rel_position_indices(kv_length) + + rel_position_indices = rel_position_indices.to(device=queries.device) + + # (S, S_kv, head_dim) + rel_position_keys = self.rel_k_embedding(rel_position_indices)[-query_length:] + + # (N, H, S, head_dim) @ (S, S_kv, head_dim) = (N, H, S, S_kv) + rel_attn_weights = torch.einsum("nhsm,stm->nhst", queries, rel_position_keys) + + attn_weights += rel_attn_weights + + attn_weights = attn_weights * (queries.size(-1) ** -0.5) + + if mask is not None: + attn_weights = attn_weights + mask + + attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32) + + attn_weights = attn_weights.type_as(queries) + + if self.training and self.attn_dropout_p > 0.0: + attn_weights = dropout(attn_weights, self.attn_dropout_p) + + # (N, H, S, S_kv) @ (N, H, S_kv, head_dim) = (N, H, S, head_dim) + attn = torch.matmul(attn_weights, values) + + if self.rel_v_embedding is not None: + # (S, S_kv, head_dim) + rel_position_values = self.rel_v_embedding(rel_position_indices)[ + -query_length: + ] + + # (N, H, S, S_kv) @ (S, S_kv, head_dim) = (N, H, S, head_dim) + rel_attn = torch.einsum("nhst,stm->nhsm", attn_weights, rel_position_values) + + attn += rel_attn + + return attn, attn_weights if needs_weights else None + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + return f"{s}, model_dim={self.model_dim}, num_heads={self.num_heads}"