Skip to content

Commit

Permalink
[GLA] Support varlen mode
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Nov 30, 2024
1 parent ddee66d commit 20f30c6
Show file tree
Hide file tree
Showing 5 changed files with 811 additions and 249 deletions.
173 changes: 117 additions & 56 deletions fla/ops/common/chunk_h.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang

from typing import Tuple
from typing import Optional, Tuple

import torch
import triton
Expand All @@ -15,11 +15,12 @@
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'],
key=['BT', 'BK', 'BV', 'USE_G', 'USE_GK', 'USE_GV'],
)
@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'STORE_FINAL_STATE': lambda args: args['ht'] is not None
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.jit
def chunk_fwd_kernel_h(
Expand All @@ -31,38 +32,50 @@ def chunk_fwd_kernel_h(
gv,
h0,
ht,
offsets,
c_offsets,
T: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_h = i_nh // H, i_nh % H
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(c_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT

# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(h0 + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)

for i_t in range(NT):
if HEAD_FIRST:
p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
else:
p_k = tl.make_block_ptr(k + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
Expand All @@ -74,24 +87,24 @@ def chunk_fwd_kernel_h(
# scalar decay
if USE_G:
if HEAD_FIRST:
b_g_last = tl.load(g + i_bh * T + last_idx)
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
b_g_last = tl.load(g + i_nh * T + last_idx)
p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
else:
b_g_last = tl.load(g + i_b * T * H + last_idx * H + i_h)
p_g = g + i_b * T*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
b_h *= tl.exp(b_g_last)
b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
b_v = (b_v * tl.exp(b_g_last - b_g)[:, None]).to(b_v.dtype)

# vector decay, h = Diag(gk) @ h
if USE_GK:
if HEAD_FIRST:
p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
else:
p_gk = tl.make_block_ptr(gk + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_b * T*H*K + last_idx * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
b_h *= tl.exp(b_gk_last)[:, None]
Expand All @@ -102,11 +115,11 @@ def chunk_fwd_kernel_h(
# vector decay, h = h @ Diag(gv)
if USE_GV:
if HEAD_FIRST:
p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
else:
p_gv = tl.make_block_ptr(gv + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_b * T*H*V + last_idx * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
b_h *= tl.exp(b_gv_last)[None, :]
Expand All @@ -117,7 +130,7 @@ def chunk_fwd_kernel_h(
b_h += tl.dot(b_k, b_v)

if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))


Expand All @@ -128,11 +141,12 @@ def chunk_fwd_kernel_h(
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'],
key=['BT', 'BK', 'BV', 'USE_G', 'USE_GK', 'USE_GV'],
)
@triton.heuristics({
'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.jit
def chunk_bwd_kernel_dh(
Expand All @@ -144,6 +158,8 @@ def chunk_bwd_kernel_dh(
dh,
dht,
dh0,
offsets,
c_offsets,
scale,
T: tl.constexpr,
HQ: tl.constexpr,
Expand All @@ -153,36 +169,49 @@ def chunk_bwd_kernel_dh(
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
NG: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_bg = i_bh // NG
i_b, i_hq = i_bh // HQ, i_bh % HQ
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_bg = i_nh // NG
i_n, i_hq = i_nh // HQ, i_nh % HQ
i_h = i_hq // NG
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(c_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT

# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
if USE_FINAL_STATE_GRADIENT:
p_dht = tl.make_block_ptr(dht + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)

for i_t in range(NT - 1, -1, -1):
p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
if HEAD_FIRST:
p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
else:
p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
last_idx = min(i_t * BT + BT, T) - 1
# [BK, BT]
if HEAD_FIRST:
p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
else:
p_q = tl.make_block_ptr(q + i_b * T*HQ*K + i_hq * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_b * T*HQ*V + i_hq * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BV]
Expand All @@ -194,8 +223,8 @@ def chunk_bwd_kernel_dh(
p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
b_g_last = tl.load(g + i_bg * T + last_idx)
else:
p_g = g + i_b * T*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + i_b * T * H + last_idx * H + i_h)
p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
b_q = (b_q * tl.exp(b_g)[None, :]).to(b_q.dtype)

Expand All @@ -204,11 +233,11 @@ def chunk_bwd_kernel_dh(
if USE_GK:
if HEAD_FIRST:
p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_bg * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)

else:
p_gk = tl.make_block_ptr(gk + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_b * T*H*K + last_idx * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)

b_gk = tl.load(p_gk, boundary_check=(0, 1))
Expand All @@ -219,10 +248,10 @@ def chunk_bwd_kernel_dh(
if USE_GV:
if HEAD_FIRST:
p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_bg * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
else:
p_gv = tl.make_block_ptr(gv + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_b * T*H*V + last_idx * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)

b_gv = tl.load(p_gv, boundary_check=(0, 1))
Expand All @@ -234,7 +263,7 @@ def chunk_bwd_kernel_dh(
b_dh += tl.dot(b_q, b_do)

if STORE_INITIAL_STATE_GRADIENT:
p_dh0 = tl.make_block_ptr(dh0 + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))


Expand All @@ -247,6 +276,8 @@ def chunk_fwd_h(
h0: torch.Tensor,
output_final_state: bool,
states_in_fp32: bool = False,
offsets: Optional[torch.Tensor] = None,
c_offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
chunk_size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -255,13 +286,26 @@ def chunk_fwd_h(
else:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = chunk_size
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
# N: the actual number of sequences in the batch with either equal or variable lengths
if offsets is None:
N, NT, c_offsets = B, triton.cdiv(T, BT), None
else:
N = len(offsets) - 1
if c_offsets is None:
c_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
NT = c_offsets[-1]
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)

h = k.new_empty(B, H, NT * K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
ht = k.new_empty(B, H, K, V, dtype=torch.float32) if output_final_state else None
if head_first:
h = k.new_empty(B, H, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
else:
h = k.new_empty(B, NT, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
ht = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None

chunk_fwd_kernel_h[(NK, NV, B * H)](
grid = (NK, NV, N * H)
chunk_fwd_kernel_h[grid](
k=k,
v=v,
h=h,
Expand All @@ -270,14 +314,15 @@ def chunk_fwd_h(
gv=gv,
h0=h0,
ht=ht,
offsets=offsets,
c_offsets=c_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
NT=NT,
USE_G=g is not None,
USE_GK=gk is not None,
USE_GV=gv is not None,
Expand All @@ -298,6 +343,8 @@ def chunk_bwd_dh(
dht: torch.Tensor,
scale: float,
states_in_fp32: bool = False,
offsets: Optional[torch.Tensor] = None,
c_offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
chunk_size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -308,15 +355,28 @@ def chunk_bwd_dh(
B, T, H, K, V = *k.shape, v.shape[-1]
HQ = q.shape[2]
BT = chunk_size
# N: the actual number of sequences in the batch with either equal or variable lengths
if offsets is None:
N, NT, c_offsets = B, triton.cdiv(T, BT), None
else:
N = len(offsets) - 1
if c_offsets is None:
c_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
NT = c_offsets[-1]
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
# number of groups in GQA
NG = HQ // H

dh = k.new_empty(B, HQ, NT * K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
if head_first:
dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
else:
dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
chunk_bwd_kernel_dh[(NK, NV, B * HQ)](

grid = (NK, NV, N * H)
chunk_bwd_kernel_dh[grid](
q=q,
g=g,
gk=gk,
Expand All @@ -325,6 +385,8 @@ def chunk_bwd_dh(
dh=dh,
dht=dht,
dh0=dh0,
offsets=offsets,
c_offsets=c_offsets,
scale=scale,
T=T,
HQ=HQ,
Expand All @@ -334,7 +396,6 @@ def chunk_bwd_dh(
BT=BT,
BK=BK,
BV=BV,
NT=NT,
NG=NG,
USE_G=g is not None,
USE_GK=gk is not None,
Expand Down
Loading

0 comments on commit 20f30c6

Please sign in to comment.