Skip to content

Commit

Permalink
beta vector
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Aug 8, 2024
1 parent e06fdeb commit 81f7ae8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
7 changes: 6 additions & 1 deletion fla/ops/delta_rule/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ def delta_rule_recurrence(q, k, v, beta):
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)

if beta.ndim < v.ndim:
beta = beta[..., None]

for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
_v = _v * beta_i
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)

return o


Expand Down
56 changes: 41 additions & 15 deletions fla/ops/delta_rule/recurrent_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def fused_recurrent_fwd_kernel(
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
IS_BETA_VECTOR: tl.constexpr, # whether beta is headwise vector or scalar
):

# indices
Expand All @@ -50,7 +51,10 @@ def fused_recurrent_fwd_kernel(
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
if IS_BETA_VECTOR:
p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
else:
p_beta = beta + i_bh * T
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)

mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
Expand All @@ -71,7 +75,10 @@ def fused_recurrent_fwd_kernel(
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_v_minus = tl.sum(h * _k[None, :], axis=1)
_v -= _v_minus
_beta = tl.load(p_beta).to(tl.float32)
if IS_BETA_VECTOR:
_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
else:
_beta = tl.load(p_beta).to(tl.float32)
# in-place overwrite
tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
_v *= _beta
Expand All @@ -84,7 +91,7 @@ def fused_recurrent_fwd_kernel(
p_k += DK
p_o += DV
p_v += DV
p_beta += 1
p_beta += DV if IS_BETA_VECTOR else 1

if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
Expand Down Expand Up @@ -129,6 +136,7 @@ def fused_recurrent_bwd_kernel(
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
IS_BETA_VECTOR: tl.constexpr, # whether beta is headwise vector or scalar
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
Expand All @@ -138,8 +146,13 @@ def fused_recurrent_bwd_kernel(
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_beta = beta + i_bh * T + T - 1
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
if IS_BETA_VECTOR:
p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_dbeta = dbeta + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
else:
p_beta = beta + i_bh * T + T - 1
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1

p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
Expand All @@ -152,17 +165,23 @@ def fused_recurrent_bwd_kernel(
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
if IS_BETA_VECTOR:
_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
else:
_beta = tl.load(p_beta).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
d_k = tl.sum(d_h * (_v * _beta)[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)

d_beta = tl.sum(d_v * _v)
d_beta = d_v * _v if IS_BETA_VECTOR else tl.sum(d_v * _v)
d_v = d_v * _beta

tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
if IS_BETA_VECTOR:
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv)
else:
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))

d_h -= _k[:, None] * d_v[None, :]

Expand All @@ -172,8 +191,8 @@ def fused_recurrent_bwd_kernel(
p_v -= DV
p_dk -= DK
p_dv -= DV
p_dbeta -= 1
p_beta -= 1
p_dbeta -= DV if IS_BETA_VECTOR else 1
p_beta -= DV if IS_BETA_VECTOR else 1

tl.debug_barrier()

Expand All @@ -182,7 +201,10 @@ def fused_recurrent_bwd_kernel(
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
if IS_BETA_VECTOR:
p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
else:
p_beta = beta + i_bh * T
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
Expand All @@ -199,7 +221,10 @@ def fused_recurrent_bwd_kernel(
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
if IS_BETA_VECTOR:
_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
else:
_beta = tl.load(p_beta).to(tl.float32)
_v *= _beta

h += _k[:, None] * _v[None, :]
Expand All @@ -219,7 +244,7 @@ def fused_recurrent_bwd_kernel(
p_dk += DK
p_dv += DV
p_dq += DK
p_beta += 1
p_beta += DV if IS_BETA_VECTOR else 1


class FusedRecurrentFunction(torch.autograd.Function):
Expand Down Expand Up @@ -252,7 +277,8 @@ def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_sta
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
STORE_FINAL_STATE=final_state is not None,
IS_BETA_VECTOR=beta.ndim == v.ndim,
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, beta, initial_state)
Expand Down

0 comments on commit 81f7ae8

Please sign in to comment.