Skip to content

Commit

Permalink
Introduce ShawRelativePositionSDPA. (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
kauterry authored Oct 9, 2023
1 parent 4812198 commit 199cf93
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 1 deletion.
9 changes: 9 additions & 0 deletions bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ @misc{https://doi.org/10.48550/arxiv.1706.03762
copyright = {arXiv.org perpetual, non-exclusive license}
}

@misc{https://doi.org/10.48550/arxiv.1803.02155,
title={Self-Attention with Relative Position Representations},
author={Peter Shaw and Jakob Uszkoreit and Ashish Vaswani},
year={2018},
eprint={1803.02155},
archivePrefix={arXiv},
primaryClass={cs.CL}
}

@misc{https://doi.org/10.48550/arxiv.1901.02860,
title={Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context},
author={Zihang Dai and Zhilin Yang and Yiming Yang and Jaime Carbonell and Quoc V. Le and Ruslan Salakhutdinov},
Expand Down
2 changes: 2 additions & 0 deletions src/fairseq2/models/w2vbert/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _encoder_600m() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_pos_sdpa_config=None,
)


Expand Down Expand Up @@ -78,6 +79,7 @@ def _encoder_300m() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_pos_sdpa_config=None,
)


Expand Down
33 changes: 32 additions & 1 deletion src/fairseq2/models/wav2vec2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MultiheadAttention,
RelativePositionalEncoding,
RelativePositionSDPA,
ShawRelativePositionSDPA,
StandardFeedForwardNetwork,
StandardMultiheadAttention,
StandardTransformerEncoder,
Expand All @@ -46,6 +47,20 @@
from fairseq2.typing import DataType, Device


@dataclass
class ShawRelativePositionSDPAConfig:
"""Holds the configuration of the :class:ShawRelativePositionSDPA module."""

max_left_rel_pos: int
"""The left clipping value for relative positions."""

max_right_rel_pos: Optional[int]
"""The right clipping value for relative positions."""

use_rel_pos_values: bool = False
"""If True, also uses relative position values to compute relative attention."""


@dataclass
class Wav2Vec2EncoderConfig:
"""Holds the configuration of a wav2vec 2.0 encoder."""
Expand Down Expand Up @@ -97,7 +112,7 @@ class Wav2Vec2EncoderConfig:
sample_fbank_every_k: int

# Position Encoder
pos_encoder_type: Literal["conv", "relative", "rotary"]
pos_encoder_type: Literal["conv", "relative", "relative_shaw", "rotary"]
"""The type of position encoder."""

# Convolutional Position Encoder
Expand Down Expand Up @@ -146,6 +161,9 @@ class Wav2Vec2EncoderConfig:
conv_norm_type: Literal["batch_norm", "layer_norm"]
"""The type of normalization to use in the Conformer convolution module."""

shaw_rel_pos_sdpa_config: Optional[ShawRelativePositionSDPAConfig]
"""The parameters for ShawRelativePositionSDPA."""


def _encoder_base() -> Wav2Vec2EncoderConfig:
layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
Expand Down Expand Up @@ -179,6 +197,7 @@ def _encoder_base() -> Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size=0,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
shaw_rel_pos_sdpa_config=None,
)


Expand Down Expand Up @@ -369,6 +388,18 @@ def build_attention(self) -> MultiheadAttention:
device=self.device,
dtype=self.dtype,
)
elif self.config.pos_encoder_type == "relative_shaw":
sdpa_config = self.config.shaw_rel_pos_sdpa_config
sdpa = ShawRelativePositionSDPA(
self.config.model_dim,
self.config.num_encoder_attn_heads,
sdpa_config.max_left_rel_pos,
max_right_rel_pos=sdpa_config.max_right_rel_pos,
use_rel_pos_values=sdpa_config.use_rel_pos_values,
attn_dropout_p=self.config.attn_dropout_p,
device=self.device,
dtype=self.dtype,
)
else:
sdpa = create_default_sdpa(self.config.attn_dropout_p)

Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,6 @@
from fairseq2.nn.transformer.relative_attention import (
RelativePositionSDPA as RelativePositionSDPA,
)
from fairseq2.nn.transformer.shaw_attention import (
ShawRelativePositionSDPA as ShawRelativePositionSDPA,
)
168 changes: 168 additions & 0 deletions src/fairseq2/nn/transformer/shaw_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, final

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import dropout, softmax

from fairseq2.nn.embedding import StandardEmbedding
from fairseq2.nn.transformer.attention import SDPA
from fairseq2.typing import DataType, Device, finaloverride


@final
class ShawRelativePositionSDPA(SDPA):
"""Computes relative position scaled dot-product attention
as described in :cite:t:`https://doi.org/10.48550/arxiv.1803.02155`."""

