Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: GSA and RWKV6 Occasionally Report Gradient=NAN when Backward #77

Open
WorldEditors opened this issue Nov 7, 2024 · 16 comments
Open
Labels
bug Something isn't working

Comments

@WorldEditors
Copy link
Contributor

Describe the bug

Running training for GSA and RWKV will result in NAN gradient occasionally, rare at the beginning stage, but getting more frequent as the training processes.
I checked parameters and losses, all of which are reasonable and shows no sign of explosion. The NAN comes suddenly. By switching model back to Transformer, this never happens

by using with torch.detect_anomaly(), get the following log:

  File "/home/xxx/Codes/flash-linear-attention/fla/layers/gsa.py", line 181, in forward
    o, recurrent_state = chunk_gsa(q, k, v, s, f,
  File "/home/xxx/Codes/flash-linear-attention/fla/ops/gsa/chunk.py", line 1203, in chunk_gsa
    ov, *final_state = ChunkGSAFunction.apply(q, k, v, s, g, scale, hk0, hv0, output_final_state, checkpoint_level)
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
W1107 18:46:16.781683 139687942383424 torch/multiprocessing/spawn.py:146] Terminating process 2106869 via signal SIGTERM
W1107 18:46:16.783345 139687942383424 torch/multiprocessing/spawn.py:146] Terminating process 2106870 via signal SIGTERM
W1107 18:46:16.783725 139687942383424 torch/multiprocessing/spawn.py:146] Terminating process 2106871 via signal SIGTERM
Traceback (most recent call last):
  File "train.py", line 10, in <module>
    runner.start(AnyMDPRSA, AnyMDPEpoch, AnyMDPEpoch)
  File "/home/xxx/Codes/L3C_Baselines/l3c_baselines/utils/trainer.py", line 335, in start
    mp.spawn(dist_process,
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
    while not context.join():
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 189, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
    fn(i, *args)
  File "/home/xxx/Codes/L3C_Baselines/l3c_baselines/utils/trainer.py", line 288, in dist_process
    for need_evaluate in train_object.run(epoch, device, device_type):
  File "/home/xxx/Codes/L3C_Baselines/l3c_baselines/utils/trainer.py", line 151, in run
    self.computer.compute(
  File "/home/xxx/Codes/L3C_Baselines/projects/AnyMDP/anymdp_epoch.py", line 80, in compute
    syn_loss.backward()
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/xxx/.local/lib/python3.8/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'ChunkGSAFunctionBackward' returned nan values in its 3th output.

Steps to reproduce the bug

Can not provide a code sample, it did not happen.in specific model in specific steps.

Expected behavior

N/A

Environment info

  1. torch: 2.4.1
  2. triton: 3.0.0

CUDA: 12.6, NVIDIA A800 80GB PCIe

@WorldEditors WorldEditors added the bug Something isn't working label Nov 7, 2024
@yzhangcs
Copy link
Member

yzhangcs commented Nov 7, 2024

Thank you for reporting this issue. Could you elaborate more on the input shapes so that I can do some simulation exps

@WorldEditors
Copy link
Contributor Author

WorldEditors commented Nov 7, 2024

some debugs showing the following part in chunk.py causes the problem

1016     grid = (NV, NT * NC, B * HQ)
1017     chunk_gsa_bwd_kernel_intra_KV[grid](
1018         v, g, o, A, do, dv, dg,
1019         v.stride(1), v.stride(2),
1020         T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1021         OVERWRITE_DG=overwrite_dg,
1022         num_warps=num_warps,
1023         num_stages=num_stages
1024     )   # After this function dv and dg suddenly goes to NAN
1025     return dq, dk, dv, dg, dh0

hyper parameters:
hidden_size = 512,
num_slots=64,
nheads=4

Others are just default (GatedSlotAttention function default)

@yzhangcs
Copy link
Member

yzhangcs commented Nov 7, 2024

@WorldEditors How about the sequence length

@WorldEditors
Copy link
Contributor Author

tried sequence length with 3K, 12K, 24K, all possible to reproduce the problem

@WorldEditors
Copy link
Contributor Author

In GSA, I tried restore training from a problematic checkpoint, and there are some interesting discoveries:

The training will be OK only if I reinitialize the f_proj parameters.
While reinitialize the other parameters (q_proj, k_proj, v_proj, g_norm, o_proj) won't solve the problem.

So I guess there is something wrong with the f_proj parameter?

But why there is similar problems in RWKV6? I have no idea

@sustcsonglin
Copy link
Collaborator

Hi, Thanks for reporting it! Do you have input tensor and model weight such that we can reproduce it?

@WorldEditors
Copy link
Contributor Author

Hi, Thanks for reporting it! Do you have input tensor and model weight such that we can reproduce it?

I'm afraid it is quite challenging to extract a minimal reproducible dataset, code, and model.
But it seems to be very possible that this has something to do with the forgetting gate (f and g).
Hopefully this will give some clues to you

@WorldEditors
Copy link
Contributor Author

Update:

I tried bounding the forget gate value as

https://github.com/sustcsonglin/flash-linear-attention/pull/78/files

Running another 400 iterations does not yield any abnormalty

But this has the risk of impact exsiting models.

@yzhangcs
Copy link
Member

@WorldEditors Hello, sry for my super late reply. Just refactored GSA layers/kernels to address potential indexing and gradient errors, could you try it again? Looking forward to your feedbacks.

@WorldEditors
Copy link
Contributor Author

WorldEditors commented Dec 3, 2024

@WorldEditors Hello, sry for my super late reply. Just refactored GSA layers/kernels to address potential indexing and gradient errors, could you try it again? Looking forward to your feedbacks.

@yzhangcs we pulled the newest branch and confirmed that this problem still exists (for both GSA and RWKV). The training is OK at the initial stage but explodes latter.

Currently the only solution that we found to prevent NAN is by adding a hard bound on f.
Moreover, we found the highest hard bound for GSA to prevent NAN is |f| < 20, for RWKV is |w| < 12

It is worth noticing that our training data is not natural language. A guess is that for NL data, the f and w never goes to that large.

@yzhangcs
Copy link
Member

yzhangcs commented Dec 3, 2024

@WorldEditors Hello, someone told me that current fla kernels is not robust enough for visual data.
You may need to add QK norm to prevent explosions if using fp16/bf16 for now.

Could you try some other kernels like GLA/DeltaNet. I'm curious if it is a common problem.

@WorldEditors
Copy link
Contributor Author

@WorldEditors Hello, someone told me that current fla kernels is not robust enough for visual data. You may need to add QK norm to prevent explosions if using fp16/bf16 for now.

Could you try some other kernels like GLA/DeltaNet. I'm curious if it is a common problem.

I have been using FP32 for now and NAN happens as well
Will try GLA later

Copy link

github-actions bot commented Jan 5, 2025

This issue is stale because it has been open for 30 days with no activity.

@github-actions github-actions bot added the stale label Jan 5, 2025
Copy link

This issue was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jan 12, 2025
@sustcsonglin
Copy link
Collaborator

I've identified the source of the NaN issue and resolved it in this commit: 0c1c0b6 (pressure testing on very small log decay has already been conducted).

The problem arose because the first position of b_gq - b_gn[None, :] could become very large (as b_gq represents cumulative decay exclusively, leading to this edge case), causing tl.exp(b_gq - b_gn[None, :]) to explode.

GLA and GSA do not encounter this problem since their cumulative decay is inclusive rather than exclusive

@sustcsonglin
Copy link
Collaborator

Fixed in 0c1c0b6

@github-actions github-actions bot removed the stale label Jan 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants