From ea8b850ad5583145d19850828930f5698d38ec2d Mon Sep 17 00:00:00 2001 From: Jannis Becktepe Date: Fri, 11 Oct 2024 15:41:11 +0200 Subject: [PATCH 1/2] fix: loading path for PBT + seeds --- .../hypersweeper/hypersweeper_sweeper.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/hydra_plugins/hypersweeper/hypersweeper_sweeper.py b/hydra_plugins/hypersweeper/hypersweeper_sweeper.py index 5ae12ee..5e19e25 100644 --- a/hydra_plugins/hypersweeper/hypersweeper_sweeper.py +++ b/hydra_plugins/hypersweeper/hypersweeper_sweeper.py @@ -13,9 +13,8 @@ import pandas as pd import wandb from hydra.utils import to_absolute_path -from omegaconf import DictConfig, OmegaConf - from hydra_plugins.hypersweeper.utils import Info, Result, read_warmstart_data +from omegaconf import DictConfig, OmegaConf if TYPE_CHECKING: from ConfigSpace import Configuration, ConfigurationSpace @@ -219,8 +218,7 @@ def run_configs(self, infos): values = [*list(infos[i].config.values())] if self.budget_arg_name is not None: values += [infos[i].budget] - if self.load_tf and self.iteration > 0: - values += [Path(self.checkpoint_dir) / f"{infos[i].load_path!s}{self.checkpoint_path_typing}"] + if self.slurm: names += ["hydra.launcher.timeout_min"] @@ -232,7 +230,11 @@ def run_configs(self, infos): if self.seeds: for s in self.seeds: local_values = values.copy() + load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}_s{s}{self.checkpoint_path_typing}" save_path = self.get_save_path(i, s) + + if self.load_tf and self.iteration > 0: + local_values += [load_path] if self.checkpoint_tf: local_values += [save_path] @@ -246,14 +248,20 @@ def run_configs(self, infos): For non-deterministic target functions, seeds must be provided. If the optimizer you chose does not support this, manually set the 'seeds' parameter of the sweeper to a list of seeds.""" + load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}_s{s}{self.checkpoint_path_typing}" save_path = self.get_save_path(i) + job_overrides = tuple(self.global_overrides) + tuple( f"{name}={val}" for name, val in zip([*names, self.seed_keyword], [*values, infos[i].seed], strict=True) ) overrides.append(job_overrides) else: + load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}{self.checkpoint_path_typing}" save_path = self.get_save_path(i) + + if self.load_tf and self.iteration > 0: + values += [load_path] if self.checkpoint_tf: values += [save_path] From c559733f9c3e6ad9e63fd06307693861e1a43c5b Mon Sep 17 00:00:00 2001 From: Jannis Becktepe Date: Fri, 11 Oct 2024 15:41:32 +0200 Subject: [PATCH 2/2] fix: self-destruct --- hydra_plugins/hyper_pbt/hyper_pbt.py | 39 ++++++++++++++++------------ 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/hydra_plugins/hyper_pbt/hyper_pbt.py b/hydra_plugins/hyper_pbt/hyper_pbt.py index 7b85c11..a034bf6 100644 --- a/hydra_plugins/hyper_pbt/hyper_pbt.py +++ b/hydra_plugins/hyper_pbt/hyper_pbt.py @@ -3,11 +3,12 @@ from __future__ import annotations import numpy as np -from ConfigSpace.hyperparameters import (CategoricalHyperparameter, - NormalIntegerHyperparameter, - OrdinalHyperparameter, - UniformIntegerHyperparameter) - +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + NormalIntegerHyperparameter, + OrdinalHyperparameter, + UniformIntegerHyperparameter, +) from hydra_plugins.hypersweeper import Info @@ -88,6 +89,7 @@ def ask(self): self.population_id += 1 if iteration_end: self.iteration += 1 + return Info( config=config, budget=self.budget_per_run, @@ -160,18 +162,21 @@ def tell(self, info, value): if self.model_based: self.fit_model(self.performance_history, self.config_history) - if self.self_destruct and self.iteration > 1: - import shutil - - print(info) - # Try to remove the checkpoint without seeds - path = self.checkpoint_dir / f"{info.load_path!s}{self.checkpoint_path_typing}" - shutil.rmtree(path, ignore_errors=True) - # Try to remove the checkpoint with seeds - for s in self.seeds: - path = self.checkpoint_dir / f"{info.load_path!s}_s{s}{self.checkpoint_path_typing}" - shutil.rmtree(path, ignore_errors=True) - + # Now that we have finished the iteration, + # we can safely remove all checkpoints from the previous iteration + print(f"Finished iteration {self.iteration}") + print("Remove checkpoints") + if self.self_destruct and self.iteration > 1: + self.remove_checkpoints(self.iteration - 2) + + def remove_checkpoints(self, iteration: int) -> None: + """Remove checkpoints.""" + import os + + # Delete all files in checkpoints dir starting with iteration_{iteration} + for file in os.listdir(self.checkpoint_dir): + if file.startswith(f"iteration_{iteration}"): + os.remove(os.path.join(self.checkpoint_dir, file)) def make_pbt(configspace, pbt_args): """Make a PBT instance for optimization."""