diff --git a/notebooks/finetuning_example.py b/notebooks/finetuning_example.py index f1f8a74..07325b9 100644 --- a/notebooks/finetuning_example.py +++ b/notebooks/finetuning_example.py @@ -261,7 +261,7 @@ def multi_gpu_example(): config = FinetuningConfig( batch_size=256, num_epochs=5, - learning_rate=1e-4, + learning_rate=3e-5, use_wandb=True, distributed=True, gpu_ids=gpu_ids, diff --git a/notebooks/finetuning_torch.py b/notebooks/finetuning_torch.py index 29215ef..a3c7f3e 100644 --- a/notebooks/finetuning_torch.py +++ b/notebooks/finetuning_torch.py @@ -1,115 +1,204 @@ """ TimesFM Finetuner: A flexible framework for finetuning TimesFM models on custom datasets. - -Example usage: - ```python - # Prepare datasets - train_dataset = TimeSeriesDataset(train_data, context_length=128, horizon_length=32) - val_dataset = TimeSeriesDataset(val_data, context_length=128, horizon_length=32) - - # Initialize model and configuration - model = TimesFm(...) - config = FinetuningConfig( - batch_size=64, - num_epochs=50, - learning_rate=1e-4, - use_wandb=True - ) - - # Create finetuner - finetuner = TimesFMFinetuner(model, config) - - # Finetune model - results = finetuner.finetune(train_dataset, val_dataset) - ``` """ -import abc import logging -import multiprocessing as mp import os +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch import torch.distributed as dist -import torch.optim as optim +import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset import wandb +class MetricsLogger(ABC): + """Abstract base class for logging metrics during training. + + This class defines the interface for logging metrics during model training. + Concrete implementations can log to different backends (e.g., WandB, TensorBoard). + """ + + @abstractmethod + def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log metrics to the specified backend. + + Args: + metrics: Dictionary containing metric names and values. + step: Optional step number or epoch for the metrics. + """ + pass + + @abstractmethod + def close(self) -> None: + """Clean up any resources used by the logger.""" + pass + + +class WandBLogger(MetricsLogger): + """Weights & Biases implementation of metrics logging. + + Args: + project: Name of the W&B project. + config: Configuration dictionary to log. + rank: Process rank in distributed training. + """ + + def __init__(self, project: str, config: Dict[str, Any], rank: int = 0): + self.rank = rank + if rank == 0: + wandb.init(project=project, config=config) + + def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log metrics to W&B if on the main process. + + Args: + metrics: Dictionary of metrics to log. + step: Current training step or epoch. + """ + if self.rank == 0: + wandb.log(metrics, step=step) + + def close(self) -> None: + """Finish the W&B run if on the main process.""" + if self.rank == 0: + wandb.finish() + + +class DistributedManager: + """Manages distributed training setup and cleanup. + + Args: + world_size: Total number of processes. + rank: Process rank. + master_addr: Address of the master process. + master_port: Port for distributed communication. + backend: PyTorch distributed backend to use. + """ + + def __init__( + self, + world_size: int, + rank: int, + master_addr: str = "localhost", + master_port: str = "12358", + backend: str = "nccl", + ): + self.world_size = world_size + self.rank = rank + self.master_addr = master_addr + self.master_port = master_port + self.backend = backend + + def setup(self) -> None: + """Initialize the distributed environment.""" + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = self.master_port + + if not dist.is_initialized(): + dist.init_process_group(backend=self.backend, world_size=self.world_size, rank=self.rank) + + def cleanup(self) -> None: + """Clean up the distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + @dataclass class FinetuningConfig: - """Configuration for TimesFM finetuning process.""" + """Configuration for model training. + + Args: + batch_size: Number of samples per batch. + num_epochs: Number of training epochs. + learning_rate: Initial learning rate. + weight_decay: L2 regularization factor. + device: Device to train on ('cuda' or 'cpu'). + distributed: Whether to use distributed training. + gpu_ids: List of GPU IDs to use. + master_port: Port for distributed training. + master_addr: Address for distributed training. + use_wandb: Whether to use Weights & Biases logging. + wandb_project: W&B project name. + """ - # Training parameters batch_size: int = 32 num_epochs: int = 20 learning_rate: float = 1e-4 weight_decay: float = 0.01 - - # Hardware parameters device: str = "cuda" if torch.cuda.is_available() else "cpu" distributed: bool = False - world_size: int = 1 - - # Logging parameters - use_wandb: bool = False - wandb_project: str = "timesfm-finetuning" - - gpu_ids: List[int] = field(default_factory=lambda: [0]) # List of GPU IDs to use - distributed: bool = False + gpu_ids: List[int] = field(default_factory=lambda: [0]) master_port: str = "12358" master_addr: str = "localhost" + use_wandb: bool = False + wandb_project: str = "timesfm-finetuning" class TimesFMFinetuner: + """Handles model training and validation. + + Args: + model: PyTorch model to train. + config: Training configuration. + rank: Process rank for distributed training. + loss_fn: Loss function (defaults to MSE). + logger: Optional logging.Logger instance. + """ + def __init__( self, - model, + model: nn.Module, config: FinetuningConfig, rank: int = 0, - loss_fn: Optional[callable] = None, + loss_fn: Optional[Callable] = None, logger: Optional[logging.Logger] = None, ): self.model = model self.config = config - self.logger = logger or logging.getLogger(__name__) self.rank = rank - - if config.distributed: - self._setup_distributed(rank) - + self.logger = logger or logging.getLogger(__name__) self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") self.loss_fn = loss_fn or (lambda x, y: torch.mean((x - y.squeeze(-1)) ** 2)) - if config.use_wandb and rank == 0: - self._setup_wandb() + if config.use_wandb: + self.metrics_logger = WandBLogger(config.wandb_project, config.__dict__, rank) - def _setup_distributed(self, rank): - """Setup distributed training environment.""" - os.environ["MASTER_ADDR"] = self.config.master_addr - os.environ["MASTER_PORT"] = self.config.master_port + if config.distributed: + self.dist_manager = DistributedManager( + world_size=len(config.gpu_ids), + rank=rank, + master_addr=config.master_addr, + master_port=config.master_port, + ) + self.dist_manager.setup() + self.model = self._setup_distributed_model() - if not dist.is_initialized(): - dist.init_process_group(backend="nccl", world_size=len(self.config.gpu_ids), rank=rank) + def _setup_distributed_model(self) -> nn.Module: + """Configure model for distributed training.""" + self.model = self.model.to(self.device) + return DDP( + self.model, device_ids=[self.config.gpu_ids[self.rank]], output_device=self.config.gpu_ids[self.rank] + ) - def _setup_wandb(self) -> None: - """Initialize Weights & Biases logging.""" + def _create_dataloader(self, dataset: Dataset, is_train: bool) -> DataLoader: + """Create appropriate DataLoader based on training configuration. - def _setup_wandb(self) -> None: - """Initialize Weights & Biases logging only on the main process.""" - if self.rank == 0: # Only initialize on main process - wandb.init(project=self.config.wandb_project, config=self.config.__dict__) + Args: + dataset: Dataset to create loader for. + is_train: Whether this is for training (affects shuffling). - def _create_dataloader(self, dataset: Dataset, name: str) -> DataLoader: - """Create a dataloader from a dataset.""" + Returns: + DataLoader instance. + """ if self.config.distributed: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=len(self.config.gpu_ids), rank=dist.get_rank(), shuffle=name == "train" + dataset, num_replicas=len(self.config.gpu_ids), rank=dist.get_rank(), shuffle=is_train ) else: sampler = None @@ -117,23 +206,44 @@ def _create_dataloader(self, dataset: Dataset, name: str) -> DataLoader: return DataLoader( dataset, batch_size=self.config.batch_size, - shuffle=(name == "train" and not self.config.distributed), + shuffle=(is_train and not self.config.distributed), sampler=sampler, ) + def _process_batch(self, batch: List[torch.Tensor]) -> tuple: + """Process a single batch of data. + + Args: + batch: List of input tensors. + + Returns: + Tuple of (loss, predictions). + """ + x_context, x_padding, freq, x_future = [t.to(self.device, non_blocking=True) for t in batch] + + predictions = self.model(x_context, x_padding.float(), freq) + predictions_mean = predictions[..., 0] + last_patch_pred = predictions_mean[:, -1, :] + + loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1)) + + return loss, predictions + def _train_epoch(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer) -> float: - """Train for one epoch with loss debugging.""" + """Train for one epoch. + + Args: + train_loader: DataLoader for training data. + optimizer: Optimizer instance. + + Returns: + Average training loss for the epoch. + """ self.model.train() total_loss = 0.0 - n_batches = len(train_loader) for batch in train_loader: - x_context, x_padding, freq, x_future = [t.to(self.device, non_blocking=True) for t in batch] - - predictions = self.model(x_context, x_padding.float(), freq) - predictions_mean = predictions[..., 0] - last_patch_pred = predictions_mean[:, -1, :] - loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1)) + loss, _ = self._process_batch(batch) if self.config.distributed: losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())] @@ -145,22 +255,23 @@ def _train_epoch(self, train_loader: DataLoader, optimizer: torch.optim.Optimize total_loss += loss.item() - return total_loss / n_batches + return total_loss / len(train_loader) def _validate(self, val_loader: DataLoader) -> float: - """Perform validation with loss debugging.""" + """Perform validation. + + Args: + val_loader: DataLoader for validation data. + + Returns: + Average validation loss. + """ self.model.eval() total_loss = 0.0 with torch.no_grad(): for batch in val_loader: - x_context, x_padding, freq, x_future = [t.to(self.device) for t in batch] - - predictions = self.model(x_context, x_padding.float(), freq) - predictions_mean = predictions[..., 0] - last_patch_pred = predictions_mean[:, -1, :] - - loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1)) + loss, _ = self._process_batch(batch) if self.config.distributed: losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())] @@ -171,31 +282,23 @@ def _validate(self, val_loader: DataLoader) -> float: return total_loss / len(val_loader) def finetune(self, train_dataset: Dataset, val_dataset: Dataset) -> Dict[str, Any]: - """ - Finetune the TimesFM model on the provided datasets. + """Train the model. Args: - train_dataset: Training dataset - val_dataset: Validation dataset + train_dataset: Training dataset. + val_dataset: Validation dataset. Returns: - Dict containing training history and best model path + Dictionary containing training history. """ self.model = self.model.to(self.device) + train_loader = self._create_dataloader(train_dataset, is_train=True) + val_loader = self._create_dataloader(val_dataset, is_train=False) - if self.config.distributed: - self.model = DDP( - self.model, - device_ids=[self.config.gpu_ids[dist.get_rank()]], - output_device=self.config.gpu_ids[dist.get_rank()], - ) - - train_loader = self._create_dataloader(train_dataset, "train") - val_loader = self._create_dataloader(val_dataset, "val") - - optimizer = optim.Adam( + optimizer = torch.optim.Adam( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay ) + history = {"train_loss": [], "val_loss": [], "learning_rate": []} self.logger.info(f"Starting training for {self.config.num_epochs} epochs...") @@ -205,67 +308,33 @@ def finetune(self, train_dataset: Dataset, val_dataset: Dataset) -> Dict[str, An try: for epoch in range(self.config.num_epochs): train_loss = self._train_epoch(train_loader, optimizer) - val_loss = self._validate(val_loader) - current_lr = optimizer.param_groups[0]["lr"] - if self.config.distributed: - train_tensor = torch.tensor(train_loss, device=self.device) - val_tensor = torch.tensor(val_loss, device=self.device) - - world_size = dist.get_world_size() - train_losses = [torch.zeros_like(train_tensor, device=self.device) for _ in range(world_size)] - val_losses = [torch.zeros_like(val_tensor, device=self.device) for _ in range(world_size)] - - dist.all_gather(train_losses, train_tensor) - dist.all_gather(val_losses, val_tensor) - - if self.rank == 0 and self.config.use_wandb: - train_losses = [t.cpu().item() for t in train_losses] - val_losses = [t.cpu().item() for t in val_losses] - - for gpu_idx, (t_loss, v_loss) in enumerate(zip(train_losses, val_losses)): - wandb.log( - { - f"train_loss_gpu_{gpu_idx}": t_loss, - f"val_loss_gpu_{gpu_idx}": v_loss, - }, - commit=False, - ) - - wandb.log( - { - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": current_lr, - "epoch": epoch + 1, - } - ) - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["learning_rate"].append(current_lr) - - else: - if self.config.use_wandb: - wandb.log( - { - "train_loss": train_loss, - "val_loss": val_loss, - "learning_rate": current_lr, - "epoch": epoch + 1, - } - ) - history["train_loss"].append(train_loss) - history["val_loss"].append(val_loss) - history["learning_rate"].append(current_lr) + metrics = { + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": current_lr, + "epoch": epoch + 1, + } + + if self.config.use_wandb: + self.metrics_logger.log_metrics(metrics) + + history["train_loss"].append(train_loss) + history["val_loss"].append(val_loss) + history["learning_rate"].append(current_lr) if self.rank == 0: - print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") + self.logger.info(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") + except KeyboardInterrupt: self.logger.info("Training interrupted by user") if self.config.distributed: - dist.destroy_process_group() + self.dist_manager.cleanup() + + if self.config.use_wandb: + self.metrics_logger.close() return {"history": history}