Skip to content

Commit

Permalink
Merge pull request #84 from vasqu/fix-mamba2-slow-path
Browse files Browse the repository at this point in the history
[Mamba2] Fix slow path
  • Loading branch information
yzhangcs authored Nov 23, 2024
2 parents 5cc40c5 + 7c6c3d2 commit 05ed6f8
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions fla/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,46 +542,44 @@ def torch_forward(
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))

# First, contraction of C and B to get G (attention-weights like)
# Contraction of C and B to get G (attention-weights like)
# shape: (b, c, l, s, h, n)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)

# Step 2: Compute M, equivalent to applying attention mask to weights
# Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)

# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)

# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)

decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] *
hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if cache_params is not None and cache_params.seqlen_offset > 0:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))

states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
states, ssm_state = new_states[:, :-1], new_states[:, -1]

# Compute state -> output conversion per chunk
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)

# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
Expand Down

0 comments on commit 05ed6f8

Please sign in to comment.