diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py index 50633aebc..db7f68cec 100644 --- a/fla/ops/common/chunk_h.py +++ b/fla/ops/common/chunk_h.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -15,11 +15,12 @@ triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), ], - key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'], + key=['BT', 'BK', 'BV', 'USE_G', 'USE_GK', 'USE_GV'], ) @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, - 'STORE_FINAL_STATE': lambda args: args['ht'] is not None + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.jit def chunk_fwd_kernel_h( @@ -31,6 +32,8 @@ def chunk_fwd_kernel_h( gv, h0, ht, + offsets, + c_offsets, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, @@ -38,31 +41,41 @@ def chunk_fwd_kernel_h( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - NT: tl.constexpr, USE_G: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): - i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(c_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT # [BK, BV] b_h = tl.zeros([BK, BV], dtype=tl.float32) if USE_INITIAL_STATE: - p_h0 = tl.make_block_ptr(h0 + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) for i_t in range(NT): if HEAD_FIRST: - p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) else: - p_k = tl.make_block_ptr(k + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_v = tl.make_block_ptr(v + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) # [BK, BT] @@ -74,12 +87,12 @@ def chunk_fwd_kernel_h( # scalar decay if USE_G: if HEAD_FIRST: - b_g_last = tl.load(g + i_bh * T + last_idx) - p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) + b_g_last = tl.load(g + i_nh * T + last_idx) + p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) else: - b_g_last = tl.load(g + i_b * T * H + last_idx * H + i_h) - p_g = g + i_b * T*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h b_h *= tl.exp(b_g_last) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_v = (b_v * tl.exp(b_g_last - b_g)[:, None]).to(b_v.dtype) @@ -87,11 +100,11 @@ def chunk_fwd_kernel_h( # vector decay, h = Diag(gk) @ h if USE_GK: if HEAD_FIRST: - p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) else: - p_gk = tl.make_block_ptr(gk + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_b * T*H*K + last_idx * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) b_h *= tl.exp(b_gk_last)[:, None] @@ -102,11 +115,11 @@ def chunk_fwd_kernel_h( # vector decay, h = h @ Diag(gv) if USE_GV: if HEAD_FIRST: - p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) else: - p_gv = tl.make_block_ptr(gv + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_b * T*H*V + last_idx * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) b_h *= tl.exp(b_gv_last)[None, :] @@ -117,7 +130,7 @@ def chunk_fwd_kernel_h( b_h += tl.dot(b_k, b_v) if STORE_FINAL_STATE: - p_ht = tl.make_block_ptr(ht + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) @@ -128,11 +141,12 @@ def chunk_fwd_kernel_h( triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), ], - key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'], + key=['BT', 'BK', 'BV', 'USE_G', 'USE_GK', 'USE_GV'], ) @triton.heuristics({ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, - 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.jit def chunk_bwd_kernel_dh( @@ -144,6 +158,8 @@ def chunk_bwd_kernel_dh( dh, dht, dh0, + offsets, + c_offsets, scale, T: tl.constexpr, HQ: tl.constexpr, @@ -153,36 +169,49 @@ def chunk_bwd_kernel_dh( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - NT: tl.constexpr, NG: tl.constexpr, USE_G: tl.constexpr, USE_GK: tl.constexpr, USE_GV: tl.constexpr, STORE_INITIAL_STATE_GRADIENT: tl.constexpr, USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): - i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_bg = i_bh // NG - i_b, i_hq = i_bh // HQ, i_bh % HQ + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(c_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + # [BK, BV] b_dh = tl.zeros([BK, BV], dtype=tl.float32) if USE_FINAL_STATE_GRADIENT: - p_dht = tl.make_block_ptr(dht + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) for i_t in range(NT - 1, -1, -1): - p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) last_idx = min(i_t * BT + BT, T) - 1 # [BK, BT] if HEAD_FIRST: - p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: - p_q = tl.make_block_ptr(q + i_b * T*HQ*K + i_hq * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_b * T*HQ*V + i_hq * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # [BT, BV] @@ -194,8 +223,8 @@ def chunk_bwd_kernel_dh( p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) b_g_last = tl.load(g + i_bg * T + last_idx) else: - p_g = g + i_b * T*H + (i_t * BT + tl.arange(0, BT)) * H + i_h - b_g_last = tl.load(g + i_b * T * H + last_idx * H + i_h) + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) b_q = (b_q * tl.exp(b_g)[None, :]).to(b_q.dtype) @@ -204,11 +233,11 @@ def chunk_bwd_kernel_dh( if USE_GK: if HEAD_FIRST: p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_bg * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) else: - p_gk = tl.make_block_ptr(gk + i_b * T*H*K + i_h * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_gk_last = gk + i_b * T*H*K + last_idx * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -219,10 +248,10 @@ def chunk_bwd_kernel_dh( if USE_GV: if HEAD_FIRST: p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_bg * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) else: - p_gv = tl.make_block_ptr(gv + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_gv_last = gv + i_b * T*H*V + last_idx * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) b_gv = tl.load(p_gv, boundary_check=(0, 1)) @@ -234,7 +263,7 @@ def chunk_bwd_kernel_dh( b_dh += tl.dot(b_q, b_do) if STORE_INITIAL_STATE_GRADIENT: - p_dh0 = tl.make_block_ptr(dh0 + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) @@ -247,6 +276,8 @@ def chunk_fwd_h( h0: torch.Tensor, output_final_state: bool, states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + c_offsets: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -255,13 +286,26 @@ def chunk_fwd_h( else: B, T, H, K, V = *k.shape, v.shape[-1] BT = chunk_size - BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) - NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, c_offsets = B, triton.cdiv(T, BT), None + else: + N = len(offsets) - 1 + if c_offsets is None: + c_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NT = c_offsets[-1] + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) - h = k.new_empty(B, H, NT * K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) - ht = k.new_empty(B, H, K, V, dtype=torch.float32) if output_final_state else None + if head_first: + h = k.new_empty(B, H, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) + else: + h = k.new_empty(B, NT, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) + ht = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None - chunk_fwd_kernel_h[(NK, NV, B * H)]( + grid = (NK, NV, N * H) + chunk_fwd_kernel_h[grid]( k=k, v=v, h=h, @@ -270,6 +314,8 @@ def chunk_fwd_h( gv=gv, h0=h0, ht=ht, + offsets=offsets, + c_offsets=c_offsets, T=T, H=H, K=K, @@ -277,7 +323,6 @@ def chunk_fwd_h( BT=BT, BK=BK, BV=BV, - NT=NT, USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, @@ -298,6 +343,8 @@ def chunk_bwd_dh( dht: torch.Tensor, scale: float, states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + c_offsets: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -308,15 +355,28 @@ def chunk_bwd_dh( B, T, H, K, V = *k.shape, v.shape[-1] HQ = q.shape[2] BT = chunk_size + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, c_offsets = B, triton.cdiv(T, BT), None + else: + N = len(offsets) - 1 + if c_offsets is None: + c_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NT = c_offsets[-1] BK = min(triton.next_power_of_2(K), 64) BV = min(triton.next_power_of_2(V), 64) - NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) # number of groups in GQA NG = HQ // H - dh = k.new_empty(B, HQ, NT * K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) + if head_first: + dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) + else: + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float32) dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None - chunk_bwd_kernel_dh[(NK, NV, B * HQ)]( + + grid = (NK, NV, N * H) + chunk_bwd_kernel_dh[grid]( q=q, g=g, gk=gk, @@ -325,6 +385,8 @@ def chunk_bwd_dh( dh=dh, dht=dht, dh0=dh0, + offsets=offsets, + c_offsets=c_offsets, scale=scale, T=T, HQ=HQ, @@ -334,7 +396,6 @@ def chunk_bwd_dh( BT=BT, BK=BK, BV=BV, - NT=NT, NG=NG, USE_G=g is not None, USE_GK=gk is not None, diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py index 31c64b463..c5138dfda 100644 --- a/fla/ops/common/fused_recurrent.py +++ b/fla/ops/common/fused_recurrent.py @@ -21,7 +21,8 @@ ) @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, - 'STORE_FINAL_STATE': lambda args: args['ht'] is not None + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.jit def fused_recurrent_fwd_kernel( @@ -34,6 +35,7 @@ def fused_recurrent_fwd_kernel( o, # output [NK, B, H, T, V]/[NK, B, T, H, V] h0, # initial hidden state [B, H, K, V] ht, # final hidden state [B, H, K, V] + offsets, scale, B: tl.constexpr, T: tl.constexpr, @@ -48,34 +50,42 @@ def fused_recurrent_fwd_kernel( USE_GV: tl.constexpr, # whether to use gv USE_INITIAL_STATE: tl.constexpr, # whether to use initial state STORE_FINAL_STATE: tl.constexpr, # whether to store final state - HEAD_FIRST: tl.constexpr # whether the inputs are in the head-first format + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr ): # indices - i_v, i_k, i_bh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) - i_b, i_h = i_bh // H, i_bh % H + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T if HEAD_FIRST: - p_q = q + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_k = k + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_bh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_o = o + (i_k * B*H + i_bh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) if USE_G: - p_g = g + i_bh * T + ((T-1) if REVERSE else 0) + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) if USE_GK: - p_gk = gk + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_bh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) else: - p_q = q + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_k = k + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + i_b * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) - p_o = o + (i_k * B + i_b) * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) if USE_G: - p_g = g + i_b * T*H + ((T-1) * H if REVERSE else 0) + i_h + p_g = g + bos*H + ((T-1) * H if REVERSE else 0) + i_h if USE_GK: - p_gk = gk + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_b * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) mask_k = (i_k * BK + tl.arange(0, BK)) < K mask_v = (i_v * BV + tl.arange(0, BV)) < V @@ -83,7 +93,7 @@ def fused_recurrent_fwd_kernel( b_h = tl.zeros([BV, BK], dtype=tl.float32) if USE_INITIAL_STATE: - p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) for _ in range(0, T): @@ -115,7 +125,7 @@ def fused_recurrent_fwd_kernel( p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) if STORE_FINAL_STATE: - p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) @@ -130,7 +140,8 @@ def fused_recurrent_fwd_kernel( @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, - 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None }) # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 @triton.jit @@ -148,7 +159,8 @@ def fused_recurrent_bwd_kernel( dv, # gradient wrt value [NK, B, H, T, V]/[NV, B, T, H, V] dht, # gradient wrt final hidden state [B, H, K, V] dh0, # gradient wrt initial hidden state [B, H, K, V] - scale, # K ** -0.5 + offsets, + scale, B: tl.constexpr, T: tl.constexpr, H: tl.constexpr, @@ -163,35 +175,43 @@ def fused_recurrent_bwd_kernel( USE_INITIAL_STATE: tl.constexpr, # whether to use initial state STORE_INITIAL_STATE_GRADIENT: tl.constexpr, # whether to store gradient wrt initial state USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to compute gradient wrt final state - HEAD_FIRST: tl.constexpr # whether the inputs are in the head-first format + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr ): - i_v, i_k, i_bh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) - i_b, i_h = i_bh // H, i_bh % H + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T if HEAD_FIRST: - p_q = q + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_k = k + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_bh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_do = do + i_bh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_dq = dq + (i_v * B*H + i_bh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) if USE_G: - p_g = g + i_bh * T + ((T-1) if REVERSE else 0) + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) if USE_GK: - p_gk = gk + i_bh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_bh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) else: - p_q = q + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_k = k + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + i_b * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) - p_do = do + i_b * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) - p_dq = dq + (i_v * B + i_b) * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) if USE_G: - p_g = g + i_b * T*H + ((T-1) * H if REVERSE else 0) + i_h + p_g = g + bos*H + ((T-1) * H if REVERSE else 0) + i_h if USE_GK: - p_gk = gk + i_b * T*H*K + ((T-1) * H*K if REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_b * T*H*V + ((T-1) * H*V if REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) mask_k = i_k * BK + tl.arange(0, BK) < K mask_v = i_v * BV + tl.arange(0, BV) < V @@ -199,7 +219,7 @@ def fused_recurrent_bwd_kernel( b_h = tl.zeros([BK, BV], dtype=tl.float32) if USE_INITIAL_STATE: - p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) for _ in range(0, T): @@ -236,35 +256,35 @@ def fused_recurrent_bwd_kernel( tl.debug_barrier() if HEAD_FIRST: - p_q = q + i_bh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_k = k + i_bh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_v = v + i_bh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_do = do + i_bh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) - p_dk = dk + (i_v * B*H + i_bh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) - p_dv = dv + (i_k * B*H + i_bh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) if USE_G: - p_g = g + i_bh * T + ((T - 1) if not REVERSE else 0) + p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0) if USE_GK: - p_gk = gk + i_bh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_bh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) else: - p_q = q + i_b * T*H*K + ((T - 1) * H*K if not REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_k = k + i_b * T*H*K + ((T - 1) * H*K if not REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_v = v + i_b * T*H*V + ((T - 1) * H*V if not REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) - p_do = do + i_b * T*H*V + ((T - 1) * H*V if not REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) - p_dk = dk + (i_v * B + i_b) * T*H*K + ((T - 1) * H*K if not REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) - p_dv = dv + (i_k * B + i_b) * T*H*V + ((T - 1) * H*V if not REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T - 1) if not REVERSE else 0))*H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T - 1) if not REVERSE else 0))*H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) if USE_G: - p_g = g + i_b * T*H + ((T - 1) if not REVERSE else 0) + i_h + p_g = g + bos*H + ((T - 1) if not REVERSE else 0) + i_h if USE_GK: - p_gk = gk + i_b * T*H*K + ((T - 1) * H*K if not REVERSE else 0) + i_h * K + i_k * BK + tl.arange(0, BK) + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) if USE_GV: - p_gv = gv + i_b * T*H*V + ((T - 1) * H*V if not REVERSE else 0) + i_h * V + i_v * BV + tl.arange(0, BV) + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) b_dh = tl.zeros([BK, BV], dtype=tl.float32) if USE_FINAL_STATE_GRADIENT: - p_dht = dht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32) for _ in range(T): @@ -301,7 +321,7 @@ def fused_recurrent_bwd_kernel( p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V if STORE_INITIAL_STATE_GRADIENT: - p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) @@ -316,6 +336,7 @@ def fused_recurrent_fwd( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ): if head_first: @@ -324,15 +345,16 @@ def fused_recurrent_fwd( B, T, H, K, V = *k.shape, v.shape[-1] BK, BV = min(K, 64), min(V, 64) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + N = B if offsets is None else len(offsets) - 1 h0 = initial_state if output_final_state: - ht = q.new_empty(B, H, K, V, dtype=torch.float32) + ht = q.new_empty(N, H, K, V, dtype=torch.float32) else: ht = None o = q.new_empty(NK, *v.shape, dtype=torch.float32) - grid = (NV, NK, B * H) + grid = (NV, NK, N * H) fused_recurrent_fwd_kernel[grid]( q, k, @@ -343,6 +365,7 @@ def fused_recurrent_fwd( o, h0, ht, + offsets, scale, B=B, T=T, @@ -374,12 +397,14 @@ def fused_recurrent_bwd( scale: Optional[float] = None, initial_state: Optional[torch.Tensor] = None, reverse: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ): if head_first: B, H, T, K, V = *k.shape, v.shape[-1] else: B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 BK, BV = min(K, 64), min(V, 64) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) @@ -388,9 +413,9 @@ def fused_recurrent_bwd( dk = q.new_empty(NV, *k.shape, dtype=torch.float32) dv = q.new_empty(NK, *v.shape, dtype=torch.float32) h0 = initial_state - dh0 = torch.empty_like(initial_state) if (initial_state is not None) else None + dh0 = torch.empty_like(initial_state) if initial_state is not None else None - grid = (NV, NK, B * H) + grid = (NV, NK, N * H) fused_recurrent_bwd_kernel[grid]( q, k, @@ -405,6 +430,7 @@ def fused_recurrent_bwd( dv, dht, dh0, + offsets, scale, B=B, T=T, @@ -427,18 +453,21 @@ def fused_recurrent_bwd( dg = chunk_global_cumsum( (dq * q.float() - dk * k.float()).sum(-1), reverse=not reverse, + offsets=offsets, head_first=head_first ) if gk is not None: dgk = chunk_global_cumsum( dq * q.float() - dk * k.float(), reverse=not reverse, + offsets=offsets, head_first=head_first ) if gv is not None: dgv = chunk_global_cumsum( do.float() * o.float() - dv * v.float(), reverse=not reverse, + offsets=offsets, head_first=head_first ) @@ -462,6 +491,7 @@ def forward( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ): o, ht = fused_recurrent_fwd( @@ -475,11 +505,13 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, reverse=reverse, + offsets=offsets, head_first=head_first ) ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o) ctx.scale = scale ctx.reverse = reverse + ctx.offsets = offsets ctx.head_first = head_first return o.to(q.dtype), ht @@ -510,9 +542,10 @@ def backward(ctx, do, dht): scale=ctx.scale, initial_state=initial_state, reverse=ctx.reverse, + offsets=ctx.offsets, head_first=ctx.head_first ) - return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None def fused_recurrent( @@ -526,8 +559,22 @@ def fused_recurrent( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ): if scale is None: scale = k.shape[-1] ** -0.5 - return FusedRecurrentFunction.apply(q, k, v, g, gk, gv, scale, initial_state, output_final_state, reverse, head_first) + return FusedRecurrentFunction.apply( + q, + k, + v, + g, + gk, + gv, + scale, + initial_state, + output_final_state, + reverse, + offsets, + head_first + ) diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py index 68e8a667e..3305ffe64 100644 --- a/fla/ops/gla/chunk.py +++ b/fla/ops/gla/chunk.py @@ -21,12 +21,15 @@ ], key=["BC", "BK"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_fwd_A_kernel_intra_sub_inter( q, k, g, A, + offsets, + indices, scale, T: tl.constexpr, H: tl.constexpr, @@ -35,11 +38,19 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if i_t * BT + i_i * BC >= T: return if i_i <= i_j: @@ -55,13 +66,13 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) p_gk = tl.make_block_ptr(g + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gn = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) else: - p_q = tl.make_block_ptr(q + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_b*T*H*K+i_h*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gk = tl.make_block_ptr(g + i_b*T*H*K+i_h*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) @@ -79,7 +90,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( if HEAD_FIRST: p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) else: - p_A = tl.make_block_ptr(A + i_b*T*H*BT + i_h*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) @@ -92,12 +103,15 @@ def chunk_gla_fwd_A_kernel_intra_sub_inter( ], key=["BK", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_fwd_A_kernel_intra_sub_intra( q, k, g, A, + offsets, + indices, scale, T: tl.constexpr, H: tl.constexpr, @@ -105,11 +119,19 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if i_t * BT + i_i * BC >= T: return @@ -121,14 +143,14 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + i_bh * T*K + (i_t * BT + i_j * BC) * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) else: - o_A = i_b * T*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC - p_q = tl.make_block_ptr(q + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + i_b * T*H*K + (i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.max_contiguous(tl.multiple_of(k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) @@ -152,12 +174,15 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra( ], key=["BC", "BK"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_fwd_A_kernel_intra_sub_intra_split( q, k, g, A, + offsets, + indices, scale, B: tl.constexpr, T: tl.constexpr, @@ -167,12 +192,22 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_tc // NC, i_tc % NC i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + if i_t * BT + i_i * BC >= T: return @@ -185,14 +220,14 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + i_bh * T*K + (i_t * BT + i_j * BC) * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) else: - o_A = (i_k * B + i_b) * T*H*BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC - p_q = tl.make_block_ptr(q + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.max_contiguous(tl.multiple_of(k + i_b * T*H*K + (i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) - p_gk = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.max_contiguous(tl.multiple_of(k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k, BK), BK) b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) @@ -216,20 +251,33 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_split( ], key=["BC"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( A, A2, + offsets, + indices, B: tl.constexpr, T: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, NK: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + if i_t * BT + i_c * BC >= T: return @@ -238,12 +286,12 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( if HEAD_FIRST: p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) else: - p_A = tl.make_block_ptr(A + (i_k*B+i_b)*T*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) b_A += tl.load(p_A, boundary_check=(0, 1)) if HEAD_FIRST: p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) else: - p_A2 = tl.make_block_ptr(A2 + i_b*T*H*BT+i_h*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) @@ -256,6 +304,7 @@ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( ], key=["BK", "BV", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_fwd_kernel_o( q, @@ -264,6 +313,8 @@ def chunk_gla_fwd_kernel_o( h, o, A, + offsets, + indices, scale, T: tl.constexpr, H: tl.constexpr, @@ -272,11 +323,22 @@ def chunk_gla_fwd_kernel_o( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - NT: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] b_o = tl.zeros([BT, BV], dtype=tl.float32) @@ -284,11 +346,11 @@ def chunk_gla_fwd_kernel_o( if HEAD_FIRST: p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) else: - p_q = tl.make_block_ptr(q + i_b * T*H*K + i_h*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_g = tl.make_block_ptr(g + i_b * T*H*K + i_h*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -308,9 +370,9 @@ def chunk_gla_fwd_kernel_o( p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: - p_v = tl.make_block_ptr(v + i_b*T*H*V + i_h*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o + i_b*T*H*V + i_h*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_A = tl.make_block_ptr(A + i_b*T*H*BT + i_h*BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [BT, BT] @@ -329,6 +391,7 @@ def chunk_gla_fwd_kernel_o( ], key=["BK", "NC", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_bwd_kernel_intra( q, @@ -337,6 +400,8 @@ def chunk_gla_bwd_kernel_intra( dA, dq, dk, + offsets, + indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, @@ -344,11 +409,18 @@ def chunk_gla_bwd_kernel_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_t, i_i = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos if i_t * BT + i_i * BC >= T: return @@ -358,15 +430,15 @@ def chunk_gla_bwd_kernel_intra( if HEAD_FIRST: p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) else: - p_g = tl.make_block_ptr(g + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_g = tl.load(p_g, boundary_check=(0, 1)) b_dq = tl.zeros([BC, BK], dtype=tl.float32) if i_i > 0: if HEAD_FIRST: - p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gn = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) else: - p_gn = tl.max_contiguous(tl.multiple_of(g + i_b*T*H*K + (i_t * BT + i_i * BC) * H*K + i_h*K + o_k, BK), BK) + p_gn = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k, BK), BK) # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(0, i_i): @@ -375,9 +447,9 @@ def chunk_gla_bwd_kernel_intra( p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) else: - p_k = tl.make_block_ptr(k+i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g+i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA+i_b*T*H*BT+i_h*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -392,14 +464,14 @@ def chunk_gla_bwd_kernel_intra( m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T if HEAD_FIRST: o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC - p_kj = tl.max_contiguous(tl.multiple_of(k + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_gkj = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gkj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) else: - o_dA = i_b * T*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC - p_kj = tl.max_contiguous(tl.multiple_of(k + i_b * T*H*K + (i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) - p_gkj = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) - p_dq = tl.make_block_ptr(dq + i_b*T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = tl.max_contiguous(tl.multiple_of(k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) + p_gkj = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] @@ -421,8 +493,8 @@ def chunk_gla_bwd_kernel_intra( p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) else: - p_k = tl.make_block_ptr(k + i_b * T*H*K + i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + i_b * T*H*K + i_h*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -434,7 +506,7 @@ def chunk_gla_bwd_kernel_intra( if HEAD_FIRST: p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh*T*K + (i_t * BT + i_i * BC + BC - 1)*K + o_k, BK), BK) else: - p_gn = tl.max_contiguous(tl.multiple_of(g + i_b*T*H*K + (i_t * BT + i_i * BC + BC - 1)*H*K + i_h*K + o_k, BK), BK) + p_gn = tl.max_contiguous(tl.multiple_of(g + bos*H*K + (i_t * BT + i_i * BC + BC - 1)*H*K + i_h*K + o_k, BK), BK) # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(i_i + 1, NC): @@ -443,9 +515,9 @@ def chunk_gla_bwd_kernel_intra( p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) else: - p_q = tl.make_block_ptr(q+i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_g = tl.make_block_ptr(g+i_b*T*H*K+i_h*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_dA = tl.make_block_ptr(dA+i_b*T*H*BT+i_h*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) # [BC, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) @@ -458,14 +530,14 @@ def chunk_gla_bwd_kernel_intra( b_dk *= tl.exp(b_gn[None, :] - b_gk) if HEAD_FIRST: o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) - p_qj = tl.max_contiguous(tl.multiple_of(q + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) - p_gqj = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gqj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) else: - o_dA = i_b * T*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) - p_qj = tl.max_contiguous(tl.multiple_of(q + i_b * T*H*K + (i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) - p_gqj = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) - p_dk = tl.make_block_ptr(dk + i_b*T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = tl.max_contiguous(tl.multiple_of(q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) + p_gqj = tl.max_contiguous(tl.multiple_of(g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k, BK), BK) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC,] b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT) @@ -489,21 +561,31 @@ def chunk_gla_bwd_kernel_intra( ], key=["BV", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_bwd_kernel_dA( v, do, dA, + offsets, + indices, scale, T: tl.constexpr, H: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): @@ -511,15 +593,15 @@ def chunk_gla_bwd_kernel_dA( p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) else: - p_do = tl.make_block_ptr(do + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_v = tl.make_block_ptr(v + i_b * T*H*V + i_h * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) b_dA += tl.dot(b_do, b_v) if HEAD_FIRST: p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: - p_dA = tl.make_block_ptr(dA + i_b * T*H*BT + i_h * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] b_dA = tl.where(m_s, b_dA * scale, 0.) tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) @@ -534,6 +616,7 @@ def chunk_gla_bwd_kernel_dA( ], key=["BK", "BV", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_bwd_kernel_dv( k, @@ -542,6 +625,8 @@ def chunk_gla_bwd_kernel_dv( do, dh, dv, + offsets, + indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, @@ -549,20 +634,30 @@ def chunk_gla_bwd_kernel_dv( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - NT: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: - p_A = tl.make_block_ptr(A + i_b * T*H*BT + i_h * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) - p_do = tl.make_block_ptr(do + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_dv = tl.make_block_ptr(dv + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.) @@ -578,11 +673,12 @@ def chunk_gla_bwd_kernel_dv( p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + min(i_t * BT + BT, T) * K - K + o_k, BK), BK) + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) else: - p_k = tl.make_block_ptr(k + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gk = tl.make_block_ptr(g + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k, BK), BK) - p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k, BK), BK) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_gk = tl.load(p_gk, boundary_check=(0, 1)) @@ -604,6 +700,7 @@ def chunk_gla_bwd_kernel_dv( ], key=["BK", "BV", "BT"], ) +@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_gla_bwd_kernel_inter( q, @@ -618,6 +715,8 @@ def chunk_gla_bwd_kernel_inter( dq2, dk2, dg, + offsets, + indices, scale, T: tl.constexpr, H: tl.constexpr, @@ -626,11 +725,21 @@ def chunk_gla_bwd_kernel_inter( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, - NT: tl.constexpr, + USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K @@ -638,8 +747,8 @@ def chunk_gla_bwd_kernel_inter( p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK) else: - p_gk = tl.make_block_ptr(g + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_gn = tl.max_contiguous(tl.multiple_of(g + i_b * T*H*K + (min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k, BK), BK) + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k, BK), BK) b_gn = tl.load(p_gn, mask=m_k, other=0) b_dq = tl.zeros([BT, BK], dtype=tl.float32) b_dk = tl.zeros([BT, BK], dtype=tl.float32) @@ -649,11 +758,13 @@ def chunk_gla_bwd_kernel_inter( if HEAD_FIRST: p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) else: - p_v = tl.make_block_ptr(v + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do + i_b * T*H*V + i_h * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_do = tl.load(p_do, boundary_check=(0, 1)) @@ -677,10 +788,10 @@ def chunk_gla_bwd_kernel_inter( p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: - p_q = tl.make_block_ptr(q + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dq = tl.make_block_ptr(dq + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dgk += tl.sum(b_dk * b_k, axis=0) @@ -697,9 +808,9 @@ def chunk_gla_bwd_kernel_inter( p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: - p_dq = tl.make_block_ptr(dq2 + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dk = tl.make_block_ptr(dk2 + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_dg = tl.make_block_ptr(dg + i_b * T*H*K + i_h * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) @@ -710,6 +821,8 @@ def chunk_gla_fwd_intra_gk( k: torch.Tensor, g: torch.Tensor, scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -718,10 +831,18 @@ def chunk_gla_fwd_intra_gk( else: B, T, H, K = k.shape BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + NT = triton.cdiv(T, BT) + else: + if indices is None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + NT = len(indices) BC = min(16, triton.next_power_of_2(T)) BK = min(64, triton.next_power_of_2(K)) NC = triton.cdiv(BT, BC) - NT = triton.cdiv(T, BT) A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float32) grid = (NT, NC * NC, B * H) @@ -730,6 +851,8 @@ def chunk_gla_fwd_intra_gk( k, g, A, + offsets, + indices, scale, T=T, H=H, @@ -740,6 +863,7 @@ def chunk_gla_fwd_intra_gk( NC=NC, HEAD_FIRST=head_first ) + grid = (NT, NC, B * H) # load the entire [BC, K] blocks into SRAM at once if K <= 256: @@ -749,6 +873,8 @@ def chunk_gla_fwd_intra_gk( k, g, A, + offsets, + indices, scale, T=T, H=H, @@ -762,13 +888,16 @@ def chunk_gla_fwd_intra_gk( else: BK = min(128, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) - A_intra = q.new_empty(NK, B, H, T, BC, dtype=torch.float32) + A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float32) + grid = (NK, NT * NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( q, k, g, A_intra, + offsets, + indices, scale, B=B, T=T, @@ -780,10 +909,13 @@ def chunk_gla_fwd_intra_gk( NC=NC, HEAD_FIRST=head_first ) + grid = (NT, NC, B * H) chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid]( A_intra, A, + offsets, + indices, B=B, T=T, H=H, @@ -802,6 +934,8 @@ def chunk_gla_fwd_o_gk( A: torch.Tensor, h: torch.Tensor, scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -810,10 +944,17 @@ def chunk_gla_fwd_o_gk( else: B, T, H, K, V = *q.shape, v.shape[-1] BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + NT = triton.cdiv(T, BT) + else: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + NT = len(indices) BK = min(32, triton.next_power_of_2(K)) BV = min(32, triton.next_power_of_2(V)) NV = triton.cdiv(V, BV) - NT = triton.cdiv(T, BT) grid = (NV, NT, B * H) o = torch.empty_like(v) @@ -824,6 +965,8 @@ def chunk_gla_fwd_o_gk( h, o, A, + offsets, + indices, scale, T=T, H=H, @@ -832,7 +975,6 @@ def chunk_gla_fwd_o_gk( BT=BT, BK=BK, BV=BV, - NT=NT, HEAD_FIRST=head_first ) return o @@ -842,6 +984,8 @@ def chunk_gla_bwd_dA( v: torch.Tensor, do: torch.Tensor, scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -850,15 +994,25 @@ def chunk_gla_bwd_dA( else: B, T, H, V = v.shape BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + N, NT = B, triton.cdiv(T, BT) + else: + if indices is None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + N, NT = len(offsets) - 1, len(indices) BV = min(64, triton.next_power_of_2(V)) - NT = triton.cdiv(T, BT) dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float32) - grid = (NT, B * H) + grid = (NT, N * H) chunk_gla_bwd_kernel_dA[grid]( v, do, dA, + offsets, + indices, scale, T=T, H=H, @@ -876,7 +1030,8 @@ def chunk_gla_bwd_dv( A: torch.Tensor, do: torch.Tensor, dh: torch.Tensor, - scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -884,13 +1039,21 @@ def chunk_gla_bwd_dv( B, H, T, K, V = *k.shape, do.shape[-1] else: B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + N, NT = B, triton.cdiv(T, BT) + else: + if indices is None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + N, NT = len(offsets) - 1, len(indices) BK = min(64, triton.next_power_of_2(K)) BV = min(32, triton.next_power_of_2(V)) - BT = min(chunk_size, triton.next_power_of_2(T)) - NT = triton.cdiv(T, BT) dv = torch.empty_like(do) - grid = (triton.cdiv(V, BV), NT, B * H) + grid = (triton.cdiv(V, BV), NT, N * H) chunk_gla_bwd_kernel_dv[grid]( k, g, @@ -898,6 +1061,8 @@ def chunk_gla_bwd_dv( do, dh, dv, + offsets, + indices, T=T, H=H, K=K, @@ -905,7 +1070,6 @@ def chunk_gla_bwd_dv( BT=BT, BK=BK, BV=BV, - NT=NT, HEAD_FIRST=head_first ) return dv @@ -916,6 +1080,8 @@ def chunk_gla_bwd_dqk_intra( k: torch.Tensor, g: torch.Tensor, dA: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -924,14 +1090,23 @@ def chunk_gla_bwd_dqk_intra( else: B, T, H, K = q.shape BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + N, NT = B, triton.cdiv(T, BT) + else: + if indices is None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + N, NT = len(offsets) - 1, len(indices) BC = min(16, triton.next_power_of_2(T)) BK = min(64, triton.next_power_of_2(K)) NK = triton.cdiv(K, BK) - NT = triton.cdiv(T, BT) NC = triton.cdiv(BT, BC) + dq = torch.empty_like(q, dtype=torch.float32) dk = torch.empty_like(k, dtype=torch.float32) - grid = (NK, NT * NC, B * H) + grid = (NK, NT * NC, N * H) chunk_gla_bwd_kernel_intra[grid]( q, k, @@ -939,6 +1114,8 @@ def chunk_gla_bwd_dqk_intra( dA, dq, dk, + offsets, + indices, T=T, H=H, K=K, @@ -962,6 +1139,8 @@ def chunk_gla_bwd_dqkg( dq: torch.Tensor, dk: torch.Tensor, scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): @@ -970,13 +1149,21 @@ def chunk_gla_bwd_dqkg( else: B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, triton.next_power_of_2(T)) + if offsets is None: + N, NT = B, triton.cdiv(T, BT) + else: + if indices is None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) + N, NT = len(offsets) - 1, len(indices) BK = min(64, triton.next_power_of_2(K)) BV = min(64, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) - NT = triton.cdiv(T, BT) dg = torch.empty_like(g) - grid = (NK, NT, B * H) + grid = (NK, NT, N * H) # work around triton compiler bugs. dq2 = torch.empty_like(dq) dk2 = torch.empty_like(dk) @@ -993,6 +1180,8 @@ def chunk_gla_bwd_dqkg( dq2, dk2, dg, + offsets, + indices, scale, T=T, H=H, @@ -1001,7 +1190,6 @@ def chunk_gla_bwd_dqkg( BT=BT, BK=BK, BV=BV, - NT=NT, HEAD_FIRST=head_first ) return dq2, dk2, dg @@ -1016,13 +1204,21 @@ def chunk_gla_fwd( scale: float, initial_state: torch.Tensor, output_final_state: bool, + offsets: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: T = q.shape[2] if head_first else q.shape[1] BT = min(chunk_size, triton.next_power_of_2(T)) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, head_first=head_first) + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, head_first=head_first) + + indices = None + if offsets is not None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) h, ht = chunk_fwd_h( k=k, v=v, @@ -1032,9 +1228,11 @@ def chunk_gla_fwd( h0=initial_state, output_final_state=output_final_state, states_in_fp32=False, + offsets=offsets, head_first=head_first, chunk_size=BT ) + # the intra A is kept in fp32 # the computation has very marginal effect on the entire throughput A = chunk_gla_fwd_intra_gk( @@ -1042,6 +1240,8 @@ def chunk_gla_fwd( k=k, g=g_cumsum, scale=scale, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1052,6 +1252,8 @@ def chunk_gla_fwd( A=A, h=h, scale=scale, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1070,14 +1272,21 @@ def chunk_gla_bwd( A: torch.Tensor, do: torch.Tensor, dht: torch.Tensor, + offsets: Optional[torch.Tensor] = None, head_first: bool = True, chunk_size: int = 64 ): T = q.shape[2] if head_first else q.shape[1] BT = min(chunk_size, triton.next_power_of_2(T)) if g_cumsum is None: - g_cumsum = chunk_local_cumsum(g, BT, head_first=head_first) - + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, head_first=head_first) + + indices = None + if offsets is not None: + indices = torch.cat([ + torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) + for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) + ]) if h is None: h, _ = chunk_fwd_h( k=k, @@ -1088,6 +1297,7 @@ def chunk_gla_bwd( h0=initial_state, output_final_state=False, states_in_fp32=True, + offsets=offsets, head_first=head_first, chunk_size=BT ) @@ -1103,6 +1313,7 @@ def chunk_gla_bwd( dht=dht, scale=scale, states_in_fp32=True, + offsets=offsets, head_first=head_first, chunk_size=BT ) @@ -1112,7 +1323,8 @@ def chunk_gla_bwd( A=A, do=do, dh=dh, - scale=scale, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1121,6 +1333,8 @@ def chunk_gla_bwd( v=v, do=do, scale=scale, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1129,6 +1343,8 @@ def chunk_gla_bwd( k=k, g=g_cumsum, dA=dA, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1143,6 +1359,8 @@ def chunk_gla_bwd( dq=dq, dk=dk, scale=scale, + offsets=offsets, + indices=indices, head_first=head_first, chunk_size=BT ) @@ -1153,8 +1371,20 @@ class ChunkGLAFunction(torch.autograd.Function): @staticmethod @contiguous - def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, head_first): - BT = 64 + def forward( + ctx, + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + offsets, + head_first + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(64, triton.next_power_of_2(T)) g_cumsum, A, h, ht, o = chunk_gla_fwd( q=q, k=k, @@ -1164,8 +1394,9 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, head_firs scale=scale, initial_state=initial_state, output_final_state=output_final_state, + offsets=offsets, head_first=head_first, - chunk_size=BT + chunk_size=chunk_size ) # recompute g_cumsum in bwd pass if g.dtype != torch.float32: @@ -1173,8 +1404,9 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, head_firs else: g = None ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A) - ctx.BT = BT + ctx.chunk_size = chunk_size ctx.scale = scale + ctx.offsets = offsets ctx.head_first = head_first return o, ht @@ -1182,7 +1414,7 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, head_firs @contiguous def backward(ctx, do, dht): q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors - BT, scale, head_first = ctx.BT, ctx.scale, ctx.head_first + chunk_size, scale, offsets, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.head_first dq, dk, dv, dg, dh0 = chunk_gla_bwd( q=q, k=k, @@ -1195,10 +1427,11 @@ def backward(ctx, do, dht): initial_state=initial_state, do=do, dht=dht, + offsets=offsets, head_first=head_first, - chunk_size=BT + chunk_size=chunk_size ) - return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None def chunk_gla( @@ -1209,6 +1442,7 @@ def chunk_gla( scale: Optional[int] = None, initial_state: torch.Tensor = None, output_final_state: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: r""" @@ -1225,20 +1459,66 @@ def chunk_gla( Scale factor for the GLA attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[B, H, K, V]`. Default: `None`. + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + offsets (Optional[torch.Tensor]): + Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. + For example, + if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. + If provided, the inputs are concatenated and the batch size `B` is expected to be 1. + Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format. + This head-first format is not supported for variable-length inputs. Default: `True`. Returns: o (torch.Tensor): Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. final_state (torch.Tensor): - Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import chunk_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, offsets with 5 start/end positions are expected + >>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + offsets=offsets, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) """ + if offsets is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") if scale is None: scale = q.shape[-1] ** -0.5 - o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, head_first) + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, offsets, head_first) return o, final_state diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py index 52e05237e..6f735c4c6 100644 --- a/fla/ops/gla/fused_recurrent.py +++ b/fla/ops/gla/fused_recurrent.py @@ -18,16 +18,17 @@ def fused_recurrent_gla( initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, + offsets: Optional[torch.Tensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): - queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. k (torch.Tensor): - keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. v (torch.Tensor): - values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. gk (torch.Tensor): Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. gv (torch.Tensor): @@ -36,24 +37,79 @@ def fused_recurrent_gla( Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[B, H, K, V]`. Default: `None`. + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. - reverse (Optional[bool]): - If `True`, process the state passing in reverse order. Default: `False`. + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + offsets (Optional[torch.Tensor]): + Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. + For example, + if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. + If provided, the inputs are concatenated and the batch size `B` is expected to be 1. + Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format. + This head-first format is not supported for variable-length inputs. Default: `True`. Returns: o (torch.Tensor): Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. final_state (torch.Tensor): - Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import fused_recurrent_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, offsets with 5 start/end positions are expected + >>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + offsets=offsets, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) """ - assert q.dim() == k.dim() == v.dim() == 4, "q, k, v must have 4 dimensions" - assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + if offsets is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") if scale is None: scale = k.shape[-1] ** -0.5 - o, final_state = fused_recurrent(q, k, v, None, gk, gv, scale, initial_state, output_final_state, reverse, head_first) + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=None, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) return o, final_state diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py index bc75ce6eb..3a5927715 100644 --- a/tests/ops/test_gla.py +++ b/tests/ops/test_gla.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from fla.ops.gla import chunk_gla, fused_recurrent_gla +from fla.ops.gla.naive import naive_recurrent_gla def get_abs_err(x, y): @@ -26,15 +27,76 @@ def assert_close(prefix, ref, tri, ratio): @pytest.mark.parametrize("B", [4]) +@pytest.mark.parametrize("T", [300, 512]) @pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("D", [32, 64, 100]) +@pytest.mark.parametrize("dtype", [torch.float]) +def test_fused_recurrent( + B: int, + T: int, + H: int, + D: int, + dtype: torch.dtype +): + torch.manual_seed(42) + + q = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_() + k = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_() + v = torch.randn((B, H, T, D), dtype=dtype, device='cuda').requires_grad_() + g = F.logsigmoid(torch.randn((B, H, T, D), dtype=dtype, device='cuda')).requires_grad_() + h0 = torch.randn(B, H, D, D, device='cuda').requires_grad_() + + do = torch.randn_like(v) + dht = torch.randn_like(h0) + ref, ref_ht = naive_recurrent_gla( + q=q, + k=k, + v=v, + gk=g, + initial_state=h0, + output_final_state=True + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + ref_dh0, h0.grad = h0.grad.clone(), None + + tri, tri_ht = fused_recurrent_gla( + q=q, + k=k, + v=v, + gk=g, + initial_state=h0, + output_final_state=True + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + tri_dh0, h0.grad = h0.grad.clone(), None + + assert_close(" o", ref, tri, 0.005) + assert_close(" ht", ref_ht, tri_ht, 0.005) + assert_close(" dq", ref_dq, tri_dq, 0.005) + assert_close(" dk", ref_dk, tri_dk, 0.005) + assert_close(" dv", ref_dv, tri_dv, 0.005) + assert_close(" dg", ref_dg, tri_dg, 0.005) + assert_close("dh0", ref_dh0, tri_dh0, 0.005) + + +@pytest.mark.parametrize("B", [4]) @pytest.mark.parametrize("T", [130, 146, 162, 178, 300, 2048]) +@pytest.mark.parametrize("H", [4]) @pytest.mark.parametrize("D", [300, 100]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float]) -@pytest.mark.parametrize("head_first", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float]) +@pytest.mark.parametrize("head_first", [False]) def test_chunk( B: int, - H: int, T: int, + H: int, D: int, dtype: torch.dtype, head_first: bool @@ -80,3 +142,59 @@ def test_chunk( assert_close(" dv", ref_dv, tri_dv, 0.005) assert_close(" dg", ref_dg, tri_dg, 0.005) assert_close("dh0", ref_dh0, tri_dh0, 0.005) + + +@pytest.mark.parametrize("N", [4]) +@pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048]) +@pytest.mark.parametrize("H", [4]) +@pytest.mark.parametrize("D", [300, 100]) +@pytest.mark.parametrize("dtype", [torch.float]) +def test_chunk_varlen( + N: int, + T: int, + H: int, + D: int, + dtype: torch.dtype, +): + torch.manual_seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + # randomly split the sequence into N segments + offsets = torch.cat([ + torch.tensor([0], dtype=torch.long), + torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], + torch.tensor([T], dtype=torch.long) + ], 0).cuda().sort()[0] + print(offsets) + # seq-first required for inputs with variable lengths + q = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_() + k = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_() + v = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_() + g = F.logsigmoid(torch.randn((1, T, H, D), dtype=dtype, device='cuda')).requires_grad_() + h0 = torch.randn((N, H, D, D), dtype=dtype, device='cuda').requires_grad_() + do = torch.randn_like(v) + + ref, ref_ht = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=True, offsets=offsets, head_first=False) + ref, _ = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=False, offsets=offsets, head_first=False) + + (ref * do).sum().backward() + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + ref_dh0, h0.grad = h0.grad.clone(), None + + tri, tri_ht = chunk_gla(q, k, v, g, initial_state=h0, output_final_state=True, offsets=offsets, head_first=False) + ((tri * do).sum()).backward() + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + tri_dh0, h0.grad = h0.grad.clone(), None + + assert_close(" o", ref, tri, 0.004) + assert_close(" ht", ref_ht, tri_ht, 0.005) + assert_close(" dq", ref_dq, tri_dq, 0.005) + assert_close(" dk", ref_dk, tri_dk, 0.005) + assert_close(" dv", ref_dv, tri_dv, 0.005) + assert_close(" dg", ref_dg, tri_dg, 0.005) + assert_close("dh0", ref_dh0, tri_dh0, 0.005)