diff --git a/fla/ops/rwkv6/fused_recurrent.py b/fla/ops/rwkv6/fused_recurrent.py index ffa9d2da3..eefb9b678 100644 --- a/fla/ops/rwkv6/fused_recurrent.py +++ b/fla/ops/rwkv6/fused_recurrent.py @@ -598,7 +598,7 @@ def backward(ctx, do, dht): offsets=ctx.offsets, head_first=ctx.head_first ) - return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0.to(initial_state), None, None, None, None + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0.to(initial_state) if dh0 is not None else dh0, None, None, None, None def fused_recurrent_rwkv6(