Skip to content

Commit

Permalink
add logic for using aim run by hash, but it deletes the previous results
Browse files Browse the repository at this point in the history
  • Loading branch information
FilyaGeikyan committed Aug 27, 2024
1 parent 0e7dc58 commit 5c078f3
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
9 changes: 9 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from torchtitan.logging import logger

from typing import Optional

TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
Expand Down Expand Up @@ -540,6 +542,13 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--metrics.aim_hash",
type=Optional[str],
default=None,
help="The hash of the aim run to continue with",
)

def parse_args(self, args_list: list = sys.argv[1:]):
self.args_list = args_list

Expand Down
10 changes: 6 additions & 4 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ def build_gpu_memory_monitor():


class MetricLogger:
def __init__(self, log_dir, save_aim_folder, enable_aim):
def __init__(self, hash, log_dir, save_aim_folder, enable_aim):
self.writer: Optional[AimLogger] = None
if enable_aim:
self.writer = AimLogger(save_aim_folder)
if hash is not None:
self.writer = AimLogger(save_aim_folder, run_hash=hash)
else:
self.writer = AimLogger(save_aim_folder)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
Expand Down Expand Up @@ -144,6 +147,5 @@ def build_metric_logger(
f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}"
)
enable_aim = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)

return MetricLogger(log_dir, save_aim_folder, enable_aim)
return MetricLogger(job_config.metrics.aim_hash, log_dir, save_aim_folder, enable_aim)

2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def loss_fn(pred, labels):
job_config_dict = job_config._args_to_two_level_dict(args)
metric_logger.log_hparams(job_config_dict)

# plot losses loaded from checkpoint (if any) to TensorBoard
# plot losses loaded from checkpoint (if any) to Aim
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
if train_state.step > 0:
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log_freq = 1
enable_color_printing = true
enable_aim = true
save_aim_folder = "aim"
aim_hash = "bfccae7dda0640f89e390e9a"

[model]
name = "llama3"
Expand Down

0 comments on commit 5c078f3

Please sign in to comment.