Skip to content

Commit

Permalink
Merge branch 'main' of github.pie.apple.com:foundation-models/axlearn…
Browse files Browse the repository at this point in the history
… into rpang_scheduler
  • Loading branch information
ruomingp committed Dec 9, 2024
2 parents c38eb19 + 53f2cbb commit bd633d9
Show file tree
Hide file tree
Showing 24 changed files with 1,387 additions and 554 deletions.
5 changes: 3 additions & 2 deletions axlearn/common/adapter_torch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from absl.testing import absltest, parameterized
from jax import numpy as jnp

from axlearn.common import attention_bias
from axlearn.common.adapter_torch import (
NEG_INF,
AdapterCausalLmModelBuilder,
Expand Down Expand Up @@ -355,8 +356,8 @@ def test_transformer_attention_layer_forward(self, structure: str, norm: str):

rng = np.random.RandomState(123)
target = rng.randn(2, 7, target_dim).astype(np.float32)
attention_logit_biases = np.zeros(target.shape[:-1]).astype(bool)[:, :, None]
attention_logit_biases[:, -2:] = True
attention_logit_biases = np.zeros(target.shape[:-1]).astype(float)[:, :, None]
attention_logit_biases[:, -2:] = attention_bias.NEG_INF
torch_inputs = {
"target": torch.as_tensor(target),
"attention_logit_biases": torch.as_tensor(attention_logit_biases),
Expand Down
259 changes: 43 additions & 216 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
"""Attention layers with pjit partition specs.
On `attention_logit_biases`:
* A biases tensor can have one of the following shapes:
* For methods that take a tensor, a biases Tensor can have one of the following shapes:
* [target_length, source_length]
* [batch, target_length, source_length]
* [batch, num_heads, target_length, source_length].
* Each value represents a bias to be added to the attention logits
(therefore a -inf represents a disconnected position pair).
* biases=None represents an all-zero tensor, i.e., all position pairs are connected.
* For methods that take a BaseAttentionBias, the value() will always be None or a 4d Tensor with
the above semantics.
TODO(apghml) Convert everything to take an instance of BaseAttentionBias rather than a Tensor.
On `segment_ids`:
* A tensor of shape [batch, target_length] with values in [0, num_segments].
Expand All @@ -65,6 +69,17 @@
from jax.core import Primitive

from axlearn.common import ops, param_init
from axlearn.common.attention_bias import (
NEG_INF,
BaseAttentionBias,
CausalAttentionBias,
MaskFn,
MaskFnAttentionBias,
SegmentIdAttentionBias,
as_attention_bias,
causal_mask,
make_segment_mask,
)
from axlearn.common.base_layer import (
BaseLayer,
FactorizationSpec,
Expand Down Expand Up @@ -116,8 +131,6 @@
split_prng_key,
)

NEG_INF = -1e15


class ForwardMode(enum.Enum):
"""ForwardMode describes the type of computation to be done in a forward pass through a layer.
Expand Down Expand Up @@ -303,66 +316,6 @@ def extend_step(
raise NotImplementedError(type(self))


def make_causal_biases(seq_len: int) -> Tensor:
"""Generates attention logit biases for causal masking.
Args:
seq_len: Sequence length.
Returns:
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf if i < j,
0 otherwise.
"""
# TODO(sneha): support batching
return bool_to_bias(causal_mask(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))


def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor:
"""Generates attention logit biases for sliding window attention.
Args:
seq_len: Sequence length.
Returns:
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf
if i - j > sliding_window_size or i < j, 0 otherwise.
"""
mask_fn = sliding_window_causal_mask(sliding_window_size)
return bool_to_bias(mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))


def bool_to_bias(mask: Tensor) -> Tensor:
"""Converts a bool mask tensor to a bias mask tensor.
Maps:
0 -> -NEG_INF
1 -> 0.
"""
if mask.dtype != jnp.bool:
raise ValueError("mask must be a Boolean tensor.")
return (~mask) * NEG_INF


def make_segment_mask(*, source_segments: Tensor, target_segments: Tensor) -> Tensor:
"""Generates attention logit biases given the segment ids.
... such that positions belonging to different segments cannot attend to each other.
Args:
source_segments: An integer tensor of shape [batch, ..., source_length].
target_segments: An integer tensor of shape [batch, ..., target_length].
Returns:
A float Tensor of shape [batch, 1, ..., target_length, source_length] where the
value at [..., i, j] = 0 if target_segments[..., i] == source_segments[..., j], or -inf
otherwise.
"""
target_segments = jnp.expand_dims(target_segments, -1)
source_segments = jnp.expand_dims(source_segments, -2)
res = (jax.lax.ne(source_segments, target_segments) * NEG_INF)[:, None, ...]
return res


class LearnedPositionalEmbedding(BaseLayer):
"""TODO(ruoming): Remove LearnedPositionalEmbedding. We can just use the Embedding layer."""

Expand Down Expand Up @@ -1555,83 +1508,6 @@ def default_scale_factor_config() -> InstantiableConfig[ScaleFn]:
return config_for_function(constant_scale_fn).set(value=1)


class MaskFn(Protocol):
"""A broadcastable function for computing a boolean logit mask."""

def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor:
"""Returns a bool Tensor of whether the query token at `query_position` should attend
to the key token at `key_position`.
Implementations have the following contract:
* Must support scalar arguments.
* If given non-scalar arguments of the same shape, the result must be the same as
applying the function elementwise over these arugments. I.e.,
```
x = f(jnp.asarray([1,2]), jnp.asarray([3,4]))
assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None]
```
* If given non-scalar arguments of different shapes, the result must be the same if we
first broadcast the arguments against each other to make them have the same shape.
* Beyond requiring broadcastability, must not impose any constraints on the shapes of its
arguments.
Args:
query_position: The index in the sequence of query vectors.
key_position: The index in the sequence of key vectors.
Returns:
Whether the query and key vectors with the given index should attend to one another.
True means they should attend. False means they should not.
The shape is the same as the shape obtained after broadcasting the inputs against each
other.
"""


def _composite_masks(op: Callable[[Tensor, Tensor], Tensor], *mask_fns: ConfigOr[MaskFn]):
if len(mask_fns) == 0:
raise RuntimeError(f"Input must not be empty: {mask_fns}")

def mask(query_position: Tensor, key_position: Tensor):
fns = [maybe_instantiate(arg) for arg in mask_fns]
result = fns[0](query_position, key_position)
for mask in fns[1:]:
result = op(result, mask(query_position, key_position))
return result

return mask


def or_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
"""Returns a MaskFn that's the union of provided MaskFn's."""
return _composite_masks(jnp.logical_or, *mask_fns)


def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
"""Returns a MaskFn that's the intersection of provided MaskFn's."""
return _composite_masks(jnp.logical_and, *mask_fns)


def causal_mask(query_position: Tensor, key_position: Tensor) -> Tensor:
"""Returns the given entry of a causal attention mask.
Implements the `MaskFn` protocol.
See that and `MultiheadAttention.Config.mask`.
"""
return query_position >= key_position


def sliding_window_causal_mask(sliding_window_size: int):
"""Returns a causal MaskFn for sliding window attentions of a given window size.
Implements the `MaskFn` protocol.
"""

def mask(query_position: Tensor, key_position: Tensor):
return query_position - key_position <= sliding_window_size

return and_masks(causal_mask, mask)


class MultiheadAttention(BaseLayer):
"""A basic multi-head attention layer.
Expand Down Expand Up @@ -1747,7 +1623,7 @@ def _forward_for_mode(
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[KVState] = None,
attention_logit_biases: Optional[Tensor] = None,
attention_logit_biases: Union[None, Tensor, BaseAttentionBias] = None,
segment_ids: Optional[Tensor] = None,
cached_states: Optional[NestedTensor] = None,
return_aux: Optional[set[str]] = None,
Expand Down Expand Up @@ -1819,36 +1695,35 @@ def _forward_for_mode(
self.vlog(3, "atten.q_proj=%s", q_proj.sum())
self.vlog(3, "atten.k_proj=%s", k_proj.sum())
self.vlog(3, "atten.v_proj=%s", v_proj.sum())
if attention_logit_biases is not None:
if attention_logit_biases.ndim == 3:
# [batch, 1, target_length, source_length].
attention_logit_biases = attention_logit_biases[:, None, :, :]
elif attention_logit_biases.ndim == 2:
# [1, 1, target_length, source_length].
attention_logit_biases = attention_logit_biases[None, None, :, :]
elif attention_logit_biases.ndim != 4:
raise ValueError(
f"Invalid attention_logit_biases shape: {attention_logit_biases.shape}."
)
attention_logit_biases = as_attention_bias(attention_logit_biases)
if self._mask_fn is not None:
kv_pos = jnp.arange(k_proj.shape[1])[None, :] # [1, source_len]
query_pos = jnp.arange(q_proj.shape[1])[None] # [1, target_length]
target_positions = None
if mode == ForwardMode.EXTEND_STEP:
time_step = cached_states["i_proj"]["time_step"] # [B]
# [B, target_length], target_length is often 1 for decoding, but not always.
query_pos = query_pos + time_step[:, None]
mask = self._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos)
if mask is not None:
attention_logit_biases = apply_attention_logit_biases(
mask.astype(q_proj.dtype),
attention_logit_biases,
target_positions = cached_states["i_proj"]["time_step"]
if self._mask_fn is causal_mask:
# Needed for legacy flash attention implementations that don't have
# sparse mask support.
# E.g., the legacy tpu flash attention, all current gpu flash attention
# implementations.
attention_logit_biases += CausalAttentionBias(
shape=(q_proj.shape[1], k_proj.shape[1]),
target_positions=target_positions,
dtype=q_proj.dtype,
)
else:
attention_logit_biases += MaskFnAttentionBias(
self._mask_fn,
shape=(q_proj.shape[1], k_proj.shape[1]),
target_positions=target_positions,
dtype=q_proj.dtype,
)
if segment_ids is not None:
attention_logit_biases += SegmentIdAttentionBias(segment_ids)
context, probs = self._compute_attention(
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
attention_logit_biases=attention_logit_biases,
segment_ids=segment_ids,
)
self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :])
self.vlog(3, "atten.context=%s", context.sum())
Expand All @@ -1865,38 +1740,13 @@ def _forward_for_mode(
)
return dict(i_proj=i_proj_state), output

def _logit_biases_for_mask(
self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor
) -> Optional[Tensor]:
"""Returns the configured attention mask in the form of logit biases.
... or None if the implementation of _compute_attention supports applying masks natively.
Args:
mode: The forward propagation mode, chosen from
(ForwardMode.FORWARD, ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP).
query_pos: The index in the sequence of query vectors, [1|batch, target_length].
kv_pos: The index in the sequence of kv vectors, [1|batch, source_length].
Returns:
A logit bias tensor [1|batch, 1, target_length, source_length].
"""
del mode
kv_pos = kv_pos[:, None] # [1|B, 1, source_len]
query_pos = query_pos[..., None] # [1|B, target_len, 1]
# [1|B, 1, target_len, source_len]
mask = self._mask_fn(query_pos, kv_pos)[:, None]
mask = bool_to_bias(mask)
return mask

def _compute_attention(
self,
*,
q_proj: Tensor,
k_proj: Tensor,
v_proj: Tensor,
attention_logit_biases: Optional[Tensor] = None,
segment_ids: Optional[Tensor] = None,
attention_logit_biases: BaseAttentionBias,
) -> tuple[Tensor, Tensor]:
"""Computes attention context and probs.
Expand All @@ -1905,27 +1755,15 @@ def _compute_attention(
k_proj: [batch_size, source_length, num_heads, per_head_dim].
v_proj: [batch_size, source_length, num_heads, per_head_dim].
attention_logit_biases: See ``On attention logit biases`` in the file comments.
segment_ids: See ``segment_ids`` in the file comments.
Returns:
The context of shape [batch_size, target_length, num_heads, per_head_dim],
and probs of shape [batch, num_heads, target_length, source_length].
"""
# Merge segment ids into attention_logit_biases.
if segment_ids is not None:
if q_proj.shape[1] != k_proj.shape[1]:
raise ValueError(
"segment_ids is only supported for query and key with identical lengths."
)
attention_logit_biases = apply_attention_logit_biases(
make_segment_mask(source_segments=segment_ids, target_segments=segment_ids),
attention_logit_biases,
)

logits = self._compute_logits(q_proj, k_proj)
logits = self._cap_logits(logits)
self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :])
probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases)
probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases.value())
probs = self.dropout(probs)
context = self._compute_context(probs, v_proj)
context = self._remat_name(context, "context")
Expand Down Expand Up @@ -2203,34 +2041,23 @@ class SigmoidAttention(MultiheadAttention):
class Config(MultiheadAttention.Config):
"""Configures SigmoidAttention."""

seq_len: Required[int] = REQUIRED # Maximum sequence length used.
seq_len: Required[int] = REQUIRED # Maximum sequence length used.

def _compute_attention(
self,
*,
q_proj: Tensor,
k_proj: Tensor,
v_proj: Tensor,
attention_logit_biases: Optional[Tensor] = None,
segment_ids: Optional[Tensor] = None,
attention_logit_biases: BaseAttentionBias,
) -> tuple[Tensor, Tensor]:
"""See `MultiheadAttention._compute_attention` for details."""
# Merge segment ids into attention_logit_biases.
if segment_ids is not None:
if q_proj.shape[1] != k_proj.shape[1]:
raise ValueError(
"segment_ids is only supported for query and key with identical lengths."
)
attention_logit_biases = apply_attention_logit_biases(
make_segment_mask(source_segments=segment_ids, target_segments=segment_ids),
attention_logit_biases,
)

cfg = self.config
logits = self._compute_logits(q_proj, k_proj)
logits = self._cap_logits(logits)
self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :])

attention_logit_biases = attention_logit_biases.value()
if attention_logit_biases is None:
attention_logit_biases = 0
# To approximate softmax, we subtract a bias dependent on sequence length.
Expand Down
Loading

0 comments on commit bd633d9

Please sign in to comment.