From fa670dbfbdf0aad3baa9de6ee32d0026181353c5 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Tue, 27 Aug 2024 17:43:44 +0400 Subject: [PATCH] experiment name --- torchtitan/config_manager.py | 5 +++++ torchtitan/metrics.py | 6 ++++-- train_configs/debug_model.toml | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 231de76f..e3d6f171 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -555,6 +555,11 @@ def __init__(self): help="The hash of the aim run to continue with", ) + self.parser.add_argument( + "--metrics.aim_experiment_name", + type=Optional[str], + default=None, + ) def parse_args(self, args_list: list = sys.argv[1:]): self.args_list = args_list diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 0c919957..d3a61842 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -94,11 +94,13 @@ def build_gpu_memory_monitor(): class MetricLogger: - def __init__(self, hash, log_dir, save_aim_folder, enable_aim): + def __init__(self, hash, experiment_name, log_dir, save_aim_folder, enable_aim): self.writer: Optional[AimLogger] = None if enable_aim: if hash is not None: self.writer = AimLogger(save_aim_folder, run_hash=hash) + elif experiment_name is not None: + self.writer = AimLogger(save_aim_folder, experiment=experiment_name) else: self.writer = AimLogger(save_aim_folder) @@ -147,5 +149,5 @@ def build_metric_logger( f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}" ) enable_aim = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) - return MetricLogger(job_config.metrics.aim_hash, log_dir, save_aim_folder, enable_aim) + return MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 27d9b4cc..3eba5fef 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -18,6 +18,7 @@ enable_color_printing = true enable_aim = true save_aim_folder = "aim" aim_hash = "c6b4d8b340f74287b82ef928" +aim_experiment_name = "hello" [model] name = "llama3"