model_dim: int
num_heads: int
max_left_rel_pos: int
max_right_rel_pos: Optional[int]
rel_k_embed: StandardEmbedding
rel_v_embed: Optional[StandardEmbedding]

def __init__(
self,
model_dim: int,
num_heads: int,
max_left_rel_pos: int,
*,
max_right_rel_pos: Optional[int] = None,
use_rel_pos_values: bool = False,
attn_dropout_p: float = 0.0,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> None:
"""
:param model_dim:
The dimensionality of the model.
:param: num_heads:
The number of attention heads.
:param: max_left_rel_pos:
The left clipping value for relative positions.
:param: max_right_rel_pos:
The right clipping value for relative positions.
:param: use_rel_pos_values:
If True, also uses relative position values to compute relative attention.
:param attn_dropout_p:
The dropout probability on attention weights.
"""
super().__init__(attn_dropout_p=attn_dropout_p)

if model_dim % num_heads != 0:
raise ValueError(
f"`model_dim` must be a multiple of `num_heads` ({num_heads}), but is {model_dim} instead."
)

self.model_dim = model_dim
self.num_heads = num_heads

head_dim = model_dim // num_heads

self.max_left_rel_pos = max_left_rel_pos
self.max_right_rel_pos = (
max_right_rel_pos if max_right_rel_pos is not None else max_left_rel_pos
)
num_pos = self.max_left_rel_pos + 1 + self.max_right_rel_pos

self.rel_k_embed = StandardEmbedding(
num_pos, head_dim, device=device, dtype=dtype
)

if use_rel_pos_values:
self.rel_v_embed = StandardEmbedding(
num_pos, head_dim, device=device, dtype=dtype
)
else:
self.register_module("rel_v_embed", None)

self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
nn.init.xavier_uniform_(self.rel_k_embed.weight)

if self.rel_v_embed is not None:
nn.init.xavier_uniform_(self.rel_v_embed.weight)

@finaloverride
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
*,
mask: Optional[Tensor] = None,
needs_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
if queries.ndim != 4 or keys.ndim != 4 or values.ndim != 4:
raise ValueError(
"`ShawRelativePositionSDPA` can only be used as part of a multi-head attention layer and expects its input tensors to be 4 dimensional."
)

# (N, H, S, head_dim) @ (N, H, head_dim, S_kv) = (N, H, S, S_kv)
attn_weights = torch.matmul(queries, keys.transpose(-1, -2))

query_len, kv_len = queries.size(2), keys.size(2)

# (S_kv, S_kv)
rel_pos_indices = self._rel_pos_indices(kv_len, queries.device)

# (S, S_kv, head_dim)
rel_pos_keys = self.rel_k_embed(rel_pos_indices)[-query_len:]

# (N, H, S, head_dim) @ (S, S_kv, head_dim) = (N, H, S, S_kv)
rel_attn_weights = torch.einsum("nhsm,stm->nhst", queries, rel_pos_keys)

attn_weights += rel_attn_weights

attn_weights = attn_weights * (queries.size(-1) ** -0.5)

if mask is not None:
attn_weights = attn_weights + mask

attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)

attn_weights = attn_weights.type_as(queries)

if self.training and self.attn_dropout_p > 0.0:
attn_weights = dropout(attn_weights, self.attn_dropout_p)

# (N, H, S, S_kv) @ (N, H, S_kv, head_dim) = (N, H, S, head_dim)
attn = torch.matmul(attn_weights, values)

if self.rel_v_embed is not None:
# (S, S_kv, head_dim)
rel_pos_values = self.rel_v_embed(rel_pos_indices)[-query_len:]

# (N, H, S, S_kv) @ (S, S_kv, head_dim) = (N, H, S, head_dim)
rel_attn = torch.einsum("nhst,stm->nhsm", attn_weights, rel_pos_values)

attn += rel_attn

return attn, attn_weights if needs_weights else None

def _rel_pos_indices(self, seq_len: int, device: Device) -> Tensor:
pos = torch.arange(seq_len, device=device).unsqueeze(0)
rel_dist = pos - pos.transpose(0, 1)
rel_dist = torch.clamp(rel_dist, -self.max_left_rel_pos, self.max_right_rel_pos)
return rel_dist + self.max_left_rel_pos

def extra_repr(self) -> str:
""":meta private:"""
s = super().extra_repr()

return (
f"{s}, "
f"model_dim={self.model_dim}, "
f"num_heads={self.num_heads}, "
f"max_left_rel_pos={self.max_left_rel_pos}, "
f"max_right_rel_pos={self.max_right_rel_pos}"
)

0 comments on commit 199cf93

Please sign in to comment.