Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalization to BlockRNNModel #1748

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

JanFidor
Copy link
Contributor

@JanFidor JanFidor commented May 8, 2023

Fixes #1649.

Summary

I've added normalization parameter to the BlockRNNModel, I've brainstormed how to do it for RNNModel and I couldn't come up with a way that wouldn't require some type of dynamic aggregation of the hidden states, so I decided to make the PR for BlockRNN for now.
I added two torch modules to simplify the rnn sequence, not sure if it's the cleanest way to implement it, but it's at least very readable.

Other Information

I also added layer norm, because it was a simple addition and it seems to be the recommended normalization for RNNs. I also considered adding group normalization, but it would either need constant num_groups parameter or additional constructor parameter for BlockRNNModel

@JanFidor JanFidor requested review from hrzn and dennisbader as code owners May 8, 2023 06:41
@JanFidor
Copy link
Contributor Author

Some of the tests were failing, I'll check if continues after merging develop. One of them was test_fit_predict_determinism() which after debugging turned out to fail for ARIMA model, It wasn't in a scope of this PR so I'm unsure what might have happened. Might be a problem with my local build, I'll wait and see what the github actions say

@codecov-commenter
Copy link

codecov-commenter commented May 15, 2023

Codecov Report

Attention: 14 lines in your changes are missing coverage. Please review.

Comparison is base (8cb04f6) 93.88% compared to head (5de39ea) 93.78%.
Report is 2 commits behind head on master.

Files Patch % Lines
darts/models/forecasting/block_rnn_model.py 80.55% 7 Missing ⚠️
darts/utils/torch.py 66.66% 7 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1748      +/-   ##
==========================================
- Coverage   93.88%   93.78%   -0.10%     
==========================================
  Files         135      135              
  Lines       13425    13461      +36     
==========================================
+ Hits        12604    12625      +21     
- Misses        821      836      +15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@JanFidor
Copy link
Contributor Author

I've been thinking whether adding batch norm makes sense in this case, as repeated rescaling would cause gradient explosion, the very thing LSTM / GRU were supposed to combat. I'm inclined to only allow layer normalization (maybe also group norm), so that users don't accidentally fall into that trap. Let me know if you think that would fit with darts design philosophy !

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JanFidor, thank you again for contributing to darts.

Found this article, applying batch norm only on the output (not the hidden state). I would be curious to see a benchmark of the BlockRNNModel with and without normalization, check if we observe similar results.

self.norm = nn.BatchNorm1d(feature_size)

def forward(self, input):
input = self._reshape_input(input) # Reshape N L C -> N C L
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is more about swapping axes that reshaping, I would instead use the corresponding torch function:

Suggested change
input = self._reshape_input(input) # Reshape N L C -> N C L
# Reshape N L C -> N C L
input = input.swapaxes(1,2)

def forward(self, input):
input = self._reshape_input(input) # Reshape N L C -> N C L
input = self.norm(input)
input = self._reshape_input(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input = self._reshape_input(input)
input = input.swapaxes(1,2)

@JanFidor
Copy link
Contributor Author

Thanks for another review @madtoinou ! The article looks exciting (at least after skimming it and reading the abstract ;P ) I found some implementations online, but I'd rather understand the actual idea first before implementing it, so it might take me a little longer compared to the other 2 PRs

@JanFidor
Copy link
Contributor Author

JanFidor commented Sep 7, 2023

Hi @madtoinou , quick update! I've read the paper and have an idea how to implement it. It might need a little bit of magic to get the time_step_index into the model input, but I think it should be doable, I'll let you know when I'll get everything running or if I stumble into some problem

@JanFidor
Copy link
Contributor Author

JanFidor commented Sep 25, 2023

Quick update @madtoinou. I've been browsing through the codebase and wanted to get your thoughts on my planned approach. I think that the simplest approach would be to manually add a past encoder with static position, but that would require expanding IntegerIndexEncoder which only supports 'relative' for now. That said, I'm not sure at which point the Encoders are applied to the TS and this approach depends on it happening before TS are sliced for training. It's also possible to manually add a "static index" component, but I think this approach would be more elegant and static IntegerIndexEncoder might be useful in other implementations in the future

@JanFidor
Copy link
Contributor Author

JanFidor commented Oct 6, 2023

Hi again @madtoinou! I wanted to get your thoughts on my new idea for the implementation. I went back to the paper and found the mention of using the batch norms specifically when training. Wouldn't it just suffice to store input_chunk_length batch norms instead, which would be significantly easier, might actually be more inline with what the paper proposes and wouldn't require TS of same length for training? I'll go ahead with this idea and give another update once I make some basic benchmarks.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@JanFidor JanFidor force-pushed the feature/rnn-normalization branch from 1a4627f to 544ab44 Compare February 7, 2024 17:48
…lization

# Conflicts:
#	darts/models/forecasting/block_rnn_model.py
@JanFidor
Copy link
Contributor Author

JanFidor commented Feb 7, 2024

It took some playing around but I think I managed to fix most of the git history (please ignore the git push --force hahaha)

target_size: int,
normalization: str = None,
):
if not num_layers_out_fc:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this point here.

num_layers_out_fc is a list of integers correct?
Suppose num_layers_out_fc = [], then not num_layers_out_fc is True.
So why num_layers_out_fc = [] ?


last = input_size
feats = []
for feature in num_layers_out_fc + [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will rather use the extend method for lists

last = feature
return nn.Sequential(*feats)

def _normalization_layer(self, normalization: str, hidden_size: int):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if normalization is different from batch and layer the method return None. is this intended?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Retrieve last hidden state for RNNModel and BlockRNNModel
4 participants