Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add RWKV7 kernels and models #105

Open
yzhangcs opened this issue Jan 5, 2025 · 5 comments
Open

[RFC] Add RWKV7 kernels and models #105

yzhangcs opened this issue Jan 5, 2025 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@yzhangcs
Copy link
Member

yzhangcs commented Jan 5, 2025

No description provided.

@yzhangcs yzhangcs added the enhancement New feature or request label Jan 5, 2025
@yzhangcs yzhangcs added this to the FLA v1.0.0 release milestone Jan 5, 2025
@sustcsonglin
Copy link
Collaborator

The forward pass for the chunkwise implementation has been completed and tested in this commit: 7569595.

@yzhangcs yzhangcs changed the title Add RWKV7 kernels [RFC] Add RWKV7 kernels Jan 6, 2025
@sustcsonglin sustcsonglin self-assigned this Jan 6, 2025
@sustcsonglin
Copy link
Collaborator

Backward pass has been implemented in this commit: e582c28

TODO: Implement RWKV 7 layer and model in FLA format

@sustcsonglin sustcsonglin changed the title [RFC] Add RWKV7 kernels [RFC] Add RWKV7 kernels and models Jan 12, 2025
@Triang-jyed-driung
Copy link

I think this might help: https://huggingface.co/SmerkyG/RWKV7-Goose-0.4B-Pile-HF/blob/main/modeling_rwkv7.py

Also, the triton kernels come from https://github.com/johanwind/wind_rwkv/tree/main/wind_rwkv/rwkv7

They use a technique called "backstepping" for the states to avoid recomputation.

@sustcsonglin
Copy link
Collaborator

I think this might help: https://huggingface.co/SmerkyG/RWKV7-Goose-0.4B-Pile-HF/blob/main/modeling_rwkv7.py

Also, the triton kernels come from https://github.com/johanwind/wind_rwkv/tree/main/wind_rwkv/rwkv7

They use a technique called "backstepping" for the states to avoid recomputation.

Hi @Triang-jyed-driung, thanks for your pointers! We're familiar with wind_rwkv's CUDA kernel but were unaware of the Triton kernel. However, I have some concerns regarding the numerical precision of wind's kernel. Has anyone tested its numerical precision and the relative error compared to the full FP32 recurrent kernels? Also has anyone tested the speed between wind's triton kernel and fla's kernel?

@Triang-jyed-driung
Copy link

Smerky (Dan Goldstein) tested different kernels. The fastest kernel is 2x (end to end!) faster, in termes of overall training speed. But at a risk of losing numerical precision. Please ask him (and Wind) for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants