Skip to content

Commit

Permalink
Introduce LocalAttentionState (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Oct 18, 2023
1 parent e9abbb3 commit c076183
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 30 deletions.
11 changes: 10 additions & 1 deletion bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ @misc{https://doi.org/10.48550/arxiv.2002.05202
primaryClass={cs.LG}
}

@misc{https://doi.org/10.48550/arxiv.2004.05150,
title={Longformer: The Long-Document Transformer},
author={Iz Beltagy and Matthew E. Peters and Arman Cohan},
year={2020},
eprint={2004.05150},
archivePrefix={arXiv},
primaryClass={cs.CL}
}

@misc{https://doi.org/10.48550/arxiv.2005.08100,
doi = {10.48550/ARXIV.2005.08100},
url = {https://arxiv.org/abs/2005.08100},
Expand Down Expand Up @@ -232,4 +241,4 @@ @misc{https://doi.org/10.48550/arXiv.2307.09288
eprint={2307.09288},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
}
6 changes: 6 additions & 0 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
from fairseq2.nn.transformer.multihead_attention import (
GlobalAttentionState as GlobalAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
LocalAttentionState as LocalAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
LocalAttentionStateFactory as LocalAttentionStateFactory,
)
from fairseq2.nn.transformer.multihead_attention import (
MultiheadAttention as MultiheadAttention,
)
Expand Down
154 changes: 125 additions & 29 deletions src/fairseq2/nn/transformer/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ def forward(
state = self.state_factory(k, v, max_seq_len=k.size(2))

state_bag.set_state(self, state)
else:
# k: (N, H_kv, S_kv, K_h)
# v: (N, H_kv, S_kv, V_h)
k, v = state.get()
else:
if key_padding_mask is not None:
raise ValueError(
Expand All @@ -444,9 +448,9 @@ def forward(
else:
state.append(k, v)

# k: (N, H_kv, S_kv, K_h)
# v: (N, H_kv, S_kv, V_h)
k, v = state.get()
# k: (N, H_kv, S_kv, K_h)
# v: (N, H_kv, S_kv, V_h)
k, v = state.get()

# With Grouped Query Attention, each key/value head is repeated.
if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1:
Expand Down Expand Up @@ -558,23 +562,23 @@ def init_output_projection(proj: Linear) -> None:


class AttentionState(IncrementalState):
"""Holds the state of a :class:`MultiheadAttention` module during
incremental decoding."""
"""Holds the projected keys and values of a :class:`MultiheadAttention`
module during incremental decoding."""

@abstractmethod
def append(self, k: Tensor, v: Tensor) -> None:
"""Update the state with ``k``, ``v``, and ``key_padding_mask``.
"""Update the state with ``k`` and ``v``.
:param k:
The projected keys. *Shape:* :math:`(N,H,S_{stp},K_{proj})`, where
:math:`N` is the batch size, :math:`H` is the number of heads,
:math:`S_{stp}` is the number of steps (e.g. 1), and :math:`K_{proj}`
is the projected key size.
The projected keys of the current step. *Shape:*
:math:`(N,H,1,K_{proj})`, where :math:`N` is the batch size,
:math:`H` is the number of heads, :math:`1` is the step length, and
:math:`K_{proj}` is the projected key size.
:param v:
The projected values. *Shape:* :math:`(N,H,S_{stp},V_{proj})`, where
:math:`N` is the batch size, :math:`H` is the number of heads,
:math:`S_{stp}` is the number of steps (e.g. 1), and :math:`V_{proj}`
is the projected value size.
The projected values of the current step. *Shape:*
:math:`(N,H,1,V_{proj})`, where :math:`N` is the batch size,
:math:`H` is the number of heads, :math:`1` is the step length, and
:math:`V_{proj}` is the projected value size.
"""

@abstractmethod
Expand Down Expand Up @@ -603,42 +607,43 @@ def __call__(self, k: Tensor, v: Tensor, max_seq_len: int) -> AttentionState:

@final
class GlobalAttentionState(AttentionState):
"""Holds the projected keys and values of a :class:`MultiheadAttention`
"""Holds the past projected keys and values of a :class:`MultiheadAttention`
module during incremental decoding."""

seq_len: int
"""The current sequence length of :attr:`k` and :attr:`v`."""

k: Tensor
"""The projected keys accumulated from the past decoding steps. *Shape:*
:math:`(N,H,S,K_{proj})`, where :math:`N` is the batch size, :math:`H` is
the number of heads, :math:`S` is the reserved sequence length capacity, and
:math:`K_{proj}` is the projected key size."""
:math:`(N,H,S_{res},K_{proj})`, where :math:`N` is the batch size, :math:`H`
is the number of heads, :math:`S_{res}` is the reserved sequence length
capacity, and :math:`K_{proj}` is the projected key size."""

v: Tensor
"""The projected values accumulated from the past decoding steps. *Shape:*
:math:`(N,H,S,V_{proj})`, where :math:`N` is the batch size, :math:`H` is
the number of heads, :math:`S` is the reserved sequence length capacity, and
:math:`V_{proj}` is the projected value size."""
:math:`(N,H,S_{res},V_{proj})`, where :math:`N` is the batch size, :math:`H`
is the number of heads, :math:`S_{res}` is the reserved sequence length
capacity, and :math:`V_{proj}` is the projected value size."""

def __init__(self, k: Tensor, v: Tensor, max_seq_len: int) -> None:
batch_size, num_heads, _, head_dim = k.shape

self.seq_len = 0
batch_size, num_heads, seq_len, head_dim = k.shape

self.k = k.new_empty((batch_size, num_heads, max_seq_len, head_dim))
self.v = v.new_empty((batch_size, num_heads, max_seq_len, head_dim))

self.append(k, v)
self.k[:, :, :seq_len] = k
self.v[:, :, :seq_len] = v

self.seq_len = seq_len

@finaloverride
def append(self, k: Tensor, v: Tensor) -> None:
start, end = self.seq_len, self.seq_len + k.size(2)
pos = self.seq_len

self.k[:, :, start:end] = k
self.v[:, :, start:end] = v
self.k[:, :, pos : pos + 1] = k
self.v[:, :, pos : pos + 1] = v

self.seq_len = end
self.seq_len += 1

@finaloverride
def get(self) -> Tuple[Tensor, Tensor]:
Expand All @@ -653,6 +658,97 @@ def reorder(self, new_order: Tensor) -> None:
self.v = self.v.index_select(0, new_order)


@final
class LocalAttentionState(AttentionState):
"""Holds the past :attr:`attn_window_len` projected keys and values of a
:class:`MultiheadAttention` module during incremental decoding.
The intended use of this class is with Sliding Window Attention as described
in :cite:t:`https://doi.org/10.48550/arxiv.2004.05150`.
"""

seq_len: int
"""The current sequence length of :attr:`k` and :attr:`v`."""

attn_window_len: int
"""The attention window length."""

k: Tensor
"""The projected keys accumulated from the past :attr:`attn_window_len`
decoding steps. *Shape:* :math:`(N,H,S_{wnd},K_{proj})`, where :math:`N` is
the batch size, :math:`H` is the number of heads, :math:`S_{wnd}` is the
attention window length (i.e. :attr:`attn_window_len`), and :math:`K_{proj}`
is the projected key size."""

v: Tensor
"""The projected values accumulated from the past :attr:`attn_window_len`
decoding steps. *Shape:* :math:`(N,H,S_{wnd},V_{proj})`, where :math:`N` is
the batch size, :math:`H` is the number of heads, :math:`S_{wnd}` is the
attention window length (i.e. :attr:`attn_window_len`), and :math:`V_{proj}`
is the projected value size."""

def __init__(
self, k: Tensor, v: Tensor, max_seq_len: int, attn_window_len: int
) -> None:
batch_size, num_heads, seq_len, head_dim = k.shape

self.attn_window_len = min(max_seq_len, attn_window_len)

self.k = k.new_empty((batch_size, num_heads, self.attn_window_len, head_dim))
self.v = v.new_empty((batch_size, num_heads, self.attn_window_len, head_dim))

pos = min(seq_len, self.attn_window_len)

self.k[:, :, :pos] = k[:, :, -pos:]
self.v[:, :, :pos] = v[:, :, -pos:]

self.seq_len = seq_len

@finaloverride
def append(self, k: Tensor, v: Tensor) -> None:
if self.seq_len >= self.attn_window_len:
self.k = torch.roll(self.k, shifts=-1, dims=2)
self.v = torch.roll(self.v, shifts=-1, dims=2)

pos = self.attn_window_len - 1
else:
pos = self.seq_len

self.k[:, :, pos : pos + 1] = k
self.v[:, :, pos : pos + 1] = v

self.seq_len += 1

@finaloverride
def get(self) -> Tuple[Tensor, Tensor]:
k = self.k[:, :, : self.seq_len]
v = self.v[:, :, : self.seq_len]

return k, v

@finaloverride
def reorder(self, new_order: Tensor) -> None:
self.k = self.k.index_select(0, new_order)
self.v = self.v.index_select(0, new_order)


class LocalAttentionStateFactory:
"""Constructs instances of :class:`LocalAttentionState`."""

def __init__(self, attn_window_len: int) -> None:
"""
:param attn_window_len:
The attention window length.
"""
self.attn_window_len = attn_window_len

def __call__(self, k: Tensor, v: Tensor, max_seq_len: int) -> LocalAttentionState:
return LocalAttentionState(k, v, max_seq_len, self.attn_window_len)

def __repr__(self) -> str:
return f"LocalAttentionStateFactory(attn_window_len={self.attn_window_len})"


@final
class StaticAttentionState(AttentionState):
"""Holds the static projected keys and values (e.g. encoder-decoder) of a
Expand Down

0 comments on commit c076183

Please sign in to comment.