diff --git a/darts-acquisition/src/darts_acquisition/admin.py b/darts-acquisition/src/darts_acquisition/admin.py new file mode 100644 index 0000000..f03355e --- /dev/null +++ b/darts-acquisition/src/darts_acquisition/admin.py @@ -0,0 +1,60 @@ +"""Download of admin level files for the regions.""" + +import io +import logging +import time +import zipfile +from pathlib import Path + +import requests + +logger = logging.getLogger(__name__) + + +def _download_zip(url: str, admin_dir: Path): + tick_fstart = time.perf_counter() + response = requests.get(url) + + # Get the downloaded data as a byte string + data = response.content + + tick_download = time.perf_counter() + logger.debug(f"Downloaded {len(data)} bytes in {tick_download - tick_fstart:.2f} seconds") + + # Create a bytesIO object + with io.BytesIO(data) as buffer: + # Create a zipfile.ZipFile object and extract the files to a directory + admin_dir.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(buffer, "r") as zip_ref: + # Extract the files to the specified directory + zip_ref.extractall(admin_dir) + + tick_extract = time.perf_counter() + logger.debug(f"Extraction completed in {tick_extract - tick_download:.2f} seconds") + + +def download_admin_files(admin_dir: Path): + """Download the admin files for the regions. + + Files will be stored under [admin_dir]/adm1.shp and [admin_dir]/adm2.shp. + + Args: + admin_dir (Path): The path to the admin files. + + """ + tick_fstart = time.perf_counter() + + # Download the admin files + admin_1_url = "https://github.com/wmgeolab/geoBoundaries/raw/main/releaseData/CGAZ/geoBoundariesCGAZ_ADM1.zip" + admin_2_url = "https://github.com/wmgeolab/geoBoundaries/raw/main/releaseData/CGAZ/geoBoundariesCGAZ_ADM2.zip" + + admin_dir.mkdir(exist_ok=True, parents=True) + + logger.debug(f"Downloading {admin_1_url} to {admin_dir.resolve()}") + _download_zip(admin_1_url, admin_dir) + + logger.debug(f"Downloading {admin_2_url} to {admin_dir.resolve()}") + _download_zip(admin_2_url, admin_dir) + + tick_fend = time.perf_counter() + logger.info(f"Downloaded admin files in {tick_fend - tick_fstart:.2f} seconds") diff --git a/darts-export/src/darts_export/miniviz.py b/darts-export/src/darts_export/miniviz.py new file mode 100644 index 0000000..e69de29 diff --git a/darts-segmentation/src/darts_segmentation/training/data.py b/darts-segmentation/src/darts_segmentation/training/data.py index 56d1212..ae7dd57 100644 --- a/darts-segmentation/src/darts_segmentation/training/data.py +++ b/darts-segmentation/src/darts_segmentation/training/data.py @@ -11,7 +11,8 @@ import albumentations as A # noqa: N812 import lightning as L # noqa: N812 import torch -from torch.utils.data import DataLoader, Dataset, random_split +from sklearn.model_selection import KFold +from torch.utils.data import DataLoader, Dataset, Subset logger = logging.getLogger(__name__.replace("darts_", "darts.")) @@ -120,6 +121,7 @@ def __init__( self, data_dir: Path, batch_size: int, + current_fold: int = 0, augment: bool = True, num_workers: int = 0, in_memory: bool = False, @@ -128,6 +130,7 @@ def __init__( self.save_hyperparameters() self.data_dir = data_dir self.batch_size = batch_size + self.current_fold = current_fold self.augment = augment self.num_workers = num_workers self.in_memory = in_memory @@ -138,15 +141,14 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = No if not self.in_memory else DartsDatasetInMemory(self.data_dir, self.augment) ) - splits = [0.8, 0.1, 0.1] - generator = torch.Generator().manual_seed(42) - self.train, self.val, self.test = random_split(dataset, splits, generator) + + kf = KFold(n_splits=5) + train_idx, val_idx = list(kf.split(dataset))[self.current_fold] + self.train = Subset(dataset, train_idx) + self.val = Subset(dataset, val_idx) def train_dataloader(self): return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) def val_dataloader(self): return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers) - - def test_dataloader(self): - return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers) diff --git a/darts/src/darts/cli.py b/darts/src/darts/cli.py index 24e5ca9..a02cd54 100644 --- a/darts/src/darts/cli.py +++ b/darts/src/darts/cli.py @@ -15,7 +15,13 @@ run_native_sentinel2_pipeline, run_native_sentinel2_pipeline_fast, ) -from darts.legacy_training import convert_lightning_checkpoint, preprocess_s2_train_data, sweep_smp, train_smp +from darts.legacy_training import ( + convert_lightning_checkpoint, + optuna_sweep_smp, + preprocess_s2_train_data, + train_smp, + wandb_sweep_smp, +) from darts.utils.config import ConfigParser from darts.utils.logging import LoggingManager @@ -72,7 +78,8 @@ def env_info(): app.command(group=train_group)(preprocess_s2_train_data) app.command(group=train_group)(train_smp) app.command(group=train_group)(convert_lightning_checkpoint) -app.command(group=train_group)(sweep_smp) +app.command(group=train_group)(wandb_sweep_smp) +app.command(group=train_group)(optuna_sweep_smp) # Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules @@ -96,9 +103,10 @@ def launcher( # noqa: D103 *tokens: Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)], log_dir: Path = Path("logs"), config_file: Path = Path("config.toml"), + tracebacks_show_locals: bool = False, ): command, bound, _ = app.parse_args(tokens) - LoggingManager.add_logging_handlers(command.__name__, console, log_dir) + LoggingManager.add_logging_handlers(command.__name__, console, log_dir, tracebacks_show_locals) logger.debug(f"Running on Python version {sys.version} from {__name__} ({root_file})") return command(*bound.args, **bound.kwargs) diff --git a/darts/src/darts/legacy_training/__init__.py b/darts/src/darts/legacy_training/__init__.py index ed089f9..e31d832 100644 --- a/darts/src/darts/legacy_training/__init__.py +++ b/darts/src/darts/legacy_training/__init__.py @@ -1,6 +1,7 @@ """Legacy training module for DARTS.""" from darts.legacy_training.preprocess import preprocess_s2_train_data as preprocess_s2_train_data -from darts.legacy_training.train import sweep_smp as sweep_smp +from darts.legacy_training.train import optuna_sweep_smp as optuna_sweep_smp from darts.legacy_training.train import train_smp as train_smp +from darts.legacy_training.train import wandb_sweep_smp as wandb_sweep_smp from darts.legacy_training.util import convert_lightning_checkpoint as convert_lightning_checkpoint diff --git a/darts/src/darts/legacy_training/preprocess.py b/darts/src/darts/legacy_training/preprocess.py index 3128d60..51d8927 100644 --- a/darts/src/darts/legacy_training/preprocess.py +++ b/darts/src/darts/legacy_training/preprocess.py @@ -7,13 +7,11 @@ from pathlib import Path from typing import Literal -import toml - logger = logging.getLogger(__name__) def split_dataset_paths( - s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None + s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None, admin_dir: Path ): """Split the dataset into a cross-val, a val-test and a test dataset. @@ -33,6 +31,7 @@ def split_dataset_paths( """ # Import here to avoid long loading times when running other commands import geopandas as gpd + from darts_acquisition.admin import download_admin_files from darts_acquisition.s2 import parse_s2_tile_id from sklearn.model_selection import train_test_split @@ -45,11 +44,32 @@ def split_dataset_paths( test_paths: list[Path] = [] training_paths: list[Path] = [] if test_regions: + # Download admin files if they do not exist + admin1_fpath = admin_dir / "geoBoundariesCGAZ_ADM1.shp" + admin2_fpath = admin_dir / "geoBoundariesCGAZ_ADM2.shp" + + if not admin1_fpath.exists() or not admin2_fpath.exists(): + download_admin_files(admin_dir) + + # Load the admin files + admin1 = gpd.read_file(admin1_fpath) + admin2 = gpd.read_file(admin2_fpath) + + # Get the regions from the admin files + test_region_geometries_adm1 = admin1[admin1["shapeName"].isin(test_regions)] + test_region_geometries_adm2 = admin2[admin2["shapeName"].isin(test_regions)] + + logger.debug(f"Found {len(test_region_geometries_adm1)} admin1-regions in {admin1_fpath}") + logger.debug(f"Found {len(test_region_geometries_adm2)} admin2-regions in {admin2_fpath}") + for fpath in s2_paths: _, s2_tile_id, _ = parse_s2_tile_id(fpath) - labels = gpd.read_file(fpath / f"{s2_tile_id}.shp") - # If any of the regions is in the test region, add to the test set - if labels["region"].isin(test_regions).any(): + labels = gpd.read_file(fpath / f"{s2_tile_id}.shp").to_crs("EPSG:4326") + # Check if any label is intersecting with the test regions + adm1_intersects = labels.overlay(test_region_geometries_adm1, how="intersection") + adm2_intersects = labels.overlay(test_region_geometries_adm2, how="intersection") + + if (len(adm1_intersects.index) > 0) or (len(adm2_intersects.index) > 0): test_paths.append(fpath) else: training_paths.append(fpath) @@ -92,6 +112,7 @@ def preprocess_s2_train_data( train_data_dir: Path, arcticdem_dir: Path, tcvis_dir: Path, + admin_dir: Path, preprocess_cache: Path | None = None, device: Literal["cuda", "cpu", "auto"] | int | None = None, dask_worker: int = min(16, mp.cpu_count() - 1), @@ -116,6 +137,7 @@ def preprocess_s2_train_data( arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. tcvis_dir (Path): The directory containing the TCVis data. + admin_dir (Path): The directory containing the admin files. preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None. device (Literal["cuda", "cpu"] | int, optional): The device to run the model on. If "cuda" take the first device (0), if int take the specified device. @@ -144,6 +166,7 @@ def preprocess_s2_train_data( # Import here to avoid long loading times when running other commands import geopandas as gpd import pandas as pd + import toml import torch import xarray as xr from darts_acquisition.arcticdem import load_arcticdem_tile @@ -189,14 +212,19 @@ def preprocess_s2_train_data( # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region) n_patches = 0 + n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0} joint_lables = [] s2_paths = sorted(sentinel2_dir.glob("*/")) logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}") - path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions) + path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions, admin_dir) for i, (fpath, output_dir, mode) in enumerate(path_gen): try: _, s2_tile_id, tile_id = parse_s2_tile_id(fpath) + logger.debug( + f"Processing sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}' ({tile_id=}) to split '{mode}'" + ) + # Check for a cached preprocessed file if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists(): cache_file = preprocess_cache / f"{tile_id}.nc" @@ -251,14 +279,15 @@ def preprocess_s2_train_data( outdir_y = output_dir / "y" outdir_x.mkdir(exist_ok=True, parents=True) outdir_y.mkdir(exist_ok=True, parents=True) - n_patches = 0 + patch_id = None for patch_id, (x, y) in enumerate(gen): torch.save(x, outdir_x / f"{tile_id}_pid{patch_id}.pt") torch.save(y, outdir_y / f"{tile_id}_pid{patch_id}.pt") n_patches += 1 + n_patches_by_mode[mode] += 1 if n_patches > 0 and len(labels) > 0: labels["mode"] = mode - joint_lables.append(labels) + joint_lables.append(labels.to_crs("EPSG:3413")) logger.info( f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'" @@ -300,4 +329,4 @@ def preprocess_s2_train_data( with open(train_data_dir / "config.toml", "w") as f: toml.dump(config, f) - logger.info(f"Saved {n_patches} patches to {train_data_dir}") + logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}") diff --git a/darts/src/darts/legacy_training/train.py b/darts/src/darts/legacy_training/train.py index 440cd18..2e38152 100644 --- a/darts/src/darts/legacy_training/train.py +++ b/darts/src/darts/legacy_training/train.py @@ -2,7 +2,9 @@ import logging import time +from collections import defaultdict from pathlib import Path +from statistics import mean import toml import yaml @@ -15,6 +17,7 @@ def train_smp( # Data config train_data_dir: Path, artifact_dir: Path = Path("lightning_logs"), + current_fold: int = 0, # Hyperparameters model_arch: str = "Unet", model_encoder: str = "dpn107", @@ -56,6 +59,7 @@ def train_smp( train_data_dir (Path): Path to the training data directory. artifact_dir (Path, optional): Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs"). + current_fold (int, optional): The current fold to train on. Must be in [0, 4]. Defaults to 0. model_arch (str, optional): Model architecture to use. Defaults to "Unet". model_encoder (str, optional): Encoder to use. Defaults to "dpn107". model_encoder_weights (str | None, optional): Path to the encoder weights. Defaults to None. @@ -79,6 +83,9 @@ def train_smp( wandb_project (str | None, optional): Weights and Biases Project. Defaults to None. run_name (str | None, optional): Name of this run, as a further grouping method for logs etc. Defaults to None. + Returns: + Trainer: The trainer object used for training. + """ import lightning as L # noqa: N812 import lovely_tensors @@ -105,6 +112,7 @@ def train_smp( lovely_tensors.monkey_patch() torch.set_float32_matmul_precision("medium") + torch.manual_seed(42) preprocess_config = toml.load(train_data_dir / "config.toml")["darts"] @@ -121,8 +129,21 @@ def train_smp( ) # Data and model - datamodule = DartsDataModule(train_data_dir, batch_size, augment, num_workers) - model = SMPSegmenter(config, learning_rate, gamma, focal_loss_alpha, focal_loss_gamma, plot_every_n_val_epochs) + datamodule = DartsDataModule( + data_dir=train_data_dir / "cross-val", + batch_size=batch_size, + current_fold=current_fold, + augment=augment, + num_workers=num_workers, + ) + model = SMPSegmenter( + config=config, + learning_rate=learning_rate, + gamma=gamma, + focal_loss_alpha=focal_loss_alpha, + focal_loss_gamma=focal_loss_gamma, + plot_every_n_val_epochs=plot_every_n_val_epochs, + ) # Loggers trainer_loggers = [ @@ -166,13 +187,20 @@ def train_smp( tick_fend = time.perf_counter() logger.info(f"Finished training '{run_name}' in {tick_fend - tick_fstart:.2f}s.") + if wandb_entity and wandb_project: + wandb_logger.finalize("success") + wandb_logger.experiment.finish(exit_code=0) + logger.debug(f"Finalized WandB logging for '{run_name}'") + + return trainer + -def sweep_smp( +def wandb_sweep_smp( *, # Data and sweep config train_data_dir: Path, sweep_config: Path, - sweep_count: int = 10, + n_trials: int = 10, sweep_id: str | None = None, artifact_dir: Path = Path("lightning_logs"), # Epoch and Logging config @@ -186,7 +214,7 @@ def sweep_smp( wandb_entity: str | None = None, wandb_project: str | None = None, ): - """Create a sweep and run it on specified cuda devices, or continue an existing sweep. + """Create a sweep with wandb and run it on the specified cuda device, or continue an existing sweep. If `sweep_id` is None, a new sweep will be created. Otherwise, the sweep with the given ID will be continued. @@ -200,10 +228,12 @@ def sweep_smp( Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`. + This will NOT use cross-validation. For cross-validation, use `optuna_sweep_smp`. + Example: In one terminal, start a sweep: ```sh - $ rye run darts sweep_smp --config-file /path/to/sweep-config.toml + $ rye run darts wandb-sweep-smp --config-file /path/to/sweep-config.toml ... # Many logs Created sweep with ID 123456789 ... # More logs from spawned agent @@ -211,7 +241,7 @@ def sweep_smp( In another terminal, start an a second agent: ```sh - $ rye run darts sweep_smp --sweep-id 123456789 + $ rye run darts wandb-sweep-smp --sweep-id 123456789 ... ``` @@ -221,7 +251,7 @@ def sweep_smp( Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`, `batch_size`. Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information. - sweep_count (int, optional): Number of runs to execute. Defaults to 10. + n_trials (int, optional): Number of runs to execute. Defaults to 10. sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None. artifact_dir (Path, optional): Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs"). @@ -282,6 +312,7 @@ def run(): # Data config train_data_dir=train_data_dir, artifact_dir=artifact_dir, + current_fold=0, # Hyperparameters model_arch=model_arch, model_encoder=model_encoder, @@ -309,4 +340,153 @@ def run(): return logger.info("Starting a default sweep agent") - wandb.agent(sweep_id, function=run, count=sweep_count, project=wandb_project, entity=wandb_entity) + wandb.agent(sweep_id, function=run, count=n_trials, project=wandb_project, entity=wandb_entity) + + +def optuna_sweep_smp( + *, + # Data and sweep config + train_data_dir: Path, + sweep_config: Path, + n_trials: int = 10, + sweep_db: str | None = None, + sweep_id: str | None = None, + artifact_dir: Path = Path("lightning_logs"), + # Epoch and Logging config + max_epochs: int = 100, + log_every_n_steps: int = 10, + check_val_every_n_epoch: int = 3, + plot_every_n_val_epochs: int = 5, + # Device and Manager config + num_workers: int = 0, + device: int | str | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, +): + """Create an optuna sweep and run it on the specified cuda device, or continue an existing sweep. + + If `sweep_id` already exists in `sweep_db`, the sweep will be continued. Otherwise, a new sweep will be created. + + If a `cuda_device` is specified, run an agent on this device. If None, do nothing. + + You can specify the frequency on how often logs will be written and validation will be performed. + - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation. + - `check_val_every_n_epoch` specifies how often validation will be performed. + This will also affect early stopping. + - `plot_every_n_val_epochs` specifies how often validation samples will be plotted. + Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. + In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`. + + This will use cross-validation. + + Example: + In one terminal, start a sweep: + ```sh + $ rye run darts sweep-smp --config-file /path/to/sweep-config.toml + ... # Many logs + Created sweep with ID 123456789 + ... # More logs from spawned agent + ``` + + In another terminal, start an a second agent: + ```sh + $ rye run darts sweep-smp --sweep-id 123456789 + ... + ``` + + Args: + train_data_dir (Path): Path to the training data directory. + sweep_config (Path): Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration. + Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`, + `batch_size`. + Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information. + n_trials (int, optional): Number of runs to execute. Defaults to 10. + sweep_db (str | None, optional): Path to the optuna database. If None, a new database will be created. + sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None. + artifact_dir (Path, optional): Path to the training output directory. + Will contain checkpoints and metrics. Defaults to Path("lightning_logs"). + max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100. + log_every_n_steps (int, optional): Log every n steps. Defaults to 10. + check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3. + plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5. + num_workers (int, optional): Number of Dataloader workers. Defaults to 0. + device (int | str | None, optional): The device to run the model on. Defaults to None. + wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None. + wandb_project (str | None, optional): Weights and Biases Project. Defaults to None. + run_name (str | None, optional): Name of this run, as a further grouping method for logs etc. Defaults to None. + + """ + import optuna + + from darts.legacy_training.util import suggest_optuna_params_from_wandb_config + + with sweep_config.open("r") as f: + sweep_configuration = yaml.safe_load(f) + + logger.debug(f"Loaded sweep configuration from {sweep_config.resolve()}:\n{sweep_configuration}") + + def objective(trial): + hparams = suggest_optuna_params_from_wandb_config(trial, sweep_configuration) + logger.info(f"Running trial with parameters: {hparams}") + + # We set the default weights to None, to be able to use different architectures + model_encoder_weights = None + # We set early stopping to None, because wandb will handle the early stopping + early_stopping_patience = None + learning_rate = hparams["learning_rate"] + gamma = hparams["gamma"] + batch_size = hparams["batch_size"] + model_arch = hparams["model_arch"] + model_encoder = hparams["model_encoder"] + augment = hparams["augment"] + + crossval_scores = defaultdict(list) + for current_fold in range(5): + logger.info(f"Running cross-validation fold {current_fold}") + trainer = train_smp( + # Data config + train_data_dir=train_data_dir, + artifact_dir=artifact_dir, + current_fold=current_fold, + # Hyperparameters + model_arch=model_arch, + model_encoder=model_encoder, + model_encoder_weights=model_encoder_weights, + augment=augment, + learning_rate=learning_rate, + gamma=gamma, + batch_size=batch_size, + # Epoch and Logging config + early_stopping_patience=early_stopping_patience, + max_epochs=max_epochs, + log_every_n_steps=log_every_n_steps, + check_val_every_n_epoch=check_val_every_n_epoch, + plot_every_n_val_epochs=plot_every_n_val_epochs, + # Device and Manager config + num_workers=num_workers, + device=device, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + run_name=f"{sweep_id}_t{trial.number}f{current_fold}", + ) + for metric, value in trainer.callback_metrics.items(): + crossval_scores[metric].append(value.item()) + + logger.debug(f"Cross-validation scores: {crossval_scores}") + crossval_jaccard = mean(crossval_scores["val/JaccardIndex"]) + crossval_recall = mean(crossval_scores["val/Recall"]) + return crossval_jaccard, crossval_recall + + study = optuna.create_study( + storage=sweep_db, + study_name=sweep_id, + directions=["maximize", "maximize"], + load_if_exists=True, + ) + + if device is None: + logger.info("No device specified, closing script...") + return + + logger.info("Starting optimizing") + study.optimize(objective, n_trials=n_trials) diff --git a/darts/src/darts/legacy_training/util.py b/darts/src/darts/legacy_training/util.py index 20d838d..a50a22b 100644 --- a/darts/src/darts/legacy_training/util.py +++ b/darts/src/darts/legacy_training/util.py @@ -54,3 +54,103 @@ def convert_lightning_checkpoint( torch.save(own_ckpt, out_checkpoint) logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}") + + +def suggest_optuna_params_from_wandb_config(trial, config: dict): + """Get optuna parameters from a wandb sweep config. + + This functions translate a wandb sweep config to a dict of values, suggested from optuna. + + Args: + trial (optuna.Trial): The optuna trial + config (dict): The wandb sweep config + + Raises: + ValueError: If the distribution is not recognized. + + Returns: + dict: A dict of parameters with the values suggested from optuna. + + Example: + Assume a wandb config which looks like this: + + ```yaml + parameters: + learning_rate: + max: !!float 1e-3 + min: !!float 1e-7 + distribution: log_uniform_values + batch_size: + value: 8 + gamma: + value: 0.9 + augment: + value: True + model_arch: + values: + - UnetPlusPlus + - Unet + model_encoder: + values: + - resnext101_32x8d + - resnet101 + - dpn98 + + ``` + + This function will return a dict like this: + + ```python + { + "learning_rate": trial.suggest_loguniform("learning_rate", 1e-7, 1e-3), + "batch_size": 8, + "gamma": 0.9, + "augment": True, + "model_arch": trial.suggest_categorical("model_arch", ["UnetPlusPlus", "Unet"]), + "model_encoder": trial.suggest_categorical( + "model_encoder", ["resnext101_32x8d", "resnet101", "dpn98"] + ), + } + ``` + + See https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information on the sweep config. + + Note: Not all distribution types are supported. + + """ + import optuna + + trial: optuna.Trial = trial + + wandb_params: dict[str, dict] = config["parameters"] + + conv = {} + for param, constrains in wandb_params.items(): + # Handle bad case first: user didn't specified the "distribution key" + if "distribution" not in constrains.keys(): + if "value" in constrains.keys(): + conv[param] = constrains["value"] + elif "values" in constrains.keys(): + conv[param] = trial.suggest_categorical(param, constrains["values"]) + elif "min" in constrains.keys() and "max" in constrains.keys(): + conv[param] = trial.suggest_float(param, constrains["min"], constrains["max"]) + else: + raise ValueError(f"Unknown constrains for parameter {param}") + continue + + # Now handle the good case where the user specified the distribution + distribution = constrains["distribution"] + if distribution == "categorical": + conv[param] = trial.suggest_categorical(param, constrains["values"]) + elif distribution == "int_uniform": + conv[param] = trial.suggest_int(param, constrains["min"], constrains["max"]) + elif distribution == "uniform": + conv[param] = trial.suggest_float(param, constrains["min"], constrains["max"]) + elif distribution == "q_uniform": + conv[param] = trial.suggest_float(param, constrains["min"], constrains["max"], step=constrains["q"]) + elif distribution == "log_uniform_values": + conv[param] = trial.suggest_float(param, constrains["min"], constrains["max"], log=True) + else: + raise ValueError(f"Unknown distribution {distribution}") + + return conv diff --git a/darts/src/darts/utils/logging.py b/darts/src/darts/utils/logging.py index 585e163..69a1108 100644 --- a/darts/src/darts/utils/logging.py +++ b/darts/src/darts/utils/logging.py @@ -39,13 +39,14 @@ def setup_logging(self): logging.getLogger("darts").setLevel(DARTS_LEVEL) logging.captureWarnings(True) - def add_logging_handlers(self, command: str, console: Console, log_dir: Path): + def add_logging_handlers(self, command: str, console: Console, log_dir: Path, tracebacks_show_locals: bool = False): """Add logging handlers (rich-console and file) to the application. Args: command (str): The command that is run. console (Console): The rich console to log everything to. log_dir (Path): The directory to save the logs to. + tracebacks_show_locals (bool): Whether to show local variables in tracebacks. """ import distributed @@ -66,7 +67,7 @@ def add_logging_handlers(self, command: str, console: Console, log_dir: Path): console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts, L, torch, torch.utils.data, xr, distributed], - tracebacks_show_locals=True, + tracebacks_show_locals=tracebacks_show_locals, ) rich_handler.setFormatter( logging.Formatter( diff --git a/pyproject.toml b/pyproject.toml index ba4681c..6b7d841 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "bokeh>=3.5.2", "jupyter-bokeh>=4.0.5", "seaborn>=0.13.2", + "optuna>=4.1.0", ] readme = "README.md" requires-python = ">= 3.11"