Skip to content

Commit

Permalink
experiment name
Browse files Browse the repository at this point in the history
  • Loading branch information
FilyaGeikyan committed Aug 27, 2024
1 parent dcb1253 commit fa670db
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
5 changes: 5 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ def __init__(self):
help="The hash of the aim run to continue with",
)

self.parser.add_argument(
"--metrics.aim_experiment_name",
type=Optional[str],
default=None,
)
def parse_args(self, args_list: list = sys.argv[1:]):
self.args_list = args_list

Expand Down
6 changes: 4 additions & 2 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def build_gpu_memory_monitor():


class MetricLogger:
def __init__(self, hash, log_dir, save_aim_folder, enable_aim):
def __init__(self, hash, experiment_name, log_dir, save_aim_folder, enable_aim):
self.writer: Optional[AimLogger] = None
if enable_aim:
if hash is not None:
self.writer = AimLogger(save_aim_folder, run_hash=hash)
elif experiment_name is not None:
self.writer = AimLogger(save_aim_folder, experiment=experiment_name)
else:
self.writer = AimLogger(save_aim_folder)

Expand Down Expand Up @@ -147,5 +149,5 @@ def build_metric_logger(
f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}"
)
enable_aim = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
return MetricLogger(job_config.metrics.aim_hash, log_dir, save_aim_folder, enable_aim)
return MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim)

1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enable_color_printing = true
enable_aim = true
save_aim_folder = "aim"
aim_hash = "c6b4d8b340f74287b82ef928"
aim_experiment_name = "hello"

[model]
name = "llama3"
Expand Down

0 comments on commit fa670db

Please sign in to comment.