Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace mamba2 mamba_chunk_scan_combined triton kernel by simple_gla triton kernel #49

Merged
merged 6 commits into from
Aug 18, 2024

Conversation

learning-chip
Copy link
Contributor

@learning-chip learning-chip commented Aug 18, 2024

Follow-up #39 (comment)

Eventually will allow the e2e mamba2 example #39 to run without the dependency on the original mamba_ssm repo.

This PR adds unit tests to ensure equivalence between {chunk_simple_gla/torch_simple_gla/torch_simple_gla_recurrent under fla.ops.simple_gla of this repository} and {mamba_chunk_scan_combined/ssd_minimal_discrete inside mamba_ssm repository}.

Unit test output from this PR:

$ pytest -v ./test_simple_gla_for_mamba2.py
====================================================== test session starts ======================================================
collected 6 items                                                                                                               

test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float32-True] PASSED                                                    [ 16%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float32-False] PASSED                                                   [ 33%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float16-True] PASSED                                                    [ 50%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[float16-False] PASSED                                                   [ 66%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[bfloat16-True] PASSED                                                   [ 83%]
test_simple_gla_for_mamba2.py::test_gla_to_mamba2[bfloat16-False] PASSED                                                  [100%]

Differences between simple_gla kernel and "mamba2_ssd" kernel:

  • mamba2_ssd uses input/output layout [batch, seq, head, hidden], while simple_gla uses [batch, head, seq, hidden]
  • mamba2_ssd does not apply the attention-inspired scaling q * (DK ** -0.5)
  • mamba2_ssd takes an extra dt input for discretization, but this can be easily absorbed into the gating matrix A as did in mamba2 example
  • mamba2_ssd's fused kernel does not take time-varying A (though the minimal torch version does), probably because the time-dependence is expressed by dt, not A_t? simple_gla supports time-varying g directly.
  • mamba2_ssd uses "group query attention", but simple_gla (also other kernels in this repo?) always use the same number of heads for Q & K & V. For now, force the same number of heads in tests.

Ref Section 7.2 of Mamba-2 paper:
group_query

Todo:

FYI @DanFosing @yzhangcs @sustcsonglin

@yzhangcs
Copy link
Member

@learning-chip very cool contributions! I think it would be great if you add some benchmarks regarding simple_gla and mamba2 kernels like in https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_gla.py.

@yzhangcs
Copy link
Member

I will be working on GQA recently

@yzhangcs yzhangcs marked this pull request as ready for review August 18, 2024 17:40
@yzhangcs yzhangcs merged commit 9aa2480 into fla-org:main Aug 18, 2024
1 check passed
@learning-chip
Copy link
Contributor Author

add some benchmarks regarding simple_gla and mamba2 kernels like in https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/ops/benchmark_gla.py.

Some quick results #50

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants