Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dh0 is None breaking backward pass #102

Merged
merged 1 commit into from
Dec 31, 2024
Merged

fix dh0 is None breaking backward pass #102

merged 1 commit into from
Dec 31, 2024

Conversation

Sxela
Copy link
Contributor

@Sxela Sxela commented Dec 30, 2024

Originally when running backward pass I get:

process_batch
    accelerator.backward(loss)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1983, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/fla/utils.py", line 18, in wrapper
    return fn(ctx,
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 511, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/fla/ops/rwkv6/fused_recurrent.py", line 601, in backward
    return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0.to(initial_state), None, None, None, None
AttributeError: 'NoneType' object has no attribute 'to'

Steps to reproduce:

import torch
import torch.nn as nn
from fla.layers.rwkv6 import RWKV6Attention

layer = RWKV6Attention(hidden_size=320).cuda()
t = torch.randn((300,10,320)).cuda()
t.requires_grad = True

t2 = torch.randn((300,10,320)).cuda()
res = layer(t)[0]

criterion = nn.MSELoss()

loss = criterion(t2, res)
loss.backward()

@yzhangcs
Copy link
Member

Thank you!

@yzhangcs yzhangcs merged commit cdbd3ac into fla-org:main Dec 31, 2024
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants