From 53f2cbba9ccdca024df5cf138a00a9565f1d9c36 Mon Sep 17 00:00:00 2001 From: apghml <143655008+apghml@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:53:18 -0800 Subject: [PATCH] Refactor attention bias/mask/segments. (#873) --- axlearn/common/adapter_torch_test.py | 5 +- axlearn/common/attention.py | 259 +----- axlearn/common/attention_bias.py | 742 ++++++++++++++++++ axlearn/common/attention_bias_test.py | 280 +++++++ axlearn/common/attention_test.py | 96 +-- axlearn/common/bert_test.py | 3 +- axlearn/common/decoder_test.py | 2 +- axlearn/common/dit_test.py | 2 +- axlearn/common/encoder_decoder.py | 2 +- axlearn/common/eval_retrieval.py | 2 +- axlearn/common/eval_retrieval_test.py | 2 +- .../common/flash_attention/gpu_attention.py | 6 +- axlearn/common/flash_attention/layer.py | 144 +--- axlearn/common/flash_attention/layer_test.py | 46 +- .../common/flash_attention/tpu_attention.py | 132 ++-- .../tpu_attention_benchmark.py | 21 +- .../flash_attention/tpu_attention_test.py | 21 +- axlearn/common/flash_attention/utils.py | 162 ++-- axlearn/common/metrics_text_dual_encoder.py | 2 +- axlearn/common/multiway_transformer_test.py | 3 +- axlearn/common/poolings.py | 3 +- axlearn/common/splade.py | 2 +- axlearn/common/ssm_test.py | 2 +- axlearn/vision/attention.py | 2 +- 24 files changed, 1387 insertions(+), 554 deletions(-) create mode 100644 axlearn/common/attention_bias.py create mode 100644 axlearn/common/attention_bias_test.py diff --git a/axlearn/common/adapter_torch_test.py b/axlearn/common/adapter_torch_test.py index 3f1e06c6e..4350604ce 100644 --- a/axlearn/common/adapter_torch_test.py +++ b/axlearn/common/adapter_torch_test.py @@ -13,6 +13,7 @@ from absl.testing import absltest, parameterized from jax import numpy as jnp +from axlearn.common import attention_bias from axlearn.common.adapter_torch import ( NEG_INF, AdapterCausalLmModelBuilder, @@ -355,8 +356,8 @@ def test_transformer_attention_layer_forward(self, structure: str, norm: str): rng = np.random.RandomState(123) target = rng.randn(2, 7, target_dim).astype(np.float32) - attention_logit_biases = np.zeros(target.shape[:-1]).astype(bool)[:, :, None] - attention_logit_biases[:, -2:] = True + attention_logit_biases = np.zeros(target.shape[:-1]).astype(float)[:, :, None] + attention_logit_biases[:, -2:] = attention_bias.NEG_INF torch_inputs = { "target": torch.as_tensor(target), "attention_logit_biases": torch.as_tensor(attention_logit_biases), diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b1e071e82..26aceb797 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -33,13 +33,17 @@ """Attention layers with pjit partition specs. On `attention_logit_biases`: -* A biases tensor can have one of the following shapes: +* For methods that take a tensor, a biases Tensor can have one of the following shapes: * [target_length, source_length] * [batch, target_length, source_length] * [batch, num_heads, target_length, source_length]. * Each value represents a bias to be added to the attention logits (therefore a -inf represents a disconnected position pair). * biases=None represents an all-zero tensor, i.e., all position pairs are connected. +* For methods that take a BaseAttentionBias, the value() will always be None or a 4d Tensor with + the above semantics. + +TODO(apghml) Convert everything to take an instance of BaseAttentionBias rather than a Tensor. On `segment_ids`: * A tensor of shape [batch, target_length] with values in [0, num_segments]. @@ -65,6 +69,17 @@ from jax.core import Primitive from axlearn.common import ops, param_init +from axlearn.common.attention_bias import ( + NEG_INF, + BaseAttentionBias, + CausalAttentionBias, + MaskFn, + MaskFnAttentionBias, + SegmentIdAttentionBias, + as_attention_bias, + causal_mask, + make_segment_mask, +) from axlearn.common.base_layer import ( BaseLayer, FactorizationSpec, @@ -116,8 +131,6 @@ split_prng_key, ) -NEG_INF = -1e15 - class ForwardMode(enum.Enum): """ForwardMode describes the type of computation to be done in a forward pass through a layer. @@ -303,66 +316,6 @@ def extend_step( raise NotImplementedError(type(self)) -def make_causal_biases(seq_len: int) -> Tensor: - """Generates attention logit biases for causal masking. - - Args: - seq_len: Sequence length. - - Returns: - A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf if i < j, - 0 otherwise. - """ - # TODO(sneha): support batching - return bool_to_bias(causal_mask(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :])) - - -def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor: - """Generates attention logit biases for sliding window attention. - - Args: - seq_len: Sequence length. - - Returns: - A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf - if i - j > sliding_window_size or i < j, 0 otherwise. - """ - mask_fn = sliding_window_causal_mask(sliding_window_size) - return bool_to_bias(mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :])) - - -def bool_to_bias(mask: Tensor) -> Tensor: - """Converts a bool mask tensor to a bias mask tensor. - - Maps: - 0 -> -NEG_INF - 1 -> 0. - """ - if mask.dtype != jnp.bool: - raise ValueError("mask must be a Boolean tensor.") - return (~mask) * NEG_INF - - -def make_segment_mask(*, source_segments: Tensor, target_segments: Tensor) -> Tensor: - """Generates attention logit biases given the segment ids. - - ... such that positions belonging to different segments cannot attend to each other. - - Args: - source_segments: An integer tensor of shape [batch, ..., source_length]. - target_segments: An integer tensor of shape [batch, ..., target_length]. - - Returns: - A float Tensor of shape [batch, 1, ..., target_length, source_length] where the - value at [..., i, j] = 0 if target_segments[..., i] == source_segments[..., j], or -inf - otherwise. - """ - target_segments = jnp.expand_dims(target_segments, -1) - source_segments = jnp.expand_dims(source_segments, -2) - res = (jax.lax.ne(source_segments, target_segments) * NEG_INF)[:, None, ...] - return res - - class LearnedPositionalEmbedding(BaseLayer): """TODO(ruoming): Remove LearnedPositionalEmbedding. We can just use the Embedding layer.""" @@ -1555,83 +1508,6 @@ def default_scale_factor_config() -> InstantiableConfig[ScaleFn]: return config_for_function(constant_scale_fn).set(value=1) -class MaskFn(Protocol): - """A broadcastable function for computing a boolean logit mask.""" - - def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor: - """Returns a bool Tensor of whether the query token at `query_position` should attend - to the key token at `key_position`. - - Implementations have the following contract: - * Must support scalar arguments. - * If given non-scalar arguments of the same shape, the result must be the same as - applying the function elementwise over these arugments. I.e., - ``` - x = f(jnp.asarray([1,2]), jnp.asarray([3,4])) - assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None] - ``` - * If given non-scalar arguments of different shapes, the result must be the same if we - first broadcast the arguments against each other to make them have the same shape. - * Beyond requiring broadcastability, must not impose any constraints on the shapes of its - arguments. - - Args: - query_position: The index in the sequence of query vectors. - key_position: The index in the sequence of key vectors. - - Returns: - Whether the query and key vectors with the given index should attend to one another. - True means they should attend. False means they should not. - The shape is the same as the shape obtained after broadcasting the inputs against each - other. - """ - - -def _composite_masks(op: Callable[[Tensor, Tensor], Tensor], *mask_fns: ConfigOr[MaskFn]): - if len(mask_fns) == 0: - raise RuntimeError(f"Input must not be empty: {mask_fns}") - - def mask(query_position: Tensor, key_position: Tensor): - fns = [maybe_instantiate(arg) for arg in mask_fns] - result = fns[0](query_position, key_position) - for mask in fns[1:]: - result = op(result, mask(query_position, key_position)) - return result - - return mask - - -def or_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn: - """Returns a MaskFn that's the union of provided MaskFn's.""" - return _composite_masks(jnp.logical_or, *mask_fns) - - -def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn: - """Returns a MaskFn that's the intersection of provided MaskFn's.""" - return _composite_masks(jnp.logical_and, *mask_fns) - - -def causal_mask(query_position: Tensor, key_position: Tensor) -> Tensor: - """Returns the given entry of a causal attention mask. - - Implements the `MaskFn` protocol. - See that and `MultiheadAttention.Config.mask`. - """ - return query_position >= key_position - - -def sliding_window_causal_mask(sliding_window_size: int): - """Returns a causal MaskFn for sliding window attentions of a given window size. - - Implements the `MaskFn` protocol. - """ - - def mask(query_position: Tensor, key_position: Tensor): - return query_position - key_position <= sliding_window_size - - return and_masks(causal_mask, mask) - - class MultiheadAttention(BaseLayer): """A basic multi-head attention layer. @@ -1747,7 +1623,7 @@ def _forward_for_mode( key: Optional[Tensor] = None, value: Optional[Tensor] = None, kv_state: Optional[KVState] = None, - attention_logit_biases: Optional[Tensor] = None, + attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None, segment_ids: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, return_aux: Optional[set[str]] = None, @@ -1819,36 +1695,35 @@ def _forward_for_mode( self.vlog(3, "atten.q_proj=%s", q_proj.sum()) self.vlog(3, "atten.k_proj=%s", k_proj.sum()) self.vlog(3, "atten.v_proj=%s", v_proj.sum()) - if attention_logit_biases is not None: - if attention_logit_biases.ndim == 3: - # [batch, 1, target_length, source_length]. - attention_logit_biases = attention_logit_biases[:, None, :, :] - elif attention_logit_biases.ndim == 2: - # [1, 1, target_length, source_length]. - attention_logit_biases = attention_logit_biases[None, None, :, :] - elif attention_logit_biases.ndim != 4: - raise ValueError( - f"Invalid attention_logit_biases shape: {attention_logit_biases.shape}." - ) + attention_logit_biases = as_attention_bias(attention_logit_biases) if self._mask_fn is not None: - kv_pos = jnp.arange(k_proj.shape[1])[None, :] # [1, source_len] - query_pos = jnp.arange(q_proj.shape[1])[None] # [1, target_length] + target_positions = None if mode == ForwardMode.EXTEND_STEP: - time_step = cached_states["i_proj"]["time_step"] # [B] - # [B, target_length], target_length is often 1 for decoding, but not always. - query_pos = query_pos + time_step[:, None] - mask = self._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) - if mask is not None: - attention_logit_biases = apply_attention_logit_biases( - mask.astype(q_proj.dtype), - attention_logit_biases, + target_positions = cached_states["i_proj"]["time_step"] + if self._mask_fn is causal_mask: + # Needed for legacy flash attention implementations that don't have + # sparse mask support. + # E.g., the legacy tpu flash attention, all current gpu flash attention + # implementations. + attention_logit_biases += CausalAttentionBias( + shape=(q_proj.shape[1], k_proj.shape[1]), + target_positions=target_positions, + dtype=q_proj.dtype, ) + else: + attention_logit_biases += MaskFnAttentionBias( + self._mask_fn, + shape=(q_proj.shape[1], k_proj.shape[1]), + target_positions=target_positions, + dtype=q_proj.dtype, + ) + if segment_ids is not None: + attention_logit_biases += SegmentIdAttentionBias(segment_ids) context, probs = self._compute_attention( q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, attention_logit_biases=attention_logit_biases, - segment_ids=segment_ids, ) self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :]) self.vlog(3, "atten.context=%s", context.sum()) @@ -1865,38 +1740,13 @@ def _forward_for_mode( ) return dict(i_proj=i_proj_state), output - def _logit_biases_for_mask( - self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor - ) -> Optional[Tensor]: - """Returns the configured attention mask in the form of logit biases. - - ... or None if the implementation of _compute_attention supports applying masks natively. - - Args: - mode: The forward propagation mode, chosen from - (ForwardMode.FORWARD, ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP). - query_pos: The index in the sequence of query vectors, [1|batch, target_length]. - kv_pos: The index in the sequence of kv vectors, [1|batch, source_length]. - - Returns: - A logit bias tensor [1|batch, 1, target_length, source_length]. - """ - del mode - kv_pos = kv_pos[:, None] # [1|B, 1, source_len] - query_pos = query_pos[..., None] # [1|B, target_len, 1] - # [1|B, 1, target_len, source_len] - mask = self._mask_fn(query_pos, kv_pos)[:, None] - mask = bool_to_bias(mask) - return mask - def _compute_attention( self, *, q_proj: Tensor, k_proj: Tensor, v_proj: Tensor, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, + attention_logit_biases: BaseAttentionBias, ) -> tuple[Tensor, Tensor]: """Computes attention context and probs. @@ -1905,27 +1755,15 @@ def _compute_attention( k_proj: [batch_size, source_length, num_heads, per_head_dim]. v_proj: [batch_size, source_length, num_heads, per_head_dim]. attention_logit_biases: See ``On attention logit biases`` in the file comments. - segment_ids: See ``segment_ids`` in the file comments. Returns: The context of shape [batch_size, target_length, num_heads, per_head_dim], and probs of shape [batch, num_heads, target_length, source_length]. """ - # Merge segment ids into attention_logit_biases. - if segment_ids is not None: - if q_proj.shape[1] != k_proj.shape[1]: - raise ValueError( - "segment_ids is only supported for query and key with identical lengths." - ) - attention_logit_biases = apply_attention_logit_biases( - make_segment_mask(source_segments=segment_ids, target_segments=segment_ids), - attention_logit_biases, - ) - logits = self._compute_logits(q_proj, k_proj) logits = self._cap_logits(logits) self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) - probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases) + probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases.value()) probs = self.dropout(probs) context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") @@ -2203,7 +2041,7 @@ class SigmoidAttention(MultiheadAttention): class Config(MultiheadAttention.Config): """Configures SigmoidAttention.""" - seq_len: Required[int] = REQUIRED # Maximum sequence length used. + seq_len: Required[int] = REQUIRED # Maximum sequence length used. def _compute_attention( self, @@ -2211,26 +2049,15 @@ def _compute_attention( q_proj: Tensor, k_proj: Tensor, v_proj: Tensor, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, + attention_logit_biases: BaseAttentionBias, ) -> tuple[Tensor, Tensor]: """See `MultiheadAttention._compute_attention` for details.""" - # Merge segment ids into attention_logit_biases. - if segment_ids is not None: - if q_proj.shape[1] != k_proj.shape[1]: - raise ValueError( - "segment_ids is only supported for query and key with identical lengths." - ) - attention_logit_biases = apply_attention_logit_biases( - make_segment_mask(source_segments=segment_ids, target_segments=segment_ids), - attention_logit_biases, - ) - cfg = self.config logits = self._compute_logits(q_proj, k_proj) logits = self._cap_logits(logits) self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) + attention_logit_biases = attention_logit_biases.value() if attention_logit_biases is None: attention_logit_biases = 0 # To approximate softmax, we subtract a bias dependent on sequence length. diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py new file mode 100644 index 000000000..f6a292136 --- /dev/null +++ b/axlearn/common/attention_bias.py @@ -0,0 +1,742 @@ +# Copyright © 2024 Apple Inc. + +"""Data structures for working with different kinds of attention logit biases. + +Some downstream optimizations, e.g., flash attention, rely on specific subclasses of +BaseAttentionBias being used. E.g., CausalAttentionBias should be used to specify the causal mask. +The optimizations will not happen if a BoolTensorAttentionBias is used to specify the causal mask +instead. +These "special" bias classes also include SegmentIdAttentionBias and MaskFnAttentionBias. + +Note that the various `BaseAttentionBias` classes are not intended to be instantiated at +configuration time. Instead, they are geared towards people developing layers who can then return +an instance of `BaseAttentionBias` where they would have before returned an explicit attention +bias Tensor. Currently, we don't have support for AttentionLogitBiasLayer returning these objects +instead of explicit bias Tensors, but such support can be added in the future if needed in a +fully backwards-compatible manner. +""" + +import dataclasses +import functools +import typing +from typing import ( + Callable, + Generic, + Iterable, + Optional, + Protocol, + Sequence, + Type, + TypeVar, + Union, + cast, + final, +) + +import jax +from jax import numpy as jnp +from jax.sharding import PartitionSpec + +from axlearn.common import struct +from axlearn.common.config import ConfigOr, maybe_instantiate +from axlearn.common.utils import Tensor + +NEG_INF = -1e15 + +# We use OpT when we have a function like fn(x: OpT) -> OpT where we want to annotate +# that the functions return value is None/not None according to whether the input is None/not None. +# This makes e.g. IDE inspections able to understand that fn(jnp.ones(5)) is in fact a Tensor and +# not None. +OpT = typing.TypeVar("OpT", type(None), Tensor) +B = TypeVar("B", bound="BaseAttentionBias") + + +@functools.partial(struct.dataclass, eq=False) +class BaseAttentionBias: + """Base class representing attention logit biases.""" + + # The dtype of the biases to return in `value()`. + # If None, do not cast the dtype. + dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False) + + @final + def value(self) -> Optional[Tensor]: + """Return a tensor with the biases or None if there are no biases. + + Shape: [batch or 1, num_heads or 1, target_len, source_len]. + + The dtype will be cast to `self.dtype` if it is not None. + """ + value = self._value() + if self.dtype is not None and value is not None: + value = value.astype(self.dtype) + return self._broadcast_value(value) + + def _value(self) -> Optional[Tensor]: + """Internal version of `value()` without the casting and broadcasting done in the public + method. + + Subclasses must implement this. + + Shape: Any of: + * [target_len, source_len] + * [batch or 1, target_len, source_len] + * [batch or 1, num_heads or 1, target_len, source_len]. + """ + raise NotImplementedError + + def __add__(self, other: "BaseAttentionBias") -> "CompositeAttentionBias": + """Returns a bias tensor representing the sum of `self` and `other`. + + The implementation lazily adds them by creating a CompositeAttentionBias + from the biases being added. + """ + return CompositeAttentionBias([self, other]) + + def astype(self, dtype: jnp.dtype) -> "BaseAttentionBias": + """Return a new bias whose dtype is `dtype`.""" + result = dataclasses.replace(self, dtype=dtype) + result = cast(BaseAttentionBias, result) + return result + + @classmethod + def _broadcast_value(cls, value: OpT) -> OpT: + """Broadcasts `value` to a canonical 4 dimensional attention bias shape. + + Raises: + ValueError: If the shape of `value` is not 2, 3, or 4 dimensional. + """ + if value is None or value.ndim == 4: + return value + if value.ndim == 2: + # Shape: [1, 1, target_length, source_length]. + return value[None, None, :, :] + elif value.ndim == 3: + # Shape: [batch, 1, target_length, source_length]. + return value[:, None, :, :] + raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.") + + def eval_shape(self): + return jax.eval_shape(self.value).shape + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union["BaseAttentionBias", PartitionSpec]: + """Compute a partition spec for this bias.""" + raise NotImplementedError + + def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]": + """Split this bias into a bias of type `cls` and a residual. + + If the two returned biases are added together, the result is equivalent to + the value of this bias. + + The default implementation returns `self` as either the bias or residual. + Which field is set to `self` is based on whether this is an instance of `cls`. + If it is an instance, `self` is returned in the bias field and the `residual` field will be + a `BaseAttentionBias` with value() None. + If not, it is returned in the `residual` field, and the `bias` field is set to `None`. + """ + if isinstance(self, cls): + return BiasAndResidual(bias=self, residual=CompositeAttentionBias([])) + return BiasAndResidual(bias=None, residual=self) + + @classmethod + def from_sequence(cls, biases: Sequence["BaseAttentionBias"]) -> Optional["BaseAttentionBias"]: + """Constructs a single combined attention bias of the same type as this class + from a sequence of such biases. + + If the sequence is empty, returns None. + + The default implementation returns the bias if the sequence has length one and + raises for length > 1. + + Raises: + NotImplementedError: If the sequence has length > 1. + TypeError: If `seq` contains a type of bias that is not an instance of this class. + ValueError: If `eval_shape()` is not the same for every bias. + """ + + if not biases: + return None + + shape = biases[0].eval_shape() + for bias in biases: + if not isinstance(bias, cls): + raise TypeError(f"Got bias type {type(bias)}, not instance of {cls}.") + if bias.eval_shape() != shape: + raise ValueError(f"Got shape mismatch {bias.eval_shape()} != {shape}.") + + if len(biases) == 1: + return biases[0] + raise NotImplementedError + + +@struct.dataclass +class BiasAndResidual(BaseAttentionBias, Generic[B]): + """A bias and residual where the bias has type `B` (or is None) and the residual + has any type. + + Used to represent an original bias that has been split into the sum of two biases. + + See `BaseAttentionBias.bias_and_residual()`. + """ + + bias: Optional[B] + residual: BaseAttentionBias + + def _value(self) -> Optional[Tensor]: + return CompositeAttentionBias([self.bias, self.residual]).value() + + def __iter__(self): + return iter((self.bias, self.residual)) + + +@struct.dataclass +class CompositeAttentionBias(BaseAttentionBias): + """A lazily evaluated list of biases that are added together to get the final bias. + + The implementation automatically flattens nested instances of `CompositeAttentionBias`. + + Biases that have a `value()` of None or are equal to None are automatically omitted + when iterating over this instance. However, they remain in the `biases` list + and are therefore still part of the pytree structure. + """ + + # The biases to add to obtain the final bias. + biases: Sequence[BaseAttentionBias] + + def _value(self) -> Optional[Tensor]: + """Returns the sum of the biases. + + If all biases have value None, this is guaranteed to also return None. + + Shape: [batch or 1, num_heads or 1, target_len, source_len]. + + Raises: + ValueError: If one of the biases in the sum has the wrong shape. + """ + biases = self._nonzero() + if not biases: + return None + + result = 0.0 + for bias in biases: + result += bias.value() + return result + + def __add__(self, other: BaseAttentionBias) -> "CompositeAttentionBias": + return self.__class__([self, other]) + + def _nonzero(self) -> Sequence[BaseAttentionBias]: + """Returns an sequence of biases in this collection except those detected as zero. + + Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None. + """ + filt = lambda b: b.value() is not None + return list(filter(filt, self.biases)) + + def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]": + """Split this bias into a bias of type `cls` and a residual. + + Compared to the default implementation, this determines which instance of `cls` to return + by calling `bias_and_residual()` on each member of this collection. It also recursively + calls `bias_and_residual` on any residuals obtained from such BiasAndResidual objects. + + All non-None biases returned from doing this are then merged using `cls.from_sequence()` + before returning. + """ + bias_and_residual = super().bias_and_residual(cls) + if bias_and_residual.bias is not None: + return bias_and_residual + remaining_biases = list(self._nonzero()) + cls_biases = [] + residuals = [] + while remaining_biases: + bias = remaining_biases.pop() + bias_and_residual = bias.bias_and_residual(cls) + if bias_and_residual.bias is not None: + cls_biases.append(bias_and_residual.bias) + send_residual_to = remaining_biases + else: + send_residual_to = residuals + if bias_and_residual.residual.value() is not None: + send_residual_to.append(bias_and_residual.residual) + return BiasAndResidual( + bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals) + ) + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + return CompositeAttentionBias( + [ + b.partition_spec(mha_dim_to_partition_spec) if b is not None else PartitionSpec() + for b in self.biases + ], + dtype=self.dtype, + ) + + def _flatten(self) -> "CompositeAttentionBias": + """Returns a flattened version of this instance. + + Used only for testing/debugging + """ + remaining = [self] + biases = [] + while remaining: + bias = remaining.pop() + if isinstance(bias, CompositeAttentionBias): + remaining.extend(bias._nonzero()) # pylint: disable=protected-access + else: + biases.append(bias) + return CompositeAttentionBias(biases) + + +def split(bias: BaseAttentionBias, *cls: Type[BaseAttentionBias]) -> Iterable[BaseAttentionBias]: + """Split `bias` into an iterable of biases of `len(cls) + 1` instances, where the ith instances + has type cls[i] or ZeroAttentionBias. + + Each bias will only be present once in the output, with ties broken based on the first + matching type in `cls`. + + The correctness of this function requires that `bias_and_residual(cls)` + has the property that future calls to `bias_and_residual()` on any residuals + obtained directly or indirectly from the original residual will never return an instance + of `cls`. + + Raises: + NotImplementedError: If any residual is encountered with a type in `cls`. + """ + for c in cls: + result, bias = bias.bias_and_residual(c) + if isinstance(bias, cls): + raise NotImplementedError("Got a residual of type in `cls`.") + if result is None: + yield ZeroAttentionBias() + else: + yield result + yield bias + + +@struct.dataclass +class TensorAttentionBias(BaseAttentionBias): + """An attention bias represented as an explicit Tensor.""" + + # The explicit value of the bias. + # Shape: [batch, num_heads, target_len, source_len]. + _internal_value: Tensor + + def __post_init__(self): + # Because TensorAttentionBias is a struct.dataclass and the automatically generated pytree + # flattening methods for all struct.dataclasses always flatten to a list of the dataclass + # fields. (I.e., not the result of calling value().) + # Therefore, we enforce a consistent shape so that the partition spec correctly lines + # up wit the dimensions of the stored Tensor. + if getattr(self._internal_value, "ndim", 4) != 4: + raise ValueError(f"Invalid shape {self._internal_value.shape}.") + + def _value(self) -> Tensor: + return self._internal_value + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + shape = self.eval_shape() + spec = mha_dim_to_partition_spec["bnts"] + return _spec_for_explicit_bias(spec=spec, shape=shape) + + @classmethod + def from_tensor(cls, tensor: Tensor) -> "TensorAttentionBias": + """Constructs an instance of this class, automatically canonicalizing the shape of + `tensor` to the required form. + + Unlike a CompositeAttentionBias, this can be used as a mask in SplashAttention. + """ + return cls(cls._broadcast_value(tensor)) + + +def _spec_for_explicit_bias( + spec: PartitionSpec, shape: tuple[int, ...] +) -> Union[BaseAttentionBias, PartitionSpec]: + """Return a PartionSpec for an explicit bias tensor of the given shape baed on `spec`.""" + # Explicit attention bias: [batch_size, num_heads, target_len, source_len]. + if spec != PartitionSpec(None): + if shape[0] == 1: + spec = PartitionSpec(None, *spec[1:]) + if shape[1] == 1: + spec = PartitionSpec(spec[0], None, *spec[2:]) + return spec + + +@struct.dataclass +class BoolAttentionBias(BaseAttentionBias): + """An attention bias represented as a boolean mask.""" + + @final + def _value(self) -> Optional[Tensor]: + bool_value = self.bool_value() + if bool_value is None: + return None + return bool_to_bias(bool_value) + + @final + def bool_value(self) -> Optional[Tensor]: + """Return a tensor with the boolean values from `self.mask` before they have been converted + to biases. + + Shape: Same as `self.value()`. + """ + bool_value = self._bool_value() + if bool_value is not None: + bool_value = bool_value.astype(bool) + return self._broadcast_value(bool_value) + + def _bool_value(self) -> Optional[Tensor]: + """Internal version of `bool_value()` without the casting and broadcasting done in the + public method. + + Subclasses must implement this. + + Shape: Same as `self._value()`. + """ + raise NotImplementedError + + +@struct.dataclass +class SegmentIdAttentionBias(BoolAttentionBias): + """An attention bias defined by segment ids.""" + + # See ``on segment ids'' in the module docstring. + segment_ids: Tensor + + def _bool_value(self) -> Optional[Tensor]: + return _make_bool_segment_mask( + target_segments=self.segment_ids, source_segments=self.segment_ids + ) + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + # Segment IDs: [batch_size, seq_len]. + q_spec = mha_dim_to_partition_spec["btnh"] + if q_spec == PartitionSpec(None): + return PartitionSpec(None) + return PartitionSpec(q_spec[0], q_spec[1]) + + +class MaskFn(Protocol): + """A broadcastable function for computing a boolean logit mask.""" + + def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor: + """Returns a bool Tensor of whether the query token at `query_position` should attend + to the key token at `key_position`. + + Implementations have the following contract: + * Must support scalar arguments. + * If given non-scalar arguments of the same shape, the result must be the same as + applying the function elementwise over these arugments. I.e., + ``` + x = f(jnp.asarray([1,2]), jnp.asarray([3,4])) + assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None] + ``` + * If given non-scalar arguments of different shapes, the result must be the same if we + first broadcast the arguments against each other to make them have the same shape. + * Beyond requiring broadcastability, must not impose any constraints on the shapes of its + arguments. + * Must return a non-tracer value when run in `jax.ensure_compile_time_eval()`. This will + typically be the case as long as you don't access tensors stored in global variables. + + Args: + query_position: The index in the sequence of query vectors. + key_position: The index in the sequence of key vectors. + + Returns: + Whether the query and key vectors with the given index should attend to one another. + True means they should attend. False means they should not. + The shape is the same as the shape obtained after broadcasting the inputs against each + other according to numpy broadcasting semantics. + For example, a common usage pattern is that `query_position` has shape + `[batch, tgt_seq, 1]` and `key_position` has shape `[batch, 1, src_seq]` and the mask + will have shape `[batch, tgt_seq, src_seq]`. + Reference for bradcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ + + +@struct.dataclass +class MaskFnAttentionBias(BoolAttentionBias): + """An attention bias represented as an implicit boolean mask.""" + + # The function defining the contents of the mask. + mask: MaskFn = struct.field(pytree_node=False) + # The shape [target_len, source_len] of the mask. + shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False) + # The positions in the query sequence that the mask should be computed for. + # I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at + # `target_positions[batch, num_heads i]` may attend to. + # If None, set `target_positions[batch, num_heads, i] = i`. + # Shape: [batch]. + # This is typically used during decoding to specify the locations in the sequence being + # being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch + # entry respectively, we would set `target_positions = jnp.asarray([5, 7])`. + target_positions: Optional[Tensor] = None + + def _bool_value(self) -> Optional[Tensor]: + """Return a tensor with the boolean values from `self.mask` before they have been converted + to biases. + + Shape: + - If `target_positions` is None: [target_len, source_len] + - Else: [batch, target_len, source_len]. + """ + target_positions, source_positions = jnp.indices(self.shape, sparse=True) + if self.target_positions is not None: + target_positions = self.target_positions + if target_positions.ndim == 1: + # pylint: disable-next=unsubscriptable-object + target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) + while target_positions.ndim < 3: + target_positions = target_positions[..., None] + return self.mask(target_positions, source_positions) # pylint: disable=not-callable + + @classmethod + def from_sequence( + cls, biases: Sequence["MaskFnAttentionBias"] + ) -> Optional["MaskFnAttentionBias"]: + """Constructs a single combined `MaskFnAttentionBias` from a Sequence of them. + + The sequence is first filtered to remove biases that are detected as all zero. + + If the sequence only has one element after doing so, that one element is returned without + modification. + + If the sequence is empty, returns None. + + Raises: + ValueError: If `target_positions` is set for any bias. + """ + try: + return super().from_sequence(biases) + except NotImplementedError: + pass + for bias in biases: + if bias.target_positions is not None: + raise ValueError(f"target_positions was not None for {bias}.") + + # Combine masks. + mask = lambda query_position, key_position: jnp.all( + jnp.stack([b.mask(query_position, key_position) for b in biases]), axis=0 + ) + return MaskFnAttentionBias(mask=mask, shape=biases[0].shape) + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + return PartitionSpec(*mha_dim_to_partition_spec["bnts"][0:1]) + + +@struct.dataclass +class BoolTensorAttentionBias(BoolAttentionBias): + """An attention bias represented as an explicit boolean mask.""" + + # The explicit bool value of the bias. + _internal_bool_value: Tensor + + def __post_init__(self): + if getattr(self._internal_bool_value, "ndim", 4) != 4: + raise ValueError(f"Invalid shape {self._internal_bool_value.shape}.") + if getattr(self._internal_bool_value, "dtype", bool) != bool: + raise ValueError(f"Invalid dtype {self._internal_bool_value.dtype}, expected bool.") + + def _bool_value(self) -> Tensor: + """Return a tensor with the boolean values from `self.mask` before they have been converted + to biases. + """ + return self._internal_bool_value + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + shape = self.eval_shape() + spec = mha_dim_to_partition_spec["bnts"] + return _spec_for_explicit_bias(spec=spec, shape=shape) + + @classmethod + def from_tensor(cls, tensor: Tensor) -> "BoolTensorAttentionBias": + """Constructs an instance of this class, automatically canonicalizing the shape of + `tensor` to the required form. + """ + return cls(cls._broadcast_value(tensor)) + + +def as_attention_bias(bias: Union[None, Tensor, B]) -> B: + """Converts `bias` to an instance of `BaseAttentionBias`. + + Raises: + ValueError: If `bias` is a Tensor but does not have a floating point dtype. + NotImplementedError: If `bias` is an unknown type. + """ + if bias is None: + return ZeroAttentionBias() + if isinstance(bias, Tensor): + if not jnp.issubdtype(bias.dtype, jnp.floating): + raise ValueError(f"bias must have a floating dtype, got {bias.dtype}.") + return TensorAttentionBias.from_tensor(bias) + if isinstance(bias, BaseAttentionBias): + return bias + raise NotImplementedError(type(bias)) + + +def causal_mask(query_position: Tensor, key_position: Tensor) -> Tensor: + """Returns the given entry of a causal attention mask. + + Implements the `MaskFn` protocol. + See that and `MultiheadAttention.Config.mask`. + """ + return query_position >= key_position + + +@struct.dataclass +@final +class CausalAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error + """A causal attention mask.""" + + mask: Optional[MaskFn] = struct.field(pytree_node=False, default=causal_mask) + + @classmethod + def from_sequence( + cls, biases: Sequence["CausalAttentionBias"] + ) -> Optional["CausalAttentionBias"]: + try: + return super().from_sequence(biases) + except NotImplementedError: + pass + return biases[0] + + +@struct.dataclass +@final +class ZeroAttentionBias(BoolAttentionBias): + """ "Attention bias that adds zero. + + It is better to check whether a bias has `value()` None rather than using + an isinstacne check on this class, since the former is more general. + """ + + def _bool_value(self) -> None: + return None + + def partition_spec( + self, mha_dim_to_partition_spec: dict[str, PartitionSpec] + ) -> Union[BaseAttentionBias, PartitionSpec]: + # Nothing to shard. + return PartitionSpec() + + def __eq__(self, other): + return type(other) is type(self) + + +def _composite_masks(op: Callable[[Tensor, Tensor], Tensor], *mask_fns: ConfigOr[MaskFn]): + if len(mask_fns) == 0: + raise RuntimeError(f"Input must not be empty: {mask_fns}") + + def mask(query_position: Tensor, key_position: Tensor): + fns = [maybe_instantiate(arg) for arg in mask_fns] + result = fns[0](query_position, key_position) + for mask in fns[1:]: + result = op(result, mask(query_position, key_position)) + return result + + return mask + + +def or_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn: + """Returns a MaskFn that's the union of provided MaskFn's.""" + return _composite_masks(jnp.logical_or, *mask_fns) + + +def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn: + """Returns a MaskFn that's the intersection of provided MaskFn's.""" + return _composite_masks(jnp.logical_and, *mask_fns) + + +def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn: + """Returns a causal MaskFn for sliding window attentions of a given window size. + + Implements the `MaskFn` protocol. + """ + + def mask(query_position: Tensor, key_position: Tensor): + return query_position - key_position <= sliding_window_size + + return and_masks(causal_mask, mask) + + +def make_causal_biases(seq_len: int) -> Tensor: + """Generates attention logit biases for causal masking. + + Args: + seq_len: Sequence length. + + Returns: + A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf if i < j, + 0 otherwise. + """ + # TODO(sneha): Support batching. + return bool_to_bias(causal_mask(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :])) + + +def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor: + """Generates attention logit biases for sliding window attention. + + Args: + seq_len: Sequence length. + + Returns: + A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf + if i - j > sliding_window_size or i < j, 0 otherwise. + """ + mask_fn = sliding_window_causal_mask(sliding_window_size) + return bool_to_bias(mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :])) + + +def bool_to_bias(mask: OpT) -> OpT: + """Converts a bool mask tensor to a bias mask tensor. + + Maps: + 0 -> -NEG_INF + 1 -> 0. + """ + if mask is None: + return None + if mask.dtype != jnp.bool: + raise ValueError("mask must be a Boolean tensor.") + return (~mask) * NEG_INF + + +def _make_bool_segment_mask(*, source_segments: Tensor, target_segments: Tensor) -> Tensor: + """The same as `make_segment_mask()` but returns a bool mask tensor instead of a flaot + bias tensor, where True corresponds to a bias of 0 and False corresponds to a bias of NEG_INF> + """ + target_segments = jnp.expand_dims(target_segments, -1) + source_segments = jnp.expand_dims(source_segments, -2) + return jax.lax.eq(source_segments, target_segments)[:, None, ...] + + +def make_segment_mask(*, source_segments: Tensor, target_segments: Tensor) -> Tensor: + """Generates attention logit biases given the segment ids. + + ... such that positions belonging to different segments cannot attend to each other. + + Args: + source_segments: An integer tensor of shape [batch, ..., source_length]. + target_segments: An integer tensor of shape [batch, ..., target_length]. + + Returns: + A float Tensor of shape [batch, 1, ..., target_length, source_length] where the + value at [..., i, j] = 0 if target_segments[..., i] == source_segments[..., j], or -inf + otherwise. + """ + return NEG_INF * ~_make_bool_segment_mask( + source_segments=source_segments, target_segments=target_segments + ) diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py new file mode 100644 index 000000000..0932df100 --- /dev/null +++ b/axlearn/common/attention_bias_test.py @@ -0,0 +1,280 @@ +# Copyright © 2024 Apple Inc. + +"""Tests for attention_bias.py.""" +from typing import Optional + +import chex +import jax.numpy as jnp +import jax.util +from absl.testing import parameterized +from jax.sharding import PartitionSpec + +from axlearn.common import attention_bias, test_utils +from axlearn.common.attention_bias import ( + CausalAttentionBias, + CompositeAttentionBias, + MaskFnAttentionBias, + SegmentIdAttentionBias, + TensorAttentionBias, +) +from axlearn.common.utils import Tensor + + +class AttentionBiasTest(test_utils.TestCase): + def test_causal_attention_bias(self): + bias = attention_bias.CausalAttentionBias(shape=(5, 5)) + chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None]) + self.assertIsInstance(bias, attention_bias.CausalAttentionBias) + + bias = attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)) + self.assertNotIsInstance(bias, attention_bias.CausalAttentionBias) + + def test_zero_attention_bias(self): + bias = attention_bias.ZeroAttentionBias() + self.assertEqual(bias.value(), None) + + bias = attention_bias.MaskFnAttentionBias(None, shape=(5, 5)) + self.assertNotIsInstance(bias, attention_bias.ZeroAttentionBias) + + self.assertNotIsInstance( + attention_bias.CausalAttentionBias(shape=(5, 5)), attention_bias.ZeroAttentionBias + ) + + def test_base_attention_bias_value(self): + """Tests `BaseAttentionBias.value()`.""" + # pylint: disable=function-redefined + + class TestAttentionBias(attention_bias.BaseAttentionBias): + def _value(self) -> Optional[Tensor]: + return jnp.ones((5, 7)) + + self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7)) + + class TestAttentionBias(attention_bias.BaseAttentionBias): + def _value(self) -> Optional[Tensor]: + return jnp.ones((3, 5, 7)) + + self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7)) + + class TestAttentionBias(attention_bias.BaseAttentionBias): + def _value(self) -> Optional[Tensor]: + return jnp.ones((2, 3, 5, 7)) + + self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7)) + + def test_base_attention_bias_and_residual(self): + """Tests `BaseAttentionBias.bias_and_residual()`.""" + bias = attention_bias.ZeroAttentionBias() + self.assertEqual( + bias.bias_and_residual(attention_bias.ZeroAttentionBias), + attention_bias.BiasAndResidual(bias=bias, residual=CompositeAttentionBias([])), + ) + self.assertEqual( + bias.bias_and_residual(attention_bias.BaseAttentionBias), + attention_bias.BiasAndResidual(bias=bias, residual=CompositeAttentionBias([])), + ) + self.assertEqual( + bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias) + ) + + def test_composite_attention_bias_zero(self): + # Test handling of zero biases. + bias = attention_bias.CompositeAttentionBias( + [attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()] + ) + self.assertEqual(bias.value(), None) + self.assertEqual(bias._nonzero(), []) # pylint: disable=protected-access + # The partition spec needs to have the same structure as the biases list. + self.assertEqual(bias.partition_spec({}).biases, [PartitionSpec(), PartitionSpec()]) + + def test_composite_attention_bias(self): + # Test value(). + b1 = attention_bias.CausalAttentionBias(shape=(5, 5)) + # Opposite of causal mask. + b2 = attention_bias.MaskFnAttentionBias(shape=(5, 5), mask=lambda q, k: q < k) + expected = attention_bias.MaskFnAttentionBias( + shape=(5, 5), + mask=lambda q, k: jnp.zeros(jnp.broadcast_shapes(q.shape, k.shape), dtype=bool), + ) + bias = attention_bias.CompositeAttentionBias([b1, b2]) + self.assertNestedEqual(bias.value(), expected.value()) + + # Test adding biases. + bias = b1 + b2 + self.assertNestedEqual(bias.value(), expected.value()) + + # Test bias_and_residual(). + bias = attention_bias.CompositeAttentionBias([b2, b1]) + bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias) + self.assertNestedEqual(bias_and_residual.bias.value(), b1.value()) + self.assertNestedEqual(bias_and_residual.residual.value(), b2.value()) + + bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias) + self.assertNestedEqual(bias_and_residual.bias.value(), bias.value()) + self.assertIs(bias_and_residual.residual.value(), None) + + bias_and_residual = bias.bias_and_residual(attention_bias.ZeroAttentionBias) + self.assertNestedEqual(bias_and_residual.bias, None) + self.assertNestedEqual(bias_and_residual.residual.value(), bias.value()) + + bias_and_residual = bias.bias_and_residual(attention_bias.CompositeAttentionBias) + self.assertNestedEqual(bias_and_residual.bias.value(), bias.value()) + self.assertNestedEqual(bias_and_residual.residual.value(), None) + + bias_and_residual = (b1 + b1).bias_and_residual(attention_bias.CausalAttentionBias) + self.assertNestedEqual(bias_and_residual.bias.value(), b1.value()) + self.assertNestedEqual(bias_and_residual.residual.value(), None) + + def test_bias_and_residual_repeated_call(self): + """Test repeated calls to `bias_and_residual()`.""" + b1 = attention_bias.CausalAttentionBias(shape=(5, 5)) + # Opposite of causal mask. + b2 = attention_bias.MaskFnAttentionBias(shape=(5, 5), mask=lambda q, k: q < k) + bias = attention_bias.CompositeAttentionBias([b2, b1]) + causal_bias, residual = bias.bias_and_residual(attention_bias.CausalAttentionBias) + mask_fn_bias, residual = residual.bias_and_residual(attention_bias.MaskFnAttentionBias) + self.assertIs(causal_bias, b1) + self.assertIs(mask_fn_bias, b2) + self.assertIs(residual.value(), None) + + # Test nested CompositeAttentionBias. + bias = CompositeAttentionBias([CompositeAttentionBias([b1]), b2]) + bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias) + self.assertNestedEqual(bias_and_residual.bias.value(), bias.value()) + self.assertIsInstance(bias_and_residual.bias, attention_bias.MaskFnAttentionBias) + self.assertIs(bias_and_residual.residual.value(), None) + + def test_split(self): + b1 = attention_bias.CausalAttentionBias(shape=(5, 5)) + # Opposite of causal mask. + b2 = attention_bias.MaskFnAttentionBias(shape=(5, 5), mask=lambda q, k: q < k) + bias = attention_bias.CompositeAttentionBias([b2, b1]) + causal_bias, mask_fn_bias, residual = attention_bias.split( + bias, attention_bias.CausalAttentionBias, attention_bias.MaskFnAttentionBias + ) + self.assertIs(causal_bias, b1) + self.assertIs(mask_fn_bias, b2) + self.assertIs(residual.value(), None) + + zero_bias, residual = attention_bias.split(bias, attention_bias.TensorAttentionBias) + self.assertIs(zero_bias.value(), None) + self.assertNestedEqual(residual.value(), bias.value()) + + b3 = attention_bias.SegmentIdAttentionBias(jnp.asarray([1, 1, 2, 2, 2])) + segment, mask, residual = attention_bias.split( + b1 + b3, attention_bias.SegmentIdAttentionBias, attention_bias.MaskFnAttentionBias + ) + self.assertIs(segment, b3) + self.assertIs(mask, b1) + self.assertIs(residual.value(), None) + + @parameterized.product( + causal=[None, attention_bias.CausalAttentionBias(shape=(3, 3))], + segment_ids=[None, attention_bias.SegmentIdAttentionBias(jnp.asarray([1, 2, 3]))], + mask=[None, attention_bias.MaskFnAttentionBias(mask=lambda q, k: q < k, shape=(3, 3))], + ) + def test_split_subsets( + self, + causal: Optional[CausalAttentionBias], + segment_ids: Optional[SegmentIdAttentionBias], + mask: Optional[MaskFnAttentionBias], + ): + """Tests split() where the input CompositeBias contains any possible subsets of a + causal, segment id, and mask fn attention bias. + """ + bias_list = [mask, causal, segment_ids] + bias_list = [b for b in bias_list if b is not None] + bias = attention_bias.CompositeAttentionBias(bias_list) + new_bias_list = attention_bias.split( + bias, + attention_bias.CausalAttentionBias, + attention_bias.SegmentIdAttentionBias, + attention_bias.MaskFnAttentionBias, + ) + new_bias_list = [b if b.value() is not None else None for b in new_bias_list] + expected = [causal, segment_ids, mask, None] + for b1, b2 in jax.util.safe_zip(new_bias_list, expected): + self.assertIs(b1, b2) + + def test_tensor_attention_bias(self): + bias = attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 7))) + self.assertNestedEqual(bias.value(), jnp.ones((1, 1, 5, 7))) + + def test_segment_id_attention_bias(self): + bias = attention_bias.SegmentIdAttentionBias( + jnp.asarray([[1, 1, 2, 2, 2, 0], [1, 2, 3, 4, 5, 6]]) + ) + expected = attention_bias.bool_to_bias( + jnp.asarray( + [ + [ + [True, True, False, False, False, False], + [True, True, False, False, False, False], + [False, False, True, True, True, False], + [False, False, True, True, True, False], + [False, False, True, True, True, False], + [False, False, False, False, False, True], + ], + jnp.eye(6, 6, dtype=bool), + ], + dtype=bool, + ) + ) + expected = expected[:, None, :, :] + self.assertNestedEqual(bias.value(), expected) + + def test_mask_fn_attention_bias_from_sequence(self): + """Tests `MaskFnAttentionBias.from_sequence()`.""" + b1 = attention_bias.CausalAttentionBias(shape=(5, 5)) + # Opposite of causal mask. + b2 = attention_bias.MaskFnAttentionBias(shape=(5, 5), mask=lambda q, k: q < k) + + self.assertNestedEqual( + attention_bias.MaskFnAttentionBias.from_sequence([b1, b2]).value(), (b1 + b2).value() + ) + self.assertIsInstance( + attention_bias.MaskFnAttentionBias.from_sequence([b1]), + attention_bias.CausalAttentionBias, + ) + self.assertIs(attention_bias.MaskFnAttentionBias.from_sequence([]), None) + + def test_mask_fn_attention_bias(self): + bias = attention_bias.MaskFnAttentionBias(mask=lambda q, k: q >= k, shape=(5, 5)) + self.assertNestedEqual( + bias.value(), jnp.asarray(attention_bias.make_causal_biases(5))[None, None] + ) + + bias = attention_bias.MaskFnAttentionBias( + mask=lambda q, k: q >= k, shape=(4, 7), target_positions=jnp.asarray([3, 1]) + ) + expected = jnp.asarray( + [ + [ + [True, True, True, True, False, False, False], + [True, True, True, True, True, False, False], + [True, True, True, True, True, True, False], + [True, True, True, True, True, True, True], + ], + [ + [True, True, False, False, False, False, False], + [True, True, True, False, False, False, False], + [True, True, True, True, False, False, False], + [True, True, True, True, True, False, False], + ], + ], + dtype=bool, + ) + expected = attention_bias.bool_to_bias(expected)[:, None, :] + self.assertNestedEqual(bias.value(), expected) + + def test_bool_tensor_attention_bias(self): + bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) + self.assertNestedEqual( + bias.value(), attention_bias.bool_to_bias(jnp.ones((1, 1, 5, 7), dtype=bool)) + ) + + def test_astype(self): + bias = TensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=jnp.float32)) + self.assertEqual(bias.value().dtype, jnp.float32) + bias = bias.astype(jnp.bfloat16) + self.assertEqual(bias.value().dtype, jnp.bfloat16) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 2f1b71387..5d4aeb623 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -13,6 +13,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. """Tests attention layers.""" + import contextlib import copy import itertools @@ -38,14 +39,12 @@ from transformers.models.roformer import modeling_roformer as hf_roformer from transformers.models.xlnet import modeling_xlnet as hf_xlnet -from axlearn.common import attention, test_utils, utils +from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( FEED_FORWARD_SAVE_PATTERN, - NEG_INF, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, - ForwardMode, FusedGroupedQKVLinear, FusedQKVLinear, KVState, @@ -69,20 +68,23 @@ _save_and_offload_only_these_names_regex, apply_attention_logit_biases, apply_rotary_position_embeddings, - bool_to_bias, build_remat_spec, - causal_mask, compute_padding_biases, - make_causal_biases, - make_sliding_window_causal_biases, rel_pos_to_abs_pos, scaled_hidden_dim, set_double_shard_weights_config, sinusoidal_positional_embeddings, - sliding_window_causal_mask, update_data_with_skip_connection, xl_attention_logits, ) +from axlearn.common.attention_bias import ( + NEG_INF, + bool_to_bias, + causal_mask, + make_causal_biases, + make_sliding_window_causal_biases, + sliding_window_causal_mask, +) from axlearn.common.base_layer import ( BaseLayer, DefaultTensorStats, @@ -172,7 +174,7 @@ class MaskTest(absltest.TestCase): def test_causal_mask(self): expected = jnp.array([[0.0, NEG_INF, NEG_INF], [0.0, 0.0, NEG_INF], [0.0, 0.0, 0.0]]) - actual = attention.make_causal_biases(3) + actual = attention_bias.make_causal_biases(3) self.assertTrue(jnp.all(actual <= expected)) def test_segment_mask(self): @@ -188,7 +190,7 @@ def test_segment_mask(self): ] ] ) - actual = attention.make_segment_mask( + actual = attention_bias.make_segment_mask( target_segments=jnp.asarray([[1, 1, 2, 0]]), source_segments=jnp.asarray([[2, 2, 0, 1]]), ) @@ -2180,55 +2182,6 @@ def test_causal( # The outputs are equivalent. self.assertNestedAllClose(outputs[0], outputs[1]) - def test_logit_biases_for_mask(self): - model_dim = 16 - num_heads = 4 - cfg = attention.MultiheadAttention.default_config().set( - name="test", - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - mask=causal_mask, - ) - layer = cfg.instantiate(parent=None) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - - query_len, kv_len = 2, 3 - query_pos = jnp.arange(query_len)[None] - kv_pos = jnp.arange(kv_len)[None] - inputs = dict(mode=ForwardMode.FORWARD, query_pos=query_pos, kv_pos=kv_pos) - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertNestedAllClose( - layer_outputs, - bool_to_bias(jnp.array([[1, 0, 0], [1, 1, 0]], dtype=jnp.bool))[None, None], - ) - - time_step = jnp.array([1, 2]) - query_pos = time_step[:, None] - kv_len = 4 - kv_pos = jnp.arange(kv_len)[None] - inputs = dict(mode=ForwardMode.EXTEND_STEP, query_pos=query_pos, kv_pos=kv_pos) - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertNestedAllClose( - layer_outputs, - bool_to_bias(jnp.array([[1, 1, 0, 0], [1, 1, 1, 0]], dtype=jnp.bool))[:, None, None, :], - ) - @parameterized.product( base_cfg=( attention.MultiheadAttention.default_config(), @@ -2374,7 +2327,7 @@ def test_gqa_forward( ), key=None, value=None, - attention_logit_biases=attention.make_causal_biases(tgt_len), + attention_logit_biases=attention_bias.make_causal_biases(tgt_len), ) # Get outputs. forward_key = jax.random.PRNGKey(456) @@ -2437,7 +2390,7 @@ def _test_extend_step( # Make key and value distinct from query. Otherwise, it is equivalent # to the query only case. key = value = query + 0.1 - attention_logit_biases = attention.make_causal_biases(tgt_len) + attention_logit_biases = attention_bias.make_causal_biases(tgt_len) return_aux = {"probs"} inputs = dict( query=query, @@ -2600,7 +2553,7 @@ def _test_prefill_states( # Make key and value distinct from query. Otherwise, it is equivalent # to the query only case. key = value = query + 0.1 - attention_logit_biases = attention.make_causal_biases(tgt_len) + attention_logit_biases = attention_bias.make_causal_biases(tgt_len) return_aux = {"probs"} forward_outputs, _ = F( @@ -2809,6 +2762,7 @@ def test_gqa_against_mha(self): q = jax.random.uniform(data_key, (batch, seq_len, num_heads, per_head_dim)) k = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) v = jax.random.uniform(data_key, (batch, seq_len, num_kv_heads, per_head_dim)) + attention_logit_biases = attention_logit_biases = attention_bias.ZeroAttentionBias() (test_context, ref_probs), _ = F( test_layer, @@ -2816,7 +2770,9 @@ def test_gqa_against_mha(self): state=state, is_training=False, prng_key=prng_key, - inputs=dict(q_proj=q, k_proj=k, v_proj=v), + inputs=dict( + q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases + ), ) k = jnp.repeat(k, num_heads // num_kv_heads, axis=2) @@ -2828,7 +2784,9 @@ def test_gqa_against_mha(self): state=state, is_training=False, prng_key=prng_key, - inputs=dict(q_proj=q, k_proj=k, v_proj=v), + inputs=dict( + q_proj=q, k_proj=k, v_proj=v, attention_logit_biases=attention_logit_biases + ), ) assert_allclose(ref_context, test_context) @@ -2999,7 +2957,7 @@ def test_sigmoid_compute_attention(self, qkv_value: float, expected_value: float q_proj=jnp.full(qkv_shape, fill_value=qkv_value), k_proj=jnp.full(qkv_shape, fill_value=qkv_value), v_proj=jnp.full(qkv_shape, fill_value=qkv_value), - attention_logit_biases=attention.make_causal_biases(seq_len), + attention_logit_biases=attention_bias.CausalAttentionBias(shape=(seq_len, seq_len)), ) # Get outputs. @@ -3831,7 +3789,7 @@ def test_with_golden_value(self): batch_size, tgt_len = 2, 6 rng = np.random.default_rng(seed=123) target = rng.random([batch_size, tgt_len, model_dim], dtype=np.float32) - mask = attention.make_causal_biases(tgt_len) + mask = attention_bias.make_causal_biases(tgt_len) mask = jnp.tile(mask[None, None, :, :], (batch_size, num_heads, 1, 1)) layer_outputs, _ = F( layer, @@ -4075,7 +4033,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type): target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) - self_attention_logit_biases = attention.make_causal_biases(tgt_len) + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) cross_attention_logit_biases = ( jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF ) @@ -4203,7 +4161,7 @@ def test_transformer_prefill_states(self, transformer_type, layer_type): target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) source = jax.random.normal(jax.random.PRNGKey(456), [batch_size, src_len, model_dim * 2]) - self_attention_logit_biases = attention.make_causal_biases(tgt_len) + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) cross_attention_logit_biases = ( jnp.array(np.random.randint(0, 2, [tgt_len, src_len])) * NEG_INF ) @@ -5207,7 +5165,7 @@ def test_forward(self, bottleneck_ratio): state = adapter.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) data = jax.random.normal(jax.random.PRNGKey(1), [batch_size, tgt_len, model_dim]) - self_attention_logit_biases = attention.make_causal_biases(tgt_len) + self_attention_logit_biases = attention_bias.make_causal_biases(tgt_len) outputs, _ = F( adapter, diff --git a/axlearn/common/bert_test.py b/axlearn/common/bert_test.py index d7bb4da73..acf7c25f9 100644 --- a/axlearn/common/bert_test.py +++ b/axlearn/common/bert_test.py @@ -12,7 +12,8 @@ from transformers.models.bert import modeling_bert as hf_bert from axlearn.common import bert, utils -from axlearn.common.attention import NEG_INF, BaseStackedTransformerLayer +from axlearn.common.attention import BaseStackedTransformerLayer +from axlearn.common.attention_bias import NEG_INF from axlearn.common.layers import ( BinaryClassificationMetric, Dropout, diff --git a/axlearn/common/decoder_test.py b/axlearn/common/decoder_test.py index 73386428c..288124af5 100644 --- a/axlearn/common/decoder_test.py +++ b/axlearn/common/decoder_test.py @@ -19,7 +19,6 @@ from axlearn.common import causal_lm, decoding, logit_modifiers, utils from axlearn.common.attention import ( - NEG_INF, ALiBiAttentionLogitBiasLayer, CausalAttentionLogitBiasLayer, MultiheadAttention, @@ -28,6 +27,7 @@ TransformerAttentionLayer, TransformerLayer, ) +from axlearn.common.attention_bias import NEG_INF from axlearn.common.base_layer import DefaultTensorStats, RematSpec from axlearn.common.causal_lm import gpt_decoder_config from axlearn.common.config import InstantiableConfig, config_for_function diff --git a/axlearn/common/dit_test.py b/axlearn/common/dit_test.py index c0e14d44e..5233c3896 100644 --- a/axlearn/common/dit_test.py +++ b/axlearn/common/dit_test.py @@ -20,7 +20,7 @@ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed from torch import nn -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF from axlearn.common.dit import ( AdaptiveLayerNormModulation, DiTAttentionLayer, diff --git a/axlearn/common/encoder_decoder.py b/axlearn/common/encoder_decoder.py index 6c0945d83..0f7241645 100644 --- a/axlearn/common/encoder_decoder.py +++ b/axlearn/common/encoder_decoder.py @@ -6,7 +6,7 @@ from jax import numpy as jnp -from axlearn.common.attention import NEG_INF, make_segment_mask +from axlearn.common.attention_bias import NEG_INF, make_segment_mask from axlearn.common.base_encoder_decoder import BaseEncoderDecoderModel from axlearn.common.config import ConfigOr, config_class from axlearn.common.decoder import Decoder diff --git a/axlearn/common/eval_retrieval.py b/axlearn/common/eval_retrieval.py index 11795f6df..a0cab8dbd 100644 --- a/axlearn/common/eval_retrieval.py +++ b/axlearn/common/eval_retrieval.py @@ -16,7 +16,7 @@ from axlearn.common import file_system as fs from axlearn.common import utils -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF from axlearn.common.base_model import BaseModel from axlearn.common.config import REQUIRED, Required, config_class from axlearn.common.evaler import GlobalMetricCalculator, PredictionOutputs diff --git a/axlearn/common/eval_retrieval_test.py b/axlearn/common/eval_retrieval_test.py index 4783b0daa..9a7bacc9b 100644 --- a/axlearn/common/eval_retrieval_test.py +++ b/axlearn/common/eval_retrieval_test.py @@ -11,7 +11,7 @@ from absl.testing import parameterized from jax.experimental.pjit import pjit -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF from axlearn.common.eval_retrieval import ( CLIPRetrievalMetricCalculator, CxcImageRetrievalMetricCalculator, diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index c1b19106e..43e62d8bb 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -36,7 +36,7 @@ from jax.experimental import pallas as pl from jax.experimental.pallas import gpu as plgpu -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF Tensor = jax.Array @@ -218,6 +218,8 @@ def flash_attention( ): """Computes attention outputs following FlashAttention. + If provided, bias, segment_ids, and any causal mask are applied on top of one another. + Args: query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. @@ -813,6 +815,8 @@ def cudnn_dot_product_attention( ): """Computes dot-product attention given query (Q), key (K), and value (V). + If provided, bias, segment_ids, and any causal mask are applied on top of one another. + Reference implementation: https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/cudnn/fused_attention_stablehlo.py#L927. https://github.com/openxla/xla/blob/536ba0b7d74f6637a7a772471a99ecf4f578aef2/xla/service/gpu/cublas_cudnn.cc#L77. diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 36da8d595..4317f5148 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -7,18 +7,12 @@ import jax import jax.numpy as jnp -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.shard_map import shard_map from jax.interpreters.pxla import thread_resources from jax.sharding import PartitionSpec -from axlearn.common.attention import ( - ForwardMode, - GroupedQueryAttention, - apply_attention_logit_biases, - causal_mask, - make_segment_mask, -) +from axlearn.common.attention import GroupedQueryAttention +from axlearn.common.attention_bias import BaseAttentionBias from axlearn.common.config import config_class from axlearn.common.flash_attention.utils import ( MultiHeadAttentionImpl, @@ -97,36 +91,6 @@ def default_config(cls) -> Config: } return cfg - def _is_mask_fn_used(self): - backend = self._backend() - # bias and segment_ids should also be None to use mask_fn (cf. _tpu_splash_attention in - # tpu_attention.py). - - return ( - backend == "tpu" - and self.per_head_dim() % splash_attention_kernel.NUM_LANES == 0 - and self._mask_fn is not None - ) - - def _logit_biases_for_mask( - self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor - ) -> Optional[Tensor]: - if self._mask_fn is None: - return None - elif mode == ForwardMode.EXTEND_STEP: - # Use biases for decoding. - return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) - elif self._is_mask_fn_used(): - # Biases are not needed in favor of mask_fn, which is supported in Splash Attention. - return None - elif self._mask_fn is causal_mask: - # Causal mode is supported natively in Flash Attention. - return None - else: - # Fall back to biases. In the subsequent _compute_attention calls, _mask_fn should not - # be used. - return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) - def _backend(self): # For compatibility with AOT compilation, we obtain the backend type from physical_mesh. global_mesh = thread_resources.env.physical_mesh @@ -137,14 +101,9 @@ def _backend(self): backend = jax.default_backend() return backend - def _logit_biases_spec(self, attention_logit_biases: Tensor) -> Tensor: - spec = self.config.mha_dim_to_partition_spec["bnts"] - if spec != PartitionSpec(None): - if attention_logit_biases.shape[0] == 1: - spec = PartitionSpec(None, *spec[1:]) - if attention_logit_biases.shape[1] == 1: - spec = PartitionSpec(spec[0], None, *spec[2:]) - return spec + def _logit_biases_spec(self, attention_logit_biases: BaseAttentionBias) -> BaseAttentionBias: + cfg = self.config + return attention_logit_biases.partition_spec(cfg.mha_dim_to_partition_spec) def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: """Repeats key or value heads dim to match the query. @@ -163,8 +122,7 @@ def _compute_attention( q_proj: Tensor, k_proj: Tensor, v_proj: Tensor, - attention_logit_biases: Optional[Tensor] = None, - segment_ids: Optional[Tensor] = None, + attention_logit_biases: BaseAttentionBias, ) -> tuple[Tensor, Tensor]: cfg = self.config backend = self._backend() @@ -176,89 +134,19 @@ def _compute_attention( batch, target_len, num_heads, _ = q_proj.shape _, source_len, _, _ = k_proj.shape - # Merge segment ids into attention_logit_biases. - if segment_ids is not None and attention_logit_biases is not None: - if q_proj.shape[1] != k_proj.shape[1]: - raise ValueError( - "segment_ids is only supported for query and key with identical lengths." - ) - attention_logit_biases = apply_attention_logit_biases( - make_segment_mask(source_segments=segment_ids, target_segments=segment_ids), - attention_logit_biases, - ) - segment_ids = None - - if attention_logit_biases is not None: - if attention_logit_biases.ndim != 4: - raise ValueError( - f"Expected attention_logit_biases.ndim == 4, got {attention_logit_biases.ndim}" - ) - bias_shape = attention_logit_biases.shape - if (bias_shape[0] != 1 and bias_shape[0] != batch) or ( - bias_shape[1] != 1 and bias_shape[1] != num_heads - ): - raise ValueError( - "attention_logit_biases must broadcast to " - f"{(batch, num_heads, target_len, source_len)}, " - f"got {attention_logit_biases.shape}." - ) - attention_logit_biases = attention_logit_biases.astype(q_proj.dtype) - - if attention_logit_biases is None or self._mask_fn is causal_mask: - mask_fn = self._mask_fn - else: - mask_fn = None - - # During GPU decoding, fall back to plain MHA implementation - # since the seq_len will not be divisible by block size. - # For prefill, seq_len can be > 1 and logit biases may not always be provided, - # so we retain `mask_fn`. - # For decoding, seq_len = 1 and logit biases are always provided, - # so we set `mask_fn` to None. - if q_proj.shape[1] % 128 != 0: - backend = "xla" - # TODO(senyut): Implement FlashDecoding kernel and support TPU decoding. - if q_proj.shape[1] == 1: - mask_fn = None - elif backend == "gpu" and q_proj.shape[1] != k_proj.shape[1]: - # TODO(xuan-zou): Generalize GPU Flash Attention for q_len != kv_len. - # Remove pytest.skip corresponding to q_len != kv_len in layer_test.py once fixed. - raise NotImplementedError( - f"Query length {q_proj.shape[1]} must be equal to KV length " - f"{k_proj.shape[1]} for correctly supported GPU flash attention usage." - ) - - if backend == "tpu": - assert q_proj.shape[1] % cfg.tpu_block_size == 0, ( - f"Target seq len {q_proj.shape[1]} must be " - f"divisible by block size {cfg.tpu_block_size}." - ) - assert k_proj.shape[1] % cfg.tpu_block_size == 0, ( - f"Source seq len {k_proj.shape[1]} must be " - f"divisible by block size {cfg.tpu_block_size}." - ) + attention_logit_biases = attention_logit_biases.astype(q_proj.dtype) jit_attn: MultiHeadAttentionImpl = flash_attention_implementation( backend=backend, - mask=mask_fn, softmax_scale=1.0, block_size=cfg.tpu_block_size, ) - q_spec = cfg.mha_dim_to_partition_spec["btnh"] - segment_ids_spec = ( - PartitionSpec(q_spec[0], q_spec[1]) - if q_spec != PartitionSpec(None) - else PartitionSpec(None) + attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases) + attention_logit_biases = with_sharding_constraint( + attention_logit_biases, attention_logit_biases_spec ) - attention_logit_biases_spec = cfg.mha_dim_to_partition_spec["bnts"] - if attention_logit_biases is not None: - attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases) - attention_logit_biases = with_sharding_constraint( - attention_logit_biases, attention_logit_biases_spec - ) - # Scale query and key. q_proj = self.scale_query(q_proj) k_proj = self.scale_key(k_proj) @@ -268,14 +156,6 @@ def _compute_attention( k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"]) v_proj = with_sharding_constraint(v_proj, cfg.mha_dim_to_partition_spec["bsnh"]) - if segment_ids is not None: - if segment_ids.shape[0] != q_proj.shape[0]: - raise ValueError( - "segment_ids must have matching batch dim: " - f"{segment_ids.shape} vs. {q_proj.shape[0]}" - ) - segment_ids = with_sharding_constraint(segment_ids, segment_ids_spec) - # We need to manually partition pallas | jax-triton calls. # Note: shard_map doesn't support kwargs. partitioned_mha = shard_map( @@ -288,8 +168,6 @@ def _compute_attention( cfg.mha_dim_to_partition_spec["bsnh"], # Bias that can broadcast to [batch_size, num_heads, seq_len, seq_len]. attention_logit_biases_spec, - # Segment IDs [batch_size, seq_len]. - segment_ids_spec, ), # O [batch_size, seq_len, num_heads, per_head_dim]. out_specs=cfg.mha_dim_to_partition_spec["btnh"], @@ -299,7 +177,7 @@ def _compute_attention( ) outputs = with_sharding_constraint( - partitioned_mha(q_proj, k_proj, v_proj, attention_logit_biases, segment_ids), + partitioned_mha(q_proj, k_proj, v_proj, attention_logit_biases), cfg.output_dim_to_partition_spec["btnh"], ) diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index e04d5e3de..e2bf076e0 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -1,11 +1,15 @@ # Copyright © 2023 Apple Inc. """Tests FlashAttention layers.""" - +# pylint: disable=ungrouped-imports import math import os from unittest import mock +from jax.sharding import PartitionSpec + +from axlearn.common.utils import Tensor + # Due to reference layer using XLA, # set the following environment variables to avoid OOM in GPU tests. # pylint: disable=wrong-import-position @@ -20,9 +24,11 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh -from axlearn.common.attention import ( - GroupedQueryAttention, - apply_attention_logit_biases, +from axlearn.common.attention import GroupedQueryAttention, apply_attention_logit_biases +from axlearn.common.attention_bias import ( + CompositeAttentionBias, + SegmentIdAttentionBias, + TensorAttentionBias, bool_to_bias, make_causal_biases, sliding_window_causal_mask, @@ -72,8 +78,9 @@ def _fake_inputs( jax.random.PRNGKey(3), p=0.5, shape=[batch, num_heads, query_len, kv_len] ) bias = bool_to_bias(bias) + bias = TensorAttentionBias(bias) else: - bias = None + bias = CompositeAttentionBias([]) if use_segment_ids: segment_ids = jnp.ones([batch, kv_len], dtype=jnp.int32) else: @@ -314,6 +321,16 @@ def test_backend(self, batch, seq_len, num_heads, per_head_dim, mesh, mesh_axis_ def test_shard_biases(self, batch, seq_len, num_heads, per_head_dim, mesh, mesh_axis_names): if not is_supported_mesh_shape(mesh): pytest.skip(reason=f"Unsupported mesh {mesh}.") + + def as_tensor_bias(bias: Tensor) -> CompositeAttentionBias: + return CompositeAttentionBias([TensorAttentionBias(bias)]) + + def as_partition_spec(pytree: CompositeAttentionBias) -> PartitionSpec: + self.assertIsInstance(pytree, CompositeAttentionBias) + pytree = jax.tree.leaves(pytree) + self.assertLen(pytree, 1) + return next(iter(pytree)) + with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names): test_layer, _, _, _ = _prepare_layers( num_heads=num_heads, @@ -323,18 +340,32 @@ def test_shard_biases(self, batch, seq_len, num_heads, per_head_dim, mesh, mesh_ sliding_window_size=None, ) bias = jnp.ones((batch, num_heads, seq_len, seq_len)) + bias = as_tensor_bias(bias) spec = test_layer._logit_biases_spec(bias) # pylint: disable=protected-access + spec = as_partition_spec(spec) self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["bnts"]) bias = jnp.ones((batch, 1, seq_len, seq_len)) + bias = as_tensor_bias(bias) spec = test_layer._logit_biases_spec(bias) # pylint: disable=protected-access + spec = as_partition_spec(spec) self.assertEqual(spec[1], None) bias = jnp.ones((1, 1, seq_len, seq_len)) + bias = as_tensor_bias(bias) spec = test_layer._logit_biases_spec(bias) # pylint: disable=protected-access + spec = as_partition_spec(spec) self.assertEqual(spec[0], None) self.assertEqual(spec[1], None) + segment_ids = CompositeAttentionBias( + [SegmentIdAttentionBias(jnp.ones((batch, seq_len)))] + ) + spec = test_layer._logit_biases_spec(segment_ids) # pylint: disable=protected-access + spec = as_partition_spec(spec) + self.assertIsInstance(spec, PartitionSpec) + self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["btnh"][:2]) + @parameterized.product( _TEST_CONFIGS, query_len_multiplier=[0.5, 1, 2], @@ -572,6 +603,11 @@ def test_extend_step( causal, sliding_window_size, ): + print( + f"batch={batch}, seq_len={seq_len} (ignored->16), num_heads={num_heads}, \n" + f"per_head_dim={per_head_dim}, mesh={mesh}, mesh_axis_names={mesh_axis_names}, \n" + f"causal={causal}, sliding_window_size={sliding_window_size}" + ) # Limit generation length to 16 to save test time. seq_len = 16 dtype = jnp.bfloat16 diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index dff777958..2335498ee 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -2,7 +2,7 @@ """Wrappers for FlashAttention on TPU in JAX with logit bias support.""" import functools -from typing import Optional +from typing import Optional, Union import jax import jax.numpy as jnp @@ -31,7 +31,13 @@ splash_attention_mask, ) -from axlearn.common.attention import MaskFn, apply_attention_logit_biases, bool_to_bias, causal_mask +from axlearn.common.attention import apply_attention_logit_biases +from axlearn.common.attention_bias import ( + CausalAttentionBias, + MaskFnAttentionBias, + ZeroAttentionBias, + as_attention_bias, +) from axlearn.common.utils import Tensor @@ -42,7 +48,7 @@ def tpu_flash_attention( bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] segment_ids: Tensor = None, # [batch_size, source_len] *, - mask: Optional[MaskFn] = None, + mask: Optional[MaskFnAttentionBias] = None, softmax_scale: float = 1.0, block_size: int = 128, ): @@ -54,10 +60,12 @@ def tpu_flash_attention( 1. within-kernel ordering of attention-bias addition and softmax scaling differ to axlearn, 2. it's more efficient to scale outside the kernel vs. fix order of ops in kernel. + If provided, bias, segment_ids, and mask are applied on top of one another. + Args: query: The query tensor, of shape [batch_size, source_len, num_heads, head_dim]. key: The key tensor, of shape [batch_size, target_len, num_heads, head_dim]. - value: The value tensor, of shape [batch_size, target_len, num_heads, head_dim]. + value: The value tensor, of shape [batch_size, source_len, num_heads, head_dim]. bias: The attention biases, can broadcast to shape [batch_size, num_heads, source_len, target_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed @@ -72,7 +80,9 @@ def tpu_flash_attention( Raises: NotImplementedError: If no implementation with support for the arguments is found. - ValueError: If the head_dim of the query, key, and value are not all equal.""" + ValueError: If the head_dim of the query, key, and value are not all equal. + ValueError: if the target or source sequence length is not divisible by block_size.` + """ if segment_ids is not None: assert query.shape[1] == key.shape[1] and query.shape[1] == value.shape[1] # Apply the softmax scale outside the kernel (see docstring for why). @@ -92,6 +102,17 @@ def tpu_flash_attention( f"{head_dim} != {value.shape[3]}" ) + if query.shape[1] % block_size != 0: + raise ValueError( + f"Target seq len {query.shape[1]} must be divisible by block size {block_size}." + ) + if key.shape[1] % block_size != 0: + raise ValueError( + f"Source seq len {key.shape[1]} must be divisible by block size {block_size}." + ) + + mask = as_attention_bias(mask) + # Switch num_heads and seq_len axes. query = jnp.einsum("btnh->bnth", query) key = jnp.einsum("bsnh->bnsh", key) @@ -129,13 +150,7 @@ def tpu_flash_attention( block_q_dq=block_size, ) context = _legacy_tpu_flash_attention( - query, - key, - value, - bias, - segment_ids=segment_ids, - mask=mask, - block_sizes=block_sizes, + query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes ) logging.warning( "Falling back to legacy flash attention because SplashAttention is not supported.\n" @@ -161,11 +176,13 @@ def _legacy_tpu_flash_attention( bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] segment_ids: Tensor = None, # [batch_size, source_len] *, - mask: Optional[MaskFn] = None, + mask: MaskFnAttentionBias, block_sizes: Optional[LegacyBlockSizes] = None, ) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. """Wraps JAX's legacy TPU flash-attention. + If provided, bias, segment_ids, and mask are applied on top of one another. + Args: query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. @@ -184,13 +201,9 @@ def _legacy_tpu_flash_attention( Raises: NotImplementedError: If a custom (non-causal, non-full) mask is specified. """ - causal = mask is causal_mask - if mask is not None and not causal: - rows = jnp.arange(0, query.shape[2]) - cols = jnp.arange(0, key.shape[2]) - bias = apply_attention_logit_biases( - bool_to_bias(mask(rows[:, None], cols[None, :]))[None, None, :, :], bias - ) + causal = isinstance(mask, CausalAttentionBias) + if not causal and mask.value() is not None: + bias = apply_attention_logit_biases(mask.value(), bias) context = pallas_tpu_flash_attention( q=query, @@ -214,77 +227,88 @@ class SplashAttentionUnsupportedError(NotImplementedError): @functools.partial( jax.jit, - static_argnames=[ - "mask", # Mask objects don't actually contain jax arrays, so they are static. - "block_sizes", - ], + static_argnames=["block_sizes"], ) def _tpu_splash_attention( - query: Tensor, # [batch_size, num_heads, source_len, head_dim] - key: Tensor, # [batch_size, num_heads, target_len, head_dim] - value: Tensor, # [batch_size, num_heads, target_len, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] + query: Tensor, # [batch_size, num_heads, target_len, head_dim] + key: Tensor, # [batch_size, num_heads, source_len, head_dim] + value: Tensor, # [batch_size, num_heads, source_len, head_dim] + bias: Optional[Tensor] = None, # [batch_size, num_heads, target_len, source_len] + segment_ids: Optional[Tensor] = None, # [batch_size, target_len] *, - mask: Optional[MaskFn] = None, + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, -) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. +) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's sparse TPU flash-attention. Args: - query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. - key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. + query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, source_len, target_len]. + bias: The attention biases, of shape [batch_size, num_heads, target_len, source_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed between tokens in different segments. - Shape: [batch_size, source_len]. + Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. Returns: - The context tensor, of shape [batch_size, num_heads, source_len, head_dim]. + The context tensor, of shape [batch_size, num_heads, target_len, head_dim]. Raises: - NotImplementedError: If a bias is also specified or the head_dim is not divisible by - 128. + SplashAttentionUnsupportedError: If splash attention does not support the given arguments. + This happens if any of the following is true: + - bias is not None. + - The per_head_dim is not divisible by 128. + - segment_ids is not None. + - The source and target lengths are different and a nonzero mask is used. + TypeError: If mask is not an instance of `MaskFnAttentionBias. """ - source_len = query.shape[2] - target_len = key.shape[2] + target_len = query.shape[2] + source_len = key.shape[2] num_heads = query.shape[1] head_dim = query.shape[3] if bias is not None: raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.") - if head_dim % splash_attention_kernel.NUM_LANES != 0: - raise SplashAttentionUnsupportedError( - "SplashAttention requires " - f"head_dim=={splash_attention_kernel.NUM_LANES}, " - f"got {head_dim}." - ) + with jax.ensure_compile_time_eval(): + if jnp.any( + jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0 + ): + raise SplashAttentionUnsupportedError( + "SplashAttention requires target_len, source_len, head_dim are divisible by" + f" {splash_attention_kernel.NUM_LANES}, got {target_len, source_len, head_dim}." + ) if segment_ids is not None: raise SplashAttentionUnsupportedError( "The public API for SplashAttention that we " "currently use does not support segment ids." ) - if source_len != target_len and mask is not None: + if target_len != source_len and mask.value() is not None: raise SplashAttentionUnsupportedError( "Query and key/value must have same length when mask is used." ) + if mask.value() is not None and not isinstance(mask, MaskFnAttentionBias): + raise TypeError(type(mask)) + if mask.value() is not None and isinstance(mask.target_positions, jax.core.Tracer): + raise SplashAttentionUnsupportedError( + "Non-static value of `target_positions` is not supported.\n" + "Are you decoding using SplashAttention? That's not supported." + ) - mask_shape = (source_len, target_len) - if mask is None: + mask_shape = (target_len, source_len) + if mask.value() is None: mask = splash_attention_mask.FullMask(mask_shape) else: - # Use fewer bytes for the mask. - rows = np.arange(source_len, dtype=np.int32) - cols = np.arange(target_len, dtype=np.int32) with jax.ensure_compile_time_eval(): - mask_array = np.asarray(mask(rows[:, None], cols[None, :])) + # MaskFn always supports compile time eval. + mask_array = np.asarray(mask.bool_value()) + # Squeeze first two leading dimensions. + mask_array = mask_array.reshape(mask_array.shape[-2:]) - # NumpyMask is backed by a dense [source_len, target_len] numpy array. + # NumpyMask is backed by a dense [target_len, source_len] numpy array. # May consume a large amount of host memory for long sequences at compile time. mask = splash_attention_mask.NumpyMask(array=mask_array) diff --git a/axlearn/common/flash_attention/tpu_attention_benchmark.py b/axlearn/common/flash_attention/tpu_attention_benchmark.py index e203d27b4..379048bc7 100644 --- a/axlearn/common/flash_attention/tpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/tpu_attention_benchmark.py @@ -38,7 +38,13 @@ import jax import jax.numpy as jnp -from axlearn.common.attention import causal_mask, sliding_window_causal_mask +from axlearn.common.attention_bias import ( + CompositeAttentionBias, + MaskFnAttentionBias, + TensorAttentionBias, + causal_mask, + sliding_window_causal_mask, +) from axlearn.common.flash_attention.utils import flash_attention_implementation, mha_reference _BENCHMARK_CONFIGS = { @@ -132,18 +138,23 @@ def _benchmark( mask = causal_mask elif causal: mask = sliding_window_causal_mask(sliding_window_size) + mask = MaskFnAttentionBias(mask, shape=(seq_len, seq_len)) + if use_bias: + bias = CompositeAttentionBias([mask, TensorAttentionBias(bias)]) + else: + bias = CompositeAttentionBias([mask]) # Get fwd & bwd timing information when softmax scaling applied before calling the kernel. mha_impl = flash_attention_implementation( - "tpu", mask=mask, softmax_scale=softmax_scale, block_size=block_size + "tpu", softmax_scale=softmax_scale, block_size=block_size ) - flash_fwd_time = _time_call(lambda: mha_impl(q, k, v, bias, segment_ids)) + flash_fwd_time = _time_call(lambda: mha_impl(q, k, v, bias)) flash_grad_fn = jax.jit( - jax.grad(lambda q, k, v, b, s: mha_impl(q, k, v, b, s).mean(), argnums=(0, 1, 2)) + jax.grad(lambda q, k, v, b: mha_impl(q, k, v, b).mean(), argnums=(0, 1, 2)) ) - flash_bwd_time = _time_call(lambda: flash_grad_fn(q, k, v, bias, segment_ids)[0]) + flash_bwd_time = _time_call(lambda: flash_grad_fn(q, k, v, bias)[0]) print(f"ref_fwd:{ref_fwd_time:.4f}s, flash_fwd:{flash_fwd_time:.4f}s") print(f"ref_bwd:{ref_bwd_time:.4f}s, flash_bwd:{flash_bwd_time:.4f}s\n") diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 1ba38d7a4..14c3dea7f 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -16,7 +16,11 @@ from jax.interpreters.pxla import thread_resources from jax.sharding import Mesh, NamedSharding, PartitionSpec -from axlearn.common.attention import causal_mask, sliding_window_causal_mask +from axlearn.common.attention_bias import ( + MaskFnAttentionBias, + causal_mask, + sliding_window_causal_mask, +) from axlearn.common.flash_attention import tpu_attention from axlearn.common.flash_attention.utils import mha_reference from axlearn.common.test_utils import TestCase, is_supported_mesh_shape @@ -32,7 +36,7 @@ def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: The mask is the same as `causal_mask`. However, this implementation requires specially handling to use with - SplashAttention since `tpu_flash-attention()` needs to wrap this function + SplashAttention since `tpu_flash_attention()` needs to wrap this function to return numpy values if the input is numpy. (Otherwise we get tracer errors in jit.) """ return jnp.greater_equal(query_position, key_position) @@ -117,7 +121,9 @@ def fn(q, k, v): ) softmax_scale = q.shape[-1] ** -0.5 - mask = sliding_window_causal_mask(sliding_window_size) + mask = MaskFnAttentionBias( + sliding_window_causal_mask(sliding_window_size), shape=(seq_len, seq_len) + ) attn = lambda q, k, v: tpu_attention.tpu_flash_attention( q, k, v, mask=mask, softmax_scale=softmax_scale @@ -171,6 +177,12 @@ def test_forward_and_backward( or with_segment_ids or (query_length_multiplier != 1 and mask is not None) ) + print( + f"{batch_size=}, {kv_len=}, {num_heads=}, \n" + f"{per_head_dim=}, {query_length_multiplier=}, {mask=}, \n" + f"{attention_bias_type=}, {with_segment_ids=} \n" + f"{causal=}, {fallback_to_legacy=}" + ) if fallback_to_legacy and mask is jax_fn_mask: pytest.skip("Custom masks are not supported by legacy attention.") @@ -203,6 +215,9 @@ def ref_fn(q, k, v, bias, ids): legacy_flash_wrapper = unittest.mock.Mock(wraps=tpu_attention._legacy_tpu_flash_attention) + if mask is not None: + mask = MaskFnAttentionBias(mask, shape=(query_len, kv_len)) + def fn(q, k, v, bias, ids): record_legacy_call = unittest.mock.patch.object( tpu_attention, "_legacy_tpu_flash_attention", legacy_flash_wrapper diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 7da0543f1..85af50aeb 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -8,7 +8,18 @@ import jax.numpy as jnp from absl import logging -from axlearn.common.attention import NEG_INF, MaskFn, causal_mask, softmax_with_biases +from axlearn.common.attention import softmax_with_biases +from axlearn.common.attention_bias import ( + NEG_INF, + BaseAttentionBias, + CausalAttentionBias, + CompositeAttentionBias, + MaskFnAttentionBias, + SegmentIdAttentionBias, + TensorAttentionBias, + ZeroAttentionBias, + split, +) from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention @@ -38,7 +49,6 @@ def mha_reference( segment_ids: segment ids tensor with shape [batch_size, seq_len]. causal: whether the attention is causal. softmax_scale: a scalar value applied to the logits before softmax. - bias_type: the type of bias to apply. "matrix" for matrix bias, "vector" for additive bias. Returns: A tensor with shape [batch_size, seq_len, num_heads, per_head_dim]. @@ -77,7 +87,6 @@ def mha_reference( def flash_attention_implementation( backend: Literal["cpu", "tpu", "gpu", "xla"], *, - mask: Optional[MaskFn] = None, softmax_scale: float, block_size: int = 128, ) -> MultiHeadAttentionImpl: @@ -85,8 +94,6 @@ def flash_attention_implementation( Args: backend: A valid XLA backend name. 'cpu' intended for testing only. - mask: A mask to use when computing the attention. This allows for more efficient - computation than setting bias = -inf on certain backends. softmax_scale: A scalar value applied to the logits before softmax. block_size: The size of the computation-block unit, only applies to the 'tpu' backend. A multiple of 128, and should be less than the target sequence length. @@ -98,24 +105,71 @@ def flash_attention_implementation( Raises: NotImplementedError: If implementation for the backend is not available. """ - causal = mask is causal_mask - if mask is not None and not causal and backend != "tpu": - raise NotImplementedError( - "Custom (non-causal, non-full) mask only supported on TPU.\n" - "You can use NEG_INF biases instead, but it won't " - "have the sparsity optimizations." - ) - if backend == "gpu": - # shard_map-decorated function needs to be jitted. - @jax.jit - def jit_attn(query, key, value, bias, segment_ids): + + # shard_map-decorated function needs to be jitted. + @jax.jit + def jit_attn( + query: Tensor, + key: Tensor, + value: Tensor, + bias: BaseAttentionBias, + *, + backend: str = backend, + ) -> Tensor: + # Fall back to plain MHA implementation when the seq_len is not be divisible by + # block size. + if query.shape[1] % block_size != 0: + backend = "xla" + # For decoding, fall back to non-flash implementation and merge all biases + # into a dense floating point bias tensor since that implementation does not + # support target_positions. + if query.shape[1] == 1: + # TODO(senyut): Implement FlashDecoding kernel and support TPU decoding. + backend = "xla" + bias = TensorAttentionBias(bias.value()) + + bias = CompositeAttentionBias([bias]) + + def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: + """Return the segment ids Tensor from the sequence of segment ids attention + biases or None if there are no segment ids. + """ + if segment_ids is None or segment_ids.value() is None: + return None + if query.shape[1] != key.shape[1]: + raise ValueError( + "segment_ids is only supported for query and key with identical lengths." + ) + if segment_ids.eval_shape()[0] != query.shape[0]: + raise ValueError( + "segment_ids must have matching batch dim: " + f"{segment_ids.eval_shape()} vs. {query.shape[0]}" + ) + return segment_ids.segment_ids + + if backend == "gpu": + if query.shape[1] != key.shape[1]: + # TODO(xuan-zou): Generalize GPU Flash Attention for q_len != kv_len. + # Remove pytest.skip corresponding to q_len != kv_len in layer_test.py once fixed. + raise NotImplementedError( + f"Query length {query.shape[1]} must be equal to KV length " + f"{key.shape[1]} for correctly supported GPU flash attention usage." + ) + + # We have two implementations to choose from. + # Both support `causal`. + # One supports `segment_ids`. + causal, segment_ids, explicit_bias = split( + bias, CausalAttentionBias, SegmentIdAttentionBias + ) + # Fall back to triton gpu kernel if: - # - segment_ids is not None, - # - bias is not None, - # - query/key/value are in float32. + # - segment_ids is not None, or + # - explicit_bias is not empty, or + # - query/key/value is in float32. if ( - segment_ids is not None - or bias is not None + segment_ids.value() is not None + or explicit_bias.value() is not None or jnp.float32 in (query.dtype, key.dtype, value.dtype) ): logging.warning("Flash attention falling back to Triton GPU kernel.") @@ -123,61 +177,65 @@ def jit_attn(query, key, value, bias, segment_ids): query, key, value, - bias=bias, - segment_ids=segment_ids, + bias=explicit_bias.value(), + segment_ids=get_segment_ids(segment_ids), softmax_scale=softmax_scale, - causal=causal, + causal=causal.value() is not None, ) else: + explicit_bias += segment_ids return cudnn_dot_product_attention( query, key, value, - bias=bias, + bias=explicit_bias.value(), softmax_scale=softmax_scale, - causal=causal, + causal=causal.value() is not None, dropout_rate=0.0, ) - return jit_attn - - elif backend == "tpu": - # shard_map-decorated function needs to be jitted. - @jax.jit - def jit_attn(query, key, value, bias, segment_ids): - context = tpu_flash_attention( + elif backend == "tpu": + # `mask` is supported. + # `segment_ids` is supported. + # Optimized handling for the above two types. + # Fallback for types that aren't instances of either of the above. + mask, segment_ids, explicit_bias = split( + bias, MaskFnAttentionBias, SegmentIdAttentionBias + ) + return tpu_flash_attention( query, key, value, - bias=bias, - segment_ids=segment_ids, - mask=mask, + bias=explicit_bias.value(), + segment_ids=get_segment_ids(segment_ids), + # The `from_sequence()` function guarantees that if there is only one + # mask, it is returned without modification. + # This allows the `causal` path in `_legacy_tpu_flash_attention()` to work. + mask=mask if not isinstance(mask, ZeroAttentionBias) else None, softmax_scale=softmax_scale, block_size=block_size, ) - return context - - return jit_attn - elif backend in ("cpu", "xla"): - if backend == "cpu": - logging.warning("Flash attention CPU backend is for testing only.") - logging.warning("Flash attention falling back using plain MHA implementation") + elif backend in ("cpu", "xla"): + if backend == "cpu": + logging.warning("Flash attention CPU backend is for testing only.") + logging.warning("Flash attention falling back using plain MHA implementation") - # shard_map-decorated function needs to be jitted. - @jax.jit - def jit_attn(query, key, value, bias, segment_ids): + # `causal` is supported. + # `segment_ids` is supported. + causal, segment_ids, explicit_bias = split( + bias, CausalAttentionBias, SegmentIdAttentionBias + ) return mha_reference( query, key, value, - bias=bias, - segment_ids=segment_ids, - causal=causal, + bias=explicit_bias.value(), + segment_ids=get_segment_ids(segment_ids), + causal=causal.value() is not None, softmax_scale=softmax_scale, ) - return jit_attn - - else: raise NotImplementedError(f"Backend ({backend}) does not have an implementation.") + + return jit_attn diff --git a/axlearn/common/metrics_text_dual_encoder.py b/axlearn/common/metrics_text_dual_encoder.py index 3b0423c52..65a751f75 100644 --- a/axlearn/common/metrics_text_dual_encoder.py +++ b/axlearn/common/metrics_text_dual_encoder.py @@ -5,7 +5,7 @@ from jax import numpy as jnp -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF from axlearn.common.config import REQUIRED, Required, config_class from axlearn.common.evaler import GlobalMetricCalculator, PredictionOutputs from axlearn.common.loss import contrastive_logits diff --git a/axlearn/common/multiway_transformer_test.py b/axlearn/common/multiway_transformer_test.py index 7c1dcba76..9df7a9b58 100644 --- a/axlearn/common/multiway_transformer_test.py +++ b/axlearn/common/multiway_transformer_test.py @@ -9,14 +9,13 @@ from absl.testing import absltest, parameterized from axlearn.common.attention import ( - NEG_INF, RepeatedTransformerLayer, StackedTransformerLayer, TransformerFeedForwardLayer, TransformerLayer, build_remat_spec, - make_causal_biases, ) +from axlearn.common.attention_bias import NEG_INF, make_causal_biases from axlearn.common.module import functional as F from axlearn.common.multiway_transformer import ( IMAGE_MODALITY, diff --git a/axlearn/common/poolings.py b/axlearn/common/poolings.py index d01b7c95e..a228b4d9a 100644 --- a/axlearn/common/poolings.py +++ b/axlearn/common/poolings.py @@ -15,12 +15,11 @@ import jax.numpy as jnp from axlearn.common.attention import ( - NEG_INF, TransformerAttentionLayer, TransformerFeedForwardLayer, - make_segment_mask, scaled_hidden_dim, ) +from axlearn.common.attention_bias import NEG_INF, make_segment_mask from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class from axlearn.common.layers import Linear diff --git a/axlearn/common/splade.py b/axlearn/common/splade.py index a008df5ee..e88c001d6 100644 --- a/axlearn/common/splade.py +++ b/axlearn/common/splade.py @@ -8,7 +8,7 @@ import jax.numpy as jnp -from axlearn.common.attention import NEG_INF +from axlearn.common.attention_bias import NEG_INF from axlearn.common.bert import BertLMHead from axlearn.common.config import config_class from axlearn.common.layers import BaseClassificationHead, RedirectToSharedModule, get_activation_fn diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index ab666e62d..6ce8bb49e 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -27,7 +27,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig -from axlearn.common.attention import make_causal_biases +from axlearn.common.attention_bias import make_causal_biases from axlearn.common.config import InstantiableConfig from axlearn.common.module import functional as F from axlearn.common.ssm import ( diff --git a/axlearn/vision/attention.py b/axlearn/vision/attention.py index 4b96db95b..7b61d264d 100644 --- a/axlearn/vision/attention.py +++ b/axlearn/vision/attention.py @@ -21,9 +21,9 @@ MultiheadAttention, TransformerAttentionLayer, apply_attention_logit_biases, - make_segment_mask, softmax_with_biases, ) +from axlearn.common.attention_bias import make_segment_mask from axlearn.common.base_layer import ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, config_class from axlearn.common.layers import get_stochastic_depth_linear_rate