Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Model] RWKV #1902

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Conversation

JanFidor
Copy link
Contributor

@JanFidor JanFidor commented Jul 17, 2023

Fixes #1817 .

Quick summary

For now the implementation follows pretty closely what was described in the paper. The implementation from the official RWKV repo has quite a few improvements which weren't discussed in the paper, but for now I wanted to get at least a workable model.

Roadmap

  • Update model initializations which were hard coded for now. I know, very bad idea, but the paper used initializations assigning different weights to different embedding which didn't feel like a good idea for a TS model.
  • Use teacher forcing for training
  • Make a benchmark with SOTA models (ex. DLinear, NLinear, TFTModel).
  • Add support for past covariates (wasn't 100% sure how to do it with the model being auto-regressive)
  • More initialization benchmarks
  • Browsing the RWKV repo to look for improvements which would make sense in a TS model
  • This one is a long shot, but I was thinking about adding support for both future and static covariates. It would require fiddling with the attention mechanism, but it feels doable.
  • Add support for probabilistic forecasting
  • Even more benchmarks (especially performance wise, as the RWKV should do pretty well when it come to long input and output chunk lengths)
  • Add tests and update Readme, Changelog and docstrings

There's still a lot of things to be done, but I wanted to put up a PR as a quick update on how everything's going and a simple roadmap for the future

@JanFidor JanFidor requested a review from dennisbader as a code owner July 17, 2023 19:53
@codecov-commenter
Copy link

codecov-commenter commented Jul 17, 2023

Codecov Report

Patch coverage: 23.78% and project coverage change: -0.98 ⚠️

Comparison is base (a5560cc) 93.95% compared to head (7a1f0ee) 92.97%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1902      +/-   ##
==========================================
- Coverage   93.95%   92.97%   -0.98%     
==========================================
  Files         125      126       +1     
  Lines       11773    11923     +150     
==========================================
+ Hits        11061    11086      +25     
- Misses        712      837     +125     
Impacted Files Coverage Δ
darts/models/forecasting/rwkv_model.py 23.31% <23.31%> (ø)
darts/models/__init__.py 57.69% <100.00%> (+0.54%) ⬆️

... and 6 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@dennisbader
Copy link
Collaborator

Hi @JanFidor, and thanks for this PR. Just to let you know that we're wrapping up the last few things for the release in 1-2 weeks. Once that's done we'll come back to this and review 🚀

@gdevos010
Copy link
Contributor

@JanFidor Were you able to benchmark this model?

@JanFidor
Copy link
Contributor Author

JanFidor commented Aug 30, 2023

@gdevos010 just some basic ones, I still have to play around with parameter initializations. On SunspotsDataset I noticed that NLinear and Transformer were having noticeable MAPE changes depending on output_chunk_length (changes around 60 <-> 200 ) while RWKV was consistently performing around 100. I also threw in ETTh1 dataset, with 720 input_chunk _length 336 output_chunk_length. The RWKV had terrible MAPE. Not sure it the architecture was at fault or if it was caused by under fitting. I'll try to make a more comprehensive benchmark next week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New model] RWKV
4 participants