-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Grad_norm & Loss are NAN when training Gated_Deltanet on fineweb-edu-10BT #111
Comments
@Chris-city Hi, could you provide detailed running cmds. |
BTW, did you pull the latest commits as we have fixed some out-of-boundary overflows recently. |
hi @Chris-city , i think the nan issue has been fixed in #99. let me know if the latest commit still have nan issue. |
hi @yzhangcs @sustcsonglin, I have already pulled the latest version of the commits, but the issue persists. In my latest attempt, I found that the problem seems to be related to the Triton version. After updating to Triton==3.1.0, torch==2.5.1, and CUDA==12.4, the code started running successfully again—at least, as of writing this reply, it has been running for 5k iterations without issues. However, with the previous environment versions (Triton==3.0.0, torch==2.4.1, CUDA==12.1), I consistently encountered NaN issues. |
@Chris-city Thank you! It seems that there are still some risky places we are not aware of. Could you save the crashed instances (q/k/v/g) once you met NaNs/INFs. if torch.isnan(...).any() or torch.isinf(...).any():
torch.save(...) We will check it soon. |
which gpu type were you using? |
hi @sustcsonglin, I have used A800-SXM4-80G GPUs. I found that it was indeed an issue with the Triton version. It couldn't run on version 3.0.0, but I successfully completed the training on version 3.1.0. |
@Chris-city interesting. we will keep an eye on it. currently no idea what's going wrong - btw, can you pass this pytest with triton 3.0? |
Describe the bug
Thank you for your excellent work! I’m using the training framework to train Gated-DeltaNet on the fineweb-edu-10BT dataset. However, I’ve noticed that regardless of which random seed I choose (e.g., 42, 2024, 3407) or which combination of model parameters I try, both the Loss and the Grad_norm in the training process always turn into NaN after around 100 iterations.
Steps to reproduce the bug
configs
{ "attn_mode": "chunk",
"bos_token_id": 1,
"eos_token_id": 2,
"expand_v": 1,
"fuse_cross_entropy": true,
"fuse_norm": true,
"hidden_act": "swish",
"hidden_ratio": 4,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": null,
"max_position_embeddings": 2048,
"model_type": "gated_deltanet",
"num_heads": 8,
"head_dim": 128,
"num_hidden_layers": 24,
"norm_first": false,
"norm_eps": 1e-06,
"tie_word_embeddings": true,
"use_cache": true,
"vocab_size": 32000 }
Output
{'loss': 10.1581, 'grad_norm': 2.6989870071411133, 'learning_rate': 9.375e-06, 'epoch': 0.0, 'num_tokens': 8388608, 'throughput': 13446.846687884983}
{'loss': 8.5623, 'grad_norm': 1.323537826538086, 'learning_rate': 1.875e-05, 'epoch': 0.0, 'num_tokens': 16777216, 'throughput': 21131.363145263378}
{'loss': 7.5845, 'grad_norm': 1.1358221769332886, 'learning_rate': 2.8125e-05, 'epoch': 0.0, 'num_tokens': 25165824, 'throughput': 26117.370942931575}
{'loss': 6.8409, 'grad_norm': 1.108870267868042, 'learning_rate': 3.75e-05, 'epoch': 0.0, 'num_tokens': 33554432, 'throughput': 29617.0397897203}
{'loss': 6.2549, 'grad_norm': 1.0967456102371216, 'learning_rate': 4.6874999999999994e-05, 'epoch': 0.0, 'num_tokens': 41943040, 'throughput': 32197.435661579908}
{'loss': 5.8436, 'grad_norm': 1.3389238119125366, 'learning_rate': 5.625e-05, 'epoch': 0.0, 'num_tokens': 50331648, 'throughput': 34184.24951713704}
{'loss': 5.5759, 'grad_norm': 1.0862421989440918, 'learning_rate': 6.5625e-05, 'epoch': 0.01, 'num_tokens': 58720256, 'throughput': 35757.30501202298}
{'loss': 5.3681, 'grad_norm': 1.2632718086242676, 'learning_rate': 7.5e-05, 'epoch': 0.01, 'num_tokens': 67108864, 'throughput': 37028.25834612053}
{'loss': 5.2003, 'grad_norm': 1.0916602611541748, 'learning_rate': 8.437499999999999e-05, 'epoch': 0.01, 'num_tokens': 75497472, 'throughput': 38089.754995776224}
{'loss': 5.0471, 'grad_norm': 1.062748670578003, 'learning_rate': 9.374999999999999e-05, 'epoch': 0.01, 'num_tokens': 83886080, 'throughput': 38983.185703428244}
{'loss': 4.7514, 'grad_norm': nan, 'learning_rate': 0.00010312499999999999, 'epoch': 0.01, 'num_tokens': 92274688, 'throughput': 39772.48555585977}
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0001125, 'epoch': 0.01, 'num_tokens': 100663296, 'throughput': 40469.73887130221}
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.000121875, 'epoch': 0.01, 'num_tokens': 109051904, 'throughput': 41055.38687761855}
Expected behavior
I don’t know how to solve this issue.
Environment info
The text was updated successfully, but these errors were encountered: