-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Add RWKV7 kernels and models #105
Comments
The forward pass for the chunkwise implementation has been completed and tested in this commit: 7569595. |
Backward pass has been implemented in this commit: e582c28 TODO: Implement RWKV 7 layer and model in FLA format |
I think this might help: https://huggingface.co/SmerkyG/RWKV7-Goose-0.4B-Pile-HF/blob/main/modeling_rwkv7.py Also, the triton kernels come from https://github.com/johanwind/wind_rwkv/tree/main/wind_rwkv/rwkv7 They use a technique called "backstepping" for the states to avoid recomputation. |
Hi @Triang-jyed-driung, thanks for your pointers! We're familiar with wind_rwkv's CUDA kernel but were unaware of the Triton kernel. However, I have some concerns regarding the numerical precision of wind's kernel. Has anyone tested its numerical precision and the relative error compared to the full FP32 recurrent kernels? Also has anyone tested the speed between wind's triton kernel and fla's kernel? |
Smerky (Dan Goldstein) tested different kernels. The fastest kernel is 2x (end to end!) faster, in termes of overall training speed. But at a risk of losing numerical precision. Please ask him (and Wind) for details. |
No description provided.
The text was updated successfully, but these errors were encountered: