Replace mamba2 mamba_chunk_scan_combined
triton kernel by simple_gla
triton kernel
#49
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Follow-up #39 (comment)
Eventually will allow the e2e mamba2 example #39 to run without the dependency on the original mamba_ssm repo.
This PR adds unit tests to ensure equivalence between {
chunk_simple_gla
/torch_simple_gla
/torch_simple_gla_recurrent
underfla.ops.simple_gla
of this repository} and {mamba_chunk_scan_combined
/ssd_minimal_discrete
insidemamba_ssm
repository}.Unit test output from this PR:
Differences between
simple_gla
kernel and "mamba2_ssd" kernel:[batch, seq, head, hidden]
, while simple_gla uses[batch, head, seq, hidden]
q * (DK ** -0.5)
dt
input for discretization, but this can be easily absorbed into the gating matrixA
as did in mamba2 exampleA
(though the minimal torch version does), probably because the time-dependence is expressed bydt
, notA_t
?simple_gla
supports time-varyingg
directly.simple_gla
(also other kernels in this repo?) always use the same number of heads for Q & K & V. For now, force the same number of heads in tests.Ref Section 7.2 of Mamba-2 paper:
Todo:
simple_gla
kernel (Mamba-Codestral usesn_groups=8
)FYI @DanFosing @yzhangcs @sustcsonglin