Skip to content

Commit

Permalink
Update amp custom_fwd, custom_bwd usage for torch 2.4.0 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Aug 25, 2024
1 parent 3583315 commit 1a8fc1b
Show file tree
Hide file tree
Showing 18 changed files with 60 additions and 67 deletions.
7 changes: 3 additions & 4 deletions fla/ops/abc/recurrent_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -284,7 +283,7 @@ class FusedRecurrentGatedABCFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(
ctx,
q: torch.Tensor,
Expand Down Expand Up @@ -374,7 +373,7 @@ def forward(

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/based/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# on-the-fly computation without materializing hidden statets into HBMs

Expand Down Expand Up @@ -305,7 +304,7 @@ class FusedChunkBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale=1):
B, H, T, K, V = *k.shape, v.shape[-1]

Expand Down Expand Up @@ -338,7 +337,7 @@ def forward(ctx, q, k, v, scale=1):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/based/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# Based: An Educational and Effective Sequence Mixer
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
Expand Down Expand Up @@ -314,7 +313,7 @@ class ParallelBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
Expand Down Expand Up @@ -349,7 +348,7 @@ def forward(ctx, q, k, v, scale):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.wy_fast import (bwd_prepare_wy_repr,
fwd_prepare_wy_repr, fwd_recompute_w_u)
from fla.ops.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd


@triton.autotune(
Expand Down Expand Up @@ -491,7 +491,7 @@ class ChunkDeltaRuleFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
# obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
Expand All @@ -512,7 +512,7 @@ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoin

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
BT = ctx.BT
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# on-the-fly computation without materializing hidden statets into HBMs
Expand Down Expand Up @@ -327,7 +326,7 @@ class FusedChunkDeltaRuleFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
assert checkpoint_level in [0, 1]
Expand All @@ -345,7 +344,7 @@ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoin

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_final_state=None):
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
chunk_size = ctx.chunk_size
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
Expand Down Expand Up @@ -191,7 +190,7 @@ def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):

class WYRepresentationPrepration(torch.autograd.Function):
@contiguous
@custom_fwd
@autocast_custom_fwd
@staticmethod
def forward(ctx, k, v, beta, chunk_size):
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
Expand All @@ -200,7 +199,7 @@ def forward(ctx, k, v, beta, chunk_size):
return o_cumdecay, v_new

@contiguous
@custom_bwd
@autocast_custom_bwd
@staticmethod
def backward(ctx, do, do2):
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
Expand Down Expand Up @@ -288,7 +287,7 @@ class WYRepresentationPrepration(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, k, v, beta, chunk_size=64):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
Expand All @@ -297,7 +296,7 @@ def forward(ctx, k, v, beta, chunk_size=64):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/gla/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import triton.language as tl
from einops import rearrange
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
prepare_qg_kg)
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -304,7 +303,7 @@ class FusedChunkGLAFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
ctx.g_dtype = g.dtype
g_original = g
Expand Down Expand Up @@ -396,7 +395,7 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/gla/recurrent_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# on-the-fly computation without materializing hidden statets into HBMs

Expand Down Expand Up @@ -223,7 +222,7 @@ class FusedRecurrentGLAFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
B, H, T, K, V = *q.shape, v.shape[-1]
# default scale
Expand Down Expand Up @@ -270,7 +269,7 @@ def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_s

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
batch_size, n_heads, seq_len, K = q.shape
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/linear_attn/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.linear_attn.utils import normalize_output
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -238,7 +237,7 @@ class ChunkLinearAttentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
Expand Down Expand Up @@ -282,7 +281,7 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, h = ctx.saved_tensors

Expand Down
7 changes: 3 additions & 4 deletions fla/ops/linear_attn/chunk_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.ops.linear_attn.utils import normalize_output
from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.jit
Expand Down Expand Up @@ -208,7 +207,7 @@ class FusedChunkLinearAttentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *k.shape, v.shape[-1]
BT = 64
Expand Down Expand Up @@ -255,7 +254,7 @@ def forward(ctx, q, k, v, scale, initial_state, output_final_state):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dht=None):
q, k, v, initial_state = ctx.saved_tensors
B, H, T, K, V = *k.shape, v.shape[-1]
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/rebased/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
Expand Down Expand Up @@ -339,7 +338,7 @@ class ParallelBasedFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
Expand Down Expand Up @@ -374,7 +373,7 @@ def forward(ctx, q, k, v, scale):

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
Expand Down
7 changes: 3 additions & 4 deletions fla/ops/retention/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd

from fla.utils import contiguous
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous


@triton.autotune(
Expand Down Expand Up @@ -375,7 +374,7 @@ class ChunkRetentionFunction(torch.autograd.Function):

@staticmethod
@contiguous
@custom_fwd
@autocast_custom_fwd
def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_level):
BT = 64
h, final_state = chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state)
Expand All @@ -388,7 +387,7 @@ def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_l

@staticmethod
@contiguous
@custom_bwd
@autocast_custom_bwd
def backward(ctx, do, d_ht=None):
BT, scale = ctx.BT, ctx.scale
q, k, v, h, initial_state = ctx.saved_tensors
Expand Down
Loading

0 comments on commit 1a8fc1b

Please sign in to comment.