Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Weight Conversion #120

Open
Triang-jyed-driung opened this issue Jan 14, 2025 · 7 comments
Open

[Feature Request] Weight Conversion #120

Triang-jyed-driung opened this issue Jan 14, 2025 · 7 comments
Labels
enhancement New feature or request

Comments

@Triang-jyed-driung
Copy link

Feature Request

Convert official weights to Flash-linear-attention's format, including RWKV, Mamba, etc.

Motivation

The official model weights of many linear models (especially RWKV series) cannot be directly accepted by Flash-linear-attention's code. Currently, some SFT and RLHF implementations rely heavily on the correct implementation of attention_mask to pad and truncate variable lengths. (RWKV's attention_mask is known to fail, Mamba's attention_mask works for forward but likely not working with backward). If we can convert the pre-trained weights to this module, making them compatible with HF and supporting variable lengths, we can easily implement SFT and RLHF for linear models.

Your Contribution

I'd like to try SFT and RLHF on some linear models.

@Triang-jyed-driung Triang-jyed-driung added the enhancement New feature or request label Jan 14, 2025
@yzhangcs
Copy link
Member

yzhangcs commented Jan 14, 2025

@Triang-jyed-driung Hello, actually I've included some conversion scripts, including Llama and RWKV6 for use
https://github.com/fla-org/flash-linear-attention/tree/main/utils
Both works fine during the period I'm developing GSA. LMK if it fails for now.

The mamba code in fla is adapted from hf so can directly load hf weights in hf hub.

Very glad if you could contribute more for conversions of other linear models.

@yzhangcs
Copy link
Member

Checkout https://huggingface.co/collections/fla-hub/rwkv6-665aaa86d4714ed3f8595aec

Both are converted from the official ckpts

@Triang-jyed-driung
Copy link
Author

@Triang-jyed-driung Hello, actually I've included some conversion scripts, including Llama and RWKV6 for use https://github.com/fla-org/flash-linear-attention/tree/main/utils Both works fine during the period I'm developing GSA. LMK if it fails for now.

The original RWKV-6 weights come from BlinkDL and you got second-hand weights.

The mamba code in fla is adapted from hf so can directly load hf weights in hf hub.

Is attention_mask properly implemented in the backward pass?

@yzhangcs
Copy link
Member

@Triang-jyed-driung We've done some inplace mul, so backward might not work.
Why don't you adopt varlen packing for SFT?

@yzhangcs
Copy link
Member

The original RWKV-6 weights come from BlinkDL and you got second-hand weights.

I see :godmode: but it's ok if var names are identical

@Triang-jyed-driung
Copy link
Author

Triang-jyed-driung commented Jan 15, 2025

@Triang-jyed-driung We've done some inplace mul, so backward might not work. Why don't you adopt varlen packing for SFT?

Packing does not work for RLHF, no packing supported for https://huggingface.co/docs/trl/en/ppo_trainer

@yzhangcs
Copy link
Member

@Triang-jyed-driung Thank you, you can try the current layers.
I'll keep an eye out for any problems that may arise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants