-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning + Multi-GPU Support with PyTorch #223
base: master
Are you sure you want to change the base?
Finetuning + Multi-GPU Support with PyTorch #223
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,nmgvv.
I hope this message finds you well.
Thank you for your amazing work on ‘timesfm’ and for sharing it with the community. I recently came across your pull request and was really excited to explore your code. However, I noticed that some parts of the code have not been uploaded yet, and as a new user, I'm finding it challenging to understand how to utilize the current implementation.
If it's not too much trouble, could you please complete the upload and, if possible, add more comments or examples? I believe it would help not just me but others in the community better understand and benefit from your contributions.
I truly appreciate your effort and understand that preparing and sharing code takes time. Please let me know if there's anything I can do to help with the process.
Thank you again for your hard work, and I look forward to seeing the completed project!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we shift these files to a folder names finetuning?
import pandas as pd | ||
import torch | ||
import torch.multiprocessing as mp | ||
import yfinance as yf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we also see the results on ETT and if so whether they match the pax results?
x_future = torch.tensor(x_future, dtype=torch.float32) | ||
|
||
input_padding = torch.zeros_like(x_context) | ||
freq = torch.zeros(1, dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we give users the option to choose the fine-tuning freq
between 0,1,2 ?
horizon_len=128, | ||
num_layers=50, | ||
use_positional_embedding=False, | ||
context_len=192, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment that this context length can be anything up to 2048 in multiples of 32 ?
if not torch.distributed.is_initialized(): | ||
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank) | ||
|
||
finetuner = TimesFMFinetuner(model, config, rank=rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does rank
mean ?
if __name__ == "__main__": | ||
try: | ||
# single_gpu_example() # Single GPU | ||
multi_gpu_example() # Multi-GPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add these two options as an absl flag ?
notebooks/finetuning_torch.py
Outdated
x_context, x_padding, freq, x_future = [t.to(self.device, non_blocking=True) for t in batch] | ||
|
||
predictions = self.model(x_context, x_padding.float(), freq) | ||
predictions_mean = predictions[..., 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add support for also tuning the quantiles: see here for the loss function in pax
notebooks/finetuning_torch.py
Outdated
loss, _ = self._process_batch(batch) | ||
|
||
if self.config.distributed: | ||
losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand why we are doing torch.zeros_like(loss)
here? Should we not get a set of actual losses from each distributed backprop ?
notebooks/finetuning_torch.py
Outdated
self.logger.info(f"Validation samples: {len(val_dataset)}") | ||
|
||
try: | ||
for epoch in range(self.config.num_epochs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Is it possible to have train and eval loss logging in between epochs as well if the epoch is very large? This is completely optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When fine-tuning, I noticed that the following messages are repeatedly printed:
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.
Could you help me take a look at this issue?
Description
This PR adds finetuning for TimesFM model on PyTorch and also adds the the Multi-GPU training on PyTorch. The framework is designed to be easy to use while providing advanced features for more sophisticated use cases.
For review
notebooks/finetuning_example.py
- client codenotebooks/finetuning_example.ipynb
- client code and example how to use finetuning frameworktimesfm/finetuning_torch.py
- framework codeFramework code contains class
TimesFMFinetuner
which accepts the instance of the model and theFinetuningConfig
Features
Testing
Make sure to follow the installation instructions. After you have pyenv env successfully created, just execute
poetry run python notebooks/finetuning_torch.py
. This will:Quick Start
Advanced Usage
Multi-GPU Training
Configuration Options
The
FinetuningConfig
class provides extensive configuration options:[WIP]