Skip to content

Commit

Permalink
Merge pull request #10 from YerevaNN/aim
Browse files Browse the repository at this point in the history
Aim
  • Loading branch information
FilyaGeikyan authored Aug 29, 2024
2 parents 17d7d0a + c0182da commit 9fce8f4
Show file tree
Hide file tree
Showing 18 changed files with 208 additions and 79 deletions.
3 changes: 2 additions & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
torchdata >= 0.8.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
aim
sentencepiece
tiktoken
blobfile
tabulate
transformers
21 changes: 1 addition & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
3. Selective layer and operator activation checkpointing
4. Distributed checkpointing
5. 2 datasets pre-configured (45K - 144M)
6. GPU usage, MFU, tokens per second and more displayed via TensorBoard
6. GPU usage, MFU, tokens per second and more displayed via Aim
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
7. All options easily configured via [toml files](train_configs/)
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
Expand Down Expand Up @@ -87,25 +87,6 @@ CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
```


## TensorBoard

To visualize TensorBoard metrics of models trained on a remote server via a local web browser:

1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI).

2. Set up SSH tunneling, by running the following from local CLI
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend
```
tensorboard --logdir=./outputs/tb
```

4. In the local web browser, go to the URL it provides OR to http://localhost:6006/.


## Multi-Node Training
For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dev = [
"pre-commit",
"pytest",
"pytest-cov",
"tensorboard",
"aim",
]

[tool.setuptools.dynamic]
Expand Down
4 changes: 0 additions & 4 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class TestJobConfig:
def test_command_line_args(self):
config = JobConfig()
config.parse_args([])
assert config.training.steps == 10000

def test_job_config_file(self):
config = JobConfig()
config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"])
assert config.training.steps == 10

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
Expand All @@ -30,7 +28,6 @@ def test_empty_config_file(self):
with tempfile.NamedTemporaryFile() as fp:
config = JobConfig()
config.parse_args(["--job.config_file", fp.name])
assert config.job.description

def test_job_config_file_cmd_overrides(self):
config = JobConfig()
Expand All @@ -42,7 +39,6 @@ def test_job_config_file_cmd_overrides(self):
"/tmp/test_tt/",
]
)
assert config.job.dump_folder == "/tmp/test_tt/"

def test_print_help(self):
config = JobConfig()
Expand Down
106 changes: 106 additions & 0 deletions torchtitan/aim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import Any, Dict, Optional

from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT
from aim.sdk.repo import Repo
from aim.sdk.run import Run
from aim.sdk.utils import clean_repo_path, get_aim_repo_name


class AimLogger():
def __init__(
self,
repo: Optional[str] = None,
experiment: Optional[str] = None,
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: Optional[bool] = True,
capture_terminal_logs: Optional[bool] = True,
run_name: Optional[str] = None,
run_hash: Optional[str] = None,
train_metric_prefix: Optional[str] = 'train_',
val_metric_prefix: Optional[str] = 'val_',
test_metric_prefix: Optional[str] = 'test_',
):
super().__init__()

self._experiment_name = experiment
self._run_name = run_name
self._repo_path = repo

self._system_tracking_interval = system_tracking_interval
self._log_system_params = log_system_params
self._capture_terminal_logs = capture_terminal_logs

self._run = None
self._run_hash = run_hash

self._train_metric_prefix = train_metric_prefix
self._val_metric_prefix = val_metric_prefix
self._test_metric_prefix = test_metric_prefix

@property
def experiment(self) -> Run:
if self._run is None:
if self._run_hash:
self._run = Run(
self._run_hash,
repo=self._repo_path,
system_tracking_interval=self._system_tracking_interval,
capture_terminal_logs=self._capture_terminal_logs,
force_resume=True,
)
else:
self._run = Run(
repo=self._repo_path,
experiment=self._experiment_name,
system_tracking_interval=self._system_tracking_interval,
log_system_params=self._log_system_params,
capture_terminal_logs=self._capture_terminal_logs,
)
self._run_hash = self._run.hash
if self._run_name is not None:
self._run.name = self._run_name
return self._run

def log_hyperparams(self, params: Dict[str, Any]):
for key, value in params.items():
self.experiment.set(('hparams', key), value, strict=False)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):

metric_items: Dict[str:Any] = {k: v for k, v in metrics.items()} # for modifications to metric_items without affecting the original metrics
for k, v in metric_items.items():
name = k
context = {}
if self._train_metric_prefix and name.startswith(self._train_metric_prefix):
name = name[len(self._train_metric_prefix) :]
context['subset'] = 'train'
elif self._test_metric_prefix and name.startswith(self._test_metric_prefix):
name = name[len(self._test_metric_prefix) :]
context['subset'] = 'test'
elif self._val_metric_prefix and name.startswith(self._val_metric_prefix):
name = name[len(self._val_metric_prefix) :]
context['subset'] = 'val'
self.experiment.track(v, name=name, step=step, context=context)

def finalize(self, status: str = '') -> None:
if self._run:
self._run.close()
del self._run
self._run = None

def __del__(self):
self.finalize()

@property
def save_dir(self) -> str:
repo_path = clean_repo_path(self._repo_path) or Repo.default_repo_path()
return os.path.join(repo_path, get_aim_repo_name())

@property
def name(self) -> str:
return self._experiment_name

@property
def version(self) -> str:
return self.experiment.hash
2 changes: 2 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
global_avg_perplexities: List[float] = field(default_factory=list)
global_max_perplexities: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
Expand Down
30 changes: 23 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from torchtitan.logging import logger

from typing import Optional

TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
Expand Down Expand Up @@ -118,7 +120,7 @@ def __init__(self):
"--metrics.log_freq",
type=int,
default=10,
help="How often to log metrics to TensorBoard, in iterations",
help="How often to log metrics to aim, in iterations",
)
self.parser.add_argument(
"--metrics.enable_color_printing",
Expand All @@ -127,22 +129,22 @@ def __init__(self):
help="Whether to enable color printing",
)
self.parser.add_argument(
"--metrics.enable_tensorboard",
"--metrics.enable_aim",
action="store_true",
help="Whether to log metrics to TensorBoard",
help="Whether to log metrics to aim",
)
self.parser.add_argument(
"--metrics.save_tb_folder",
"--metrics.save_aim_folder",
type=str,
default="tb",
help="Folder to dump TensorBoard states",
default="aim",
help="Folder to dump Aim states",
)
self.parser.add_argument(
"--metrics.rank_0_only",
default=True,
action="store_true",
help="""
Whether to save TensorBoard metrics only for rank 0 or for all ranks.
Whether to save Aim metrics only for rank 0 or for all ranks.
When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group,
which is the only stage that computes loss metrics.
""",
Expand Down Expand Up @@ -546,7 +548,21 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--metrics.aim_hash",
type=Optional[str],
default=None,
help="The hash of the aim run to continue with",
)

self.parser.add_argument(
"--metrics.aim_experiment_name",
type=Optional[str],
default=None,
)
def parse_args(self, args_list: list = sys.argv[1:]):
self.args_list = args_list

args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
# build up a two level dict
Expand Down
50 changes: 26 additions & 24 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing import Any, Dict, Optional

import torch
from torch.utils.tensorboard import SummaryWriter
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims
from torchtitan.aim import AimLogger

# named tuple for passing GPU memory stats for logging
GPUMemStats = namedtuple(
Expand Down Expand Up @@ -94,21 +94,27 @@ def build_gpu_memory_monitor():


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)
def __init__(self, hash, experiment_name, log_dir, save_aim_folder, enable_aim):
self.writer: Optional[AimLogger] = None
if enable_aim:
if hash is not None:
self.writer = AimLogger(save_aim_folder, run_hash=hash)
elif experiment_name is not None:
self.writer = AimLogger(save_aim_folder, experiment=experiment_name)
else:
self.writer = AimLogger(save_aim_folder)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
self.writer.log_metrics(metrics, step)

def close(self):
if self.writer is not None:
self.writer.close()
self.writer.finalize()

def log_hparams(self, config):
if self.writer is not None:
self.writer.experiment['hparams'] = config


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
Expand All @@ -122,30 +128,26 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:


def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
job_config: JobConfig, parallel_dims: ParallelDims
):
"""
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
parallel_dims is used to determine the rank to log metrics from if 'aim_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
"""
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
aim_config = job_config.metrics
save_aim_folder = aim_config.save_aim_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)
log_dir = os.path.join(dump_dir, datetime_str)

enable_tb = tb_config.enable_tensorboard
if enable_tb:
enable_aim = aim_config.enable_aim
if enable_aim:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)
enable_aim = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
return MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim)

return MetricLogger(log_dir, tag, enable_tb)
Loading

0 comments on commit 9fce8f4

Please sign in to comment.