-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: GSA and RWKV6 Occasionally Report Gradient=NAN when Backward #77
Comments
Thank you for reporting this issue. Could you elaborate more on the input shapes so that I can do some simulation exps |
some debugs showing the following part in chunk.py causes the problem 1016 grid = (NV, NT * NC, B * HQ)
1017 chunk_gsa_bwd_kernel_intra_KV[grid](
1018 v, g, o, A, do, dv, dg,
1019 v.stride(1), v.stride(2),
1020 T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1021 OVERWRITE_DG=overwrite_dg,
1022 num_warps=num_warps,
1023 num_stages=num_stages
1024 ) # After this function dv and dg suddenly goes to NAN
1025 return dq, dk, dv, dg, dh0 hyper parameters: Others are just default (GatedSlotAttention function default) |
@WorldEditors How about the sequence length |
tried sequence length with 3K, 12K, 24K, all possible to reproduce the problem |
In GSA, I tried restore training from a problematic checkpoint, and there are some interesting discoveries: The training will be OK only if I reinitialize the So I guess there is something wrong with the f_proj parameter? But why there is similar problems in RWKV6? I have no idea |
Hi, Thanks for reporting it! Do you have input tensor and model weight such that we can reproduce it? |
I'm afraid it is quite challenging to extract a minimal reproducible dataset, code, and model. |
Update: I tried bounding the forget gate value as https://github.com/sustcsonglin/flash-linear-attention/pull/78/files Running another 400 iterations does not yield any abnormalty But this has the risk of impact exsiting models. |
@WorldEditors Hello, sry for my super late reply. Just refactored GSA layers/kernels to address potential indexing and gradient errors, could you try it again? Looking forward to your feedbacks. |
@yzhangcs we pulled the newest branch and confirmed that this problem still exists (for both GSA and RWKV). The training is OK at the initial stage but explodes latter. Currently the only solution that we found to prevent NAN is by adding a hard bound on f. It is worth noticing that our training data is not natural language. A guess is that for NL data, the f and w never goes to that large. |
@WorldEditors Hello, someone told me that current fla kernels is not robust enough for visual data. Could you try some other kernels like GLA/DeltaNet. I'm curious if it is a common problem. |
I have been using FP32 for now and NAN happens as well |
This issue is stale because it has been open for 30 days with no activity. |
This issue was closed because it has been inactive for 7 days since being marked as stale. |
I've identified the source of the NaN issue and resolved it in this commit: 0c1c0b6 (pressure testing on very small log decay has already been conducted). The problem arose because the first position of GLA and GSA do not encounter this problem since their cumulative decay is inclusive rather than exclusive |
Fixed in 0c1c0b6 |
Describe the bug
Running training for GSA and RWKV will result in NAN gradient occasionally, rare at the beginning stage, but getting more frequent as the training processes.
I checked parameters and losses, all of which are reasonable and shows no sign of explosion. The NAN comes suddenly. By switching model back to Transformer, this never happens
by using with torch.detect_anomaly(), get the following log:
Steps to reproduce the bug
Can not provide a code sample, it did not happen.in specific model in specific steps.
Expected behavior
N/A
Environment info
CUDA: 12.6, NVIDIA A800 80GB PCIe
The text was updated successfully, but these errors were encountered: