Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning + Multi-GPU Support with PyTorch #223

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

misha-chertushkin
Copy link

@misha-chertushkin misha-chertushkin commented Jan 20, 2025

Description

This PR adds finetuning for TimesFM model on PyTorch and also adds the the Multi-GPU training on PyTorch. The framework is designed to be easy to use while providing advanced features for more sophisticated use cases.

For review

  • notebooks/finetuning_example.py - client code
  • notebooks/finetuning_example.ipynb - client code and example how to use finetuning framework
  • timesfm/finetuning_torch.py - framework code

Framework code contains class TimesFMFinetuner which accepts the instance of the model and the FinetuningConfig

Features

  • 🚀 Easy-to-use API for finetuning TimesFM models
  • 📊 Integration with Weights & Biases for experiment tracking
  • 🔧 Configurable training parameters and loss functions
  • 📈 Advanced features including:
    • Multi-GPU training support
    • Custom loss functions
    • Data preprocessing utilities
    • Learning rate scheduling
    • Early stopping
    • Model checkpointing
  • 📝 Comprehensive logging and metrics tracking

Testing

Make sure to follow the installation instructions. After you have pyenv env successfully created, just execute poetry run python notebooks/finetuning_torch.py. This will:

  • Create the TimesFM model on PyTorch
  • Load the checkpoint from HuggingFace
  • Finetune the model on the yfinance dataset

Quick Start

import numpy as np
from timesfm import TimesFm, TimesFmHparams
from timesfm_finetuner import FinetuningConfig, TimesFMFinetuner

# Load your time series data
time_series = np.load('my_time_series.npy')

# Initialize TimesFM model
model = TimesFm(
    hparams=TimesFmHparams(
        backend="cuda",
        per_core_batch_size=32,
        horizon_len=128,
        num_layers=50,
    )
)

# Configure finetuning
config = FinetuningConfig(
    batch_size=64,
    num_epochs=50,
    learning_rate=1e-4,
    use_wandb=True  # Enable W&B logging
)

# Create finetuner
finetuner = TimesFMFinetuner(config, model.config)

# Finetune model
results = finetuner.finetune(model, time_series, "my_model")

Advanced Usage

Multi-GPU Training

  gpu_ids = [0, 1]
  config = FinetuningConfig(
    batch_size=256,
    num_epochs=5,
    learning_rate=3e-5,
    use_wandb=True,
    distributed=True,
    gpu_ids=gpu_ids,
  )

Configuration Options

The FinetuningConfig class provides extensive configuration options:

config = FinetuningConfig(
    # Training parameters
    batch_size=32,
    num_epochs=20,
    learning_rate=1e-4,
    weight_decay=0.01,
    
    # Model parameters
    context_length=128,
    horizon_length=32,
    

    # Multi-GPU training
    distributed=True,
    gpu_ids=gpu_ids,
)

[WIP]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,nmgvv.
I hope this message finds you well.
Thank you for your amazing work on ‘timesfm’ and for sharing it with the community. I recently came across your pull request and was really excited to explore your code. However, I noticed that some parts of the code have not been uploaded yet, and as a new user, I'm finding it challenging to understand how to utilize the current implementation.
If it's not too much trouble, could you please complete the upload and, if possible, add more comments or examples? I believe it would help not just me but others in the community better understand and benefit from your contributions.
I truly appreciate your effort and understand that preparing and sharing code takes time. Please let me know if there's anything I can do to help with the process.
Thank you again for your hard work, and I look forward to seeing the completed project!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we shift these files to a folder names finetuning?

import pandas as pd
import torch
import torch.multiprocessing as mp
import yfinance as yf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we also see the results on ETT and if so whether they match the pax results?

x_future = torch.tensor(x_future, dtype=torch.float32)

input_padding = torch.zeros_like(x_context)
freq = torch.zeros(1, dtype=torch.long)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we give users the option to choose the fine-tuning freq between 0,1,2 ?

horizon_len=128,
num_layers=50,
use_positional_embedding=False,
context_len=192,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that this context length can be anything up to 2048 in multiples of 32 ?

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)

finetuner = TimesFMFinetuner(model, config, rank=rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does rank mean ?

if __name__ == "__main__":
try:
# single_gpu_example() # Single GPU
multi_gpu_example() # Multi-GPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add these two options as an absl flag ?

x_context, x_padding, freq, x_future = [t.to(self.device, non_blocking=True) for t in batch]

predictions = self.model(x_context, x_padding.float(), freq)
predictions_mean = predictions[..., 0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add support for also tuning the quantiles: see here for the loss function in pax

loss, _ = self._process_batch(batch)

if self.config.distributed:
losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand why we are doing torch.zeros_like(loss) here? Should we not get a set of actual losses from each distributed backprop ?

self.logger.info(f"Validation samples: {len(val_dataset)}")

try:
for epoch in range(self.config.num_epochs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it possible to have train and eval loss logging in between epochs as well if the epoch is very large? This is completely optional.

Copy link

@otakuCandy otakuCandy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When fine-tuning, I noticed that the following messages are repeatedly printed:
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.
Could you help me take a look at this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants