Skip to content

Commit

Permalink
Merge pull request #22 from becktepe/main
Browse files Browse the repository at this point in the history
[Fix] Loading paths + self-destruct for PBT
  • Loading branch information
TheEimer authored Oct 11, 2024
2 parents bf73b56 + 8360767 commit 44dea99
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
39 changes: 22 additions & 17 deletions hydra_plugins/hyper_pbt/hyper_pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import numpy as np
from ConfigSpace.hyperparameters import (CategoricalHyperparameter,
NormalIntegerHyperparameter,
OrdinalHyperparameter,
UniformIntegerHyperparameter)

from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
NormalIntegerHyperparameter,
OrdinalHyperparameter,
UniformIntegerHyperparameter,
)
from hydra_plugins.hypersweeper import Info


Expand Down Expand Up @@ -88,6 +89,7 @@ def ask(self):
self.population_id += 1
if iteration_end:
self.iteration += 1

return Info(
config=config,
budget=self.budget_per_run,
Expand Down Expand Up @@ -160,18 +162,21 @@ def tell(self, info, value):
if self.model_based:
self.fit_model(self.performance_history, self.config_history)

if self.self_destruct and self.iteration > 1:
import shutil

print(info)
# Try to remove the checkpoint without seeds
path = self.checkpoint_dir / f"{info.load_path!s}{self.checkpoint_path_typing}"
shutil.rmtree(path, ignore_errors=True)
# Try to remove the checkpoint with seeds
for s in self.seeds:
path = self.checkpoint_dir / f"{info.load_path!s}_s{s}{self.checkpoint_path_typing}"
shutil.rmtree(path, ignore_errors=True)

# Now that we have finished the iteration,
# we can safely remove all checkpoints from the previous iteration
print(f"Finished iteration {self.iteration}")
print("Remove checkpoints")
if self.self_destruct and self.iteration > 1:
self.remove_checkpoints(self.iteration - 2)

def remove_checkpoints(self, iteration: int) -> None:
"""Remove checkpoints."""
import os

# Delete all files in checkpoints dir starting with iteration_{iteration}
for file in os.listdir(self.checkpoint_dir):
if file.startswith(f"iteration_{iteration}"):
os.remove(os.path.join(self.checkpoint_dir, file))

def make_pbt(configspace, pbt_args):
"""Make a PBT instance for optimization."""
Expand Down
16 changes: 12 additions & 4 deletions hydra_plugins/hypersweeper/hypersweeper_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
import pandas as pd
import wandb
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf

from hydra_plugins.hypersweeper.utils import Info, Result, read_warmstart_data
from omegaconf import DictConfig, OmegaConf

if TYPE_CHECKING:
from ConfigSpace import Configuration, ConfigurationSpace
Expand Down Expand Up @@ -219,8 +218,7 @@ def run_configs(self, infos):
values = [*list(infos[i].config.values())]
if self.budget_arg_name is not None:
values += [infos[i].budget]
if self.load_tf and self.iteration > 0:
values += [Path(self.checkpoint_dir) / f"{infos[i].load_path!s}{self.checkpoint_path_typing}"]


if self.slurm:
names += ["hydra.launcher.timeout_min"]
Expand All @@ -232,7 +230,11 @@ def run_configs(self, infos):
if self.seeds:
for s in self.seeds:
local_values = values.copy()
load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}_s{s}{self.checkpoint_path_typing}"
save_path = self.get_save_path(i, s)

if self.load_tf and self.iteration > 0:
local_values += [load_path]
if self.checkpoint_tf:
local_values += [save_path]

Expand All @@ -246,14 +248,20 @@ def run_configs(self, infos):
For non-deterministic target functions, seeds must be provided.
If the optimizer you chose does not support this,
manually set the 'seeds' parameter of the sweeper to a list of seeds."""
load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}_s{s}{self.checkpoint_path_typing}"
save_path = self.get_save_path(i)

job_overrides = tuple(self.global_overrides) + tuple(
f"{name}={val}"
for name, val in zip([*names, self.seed_keyword], [*values, infos[i].seed], strict=True)
)
overrides.append(job_overrides)
else:
load_path = Path(self.checkpoint_dir) / f"{infos[i].load_path!s}{self.checkpoint_path_typing}"
save_path = self.get_save_path(i)

if self.load_tf and self.iteration > 0:
values += [load_path]
if self.checkpoint_tf:
values += [save_path]

Expand Down

0 comments on commit 44dea99

Please sign in to comment.