Skip to content

Commit

Permalink
Merge branch 'main' into aim
Browse files Browse the repository at this point in the history
  • Loading branch information
FilyaGeikyan committed Aug 27, 2024
2 parents 73d458f + 5cd10c7 commit dcb1253
Show file tree
Hide file tree
Showing 17 changed files with 739 additions and 95 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__
.idea
.vscode
.DS_Store
*.egg-info
build
Expand Down
31 changes: 31 additions & 0 deletions submitit_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import submitit
import datetime
import yaml
import os


if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
executor.update_parameters(
name="titan", timeout_min=15,
gpus_per_node=2,
nodes=1, mem_gb=30, cpus_per_task=10,
slurm_array_parallelism=10
)

jobs = []
with executor.batch():
for _ in range(1):
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
'--nproc_per_node', '2',
'--rdzv_backend', 'c10d',
'--rdzv_endpoint', 'localhost:0',
'--local-ranks-filter', '0',
'--role', 'rank', '--tee', '3',
'train.py', '--job.config_file', './train_configs/galactica_125m.toml',
])
print(' '.join(function.command))
# subprocess.run(function.command)
job = executor.submit(function)
jobs.append(job)
64 changes: 32 additions & 32 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,38 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test 1f1b",
"pp_1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test gpipe",
"pp_gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
# OverrideDefinitions(
# [
# [
# "--checkpoint.enable_checkpoint",
# "--experimental.pipeline_parallel_degree 2",
# "--experimental.pipeline_parallel_split_points layers.4",
# "--experimental.pipeline_parallel_schedule 1f1b",
# "--training.data_parallel_degree 1",
# "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
# ],
# ],
# "PP 1D test 1f1b",
# "pp_1f1b",
# requires_seed_checkpoint=True,
# ngpu=2,
# ),
# OverrideDefinitions(
# [
# [
# "--checkpoint.enable_checkpoint",
# "--experimental.pipeline_parallel_degree 2",
# "--experimental.pipeline_parallel_split_points layers.4",
# "--experimental.pipeline_parallel_schedule gpipe",
# "--training.data_parallel_degree 1",
# "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
# ],
# ],
# "PP 1D test gpipe",
# "pp_gpipe",
# requires_seed_checkpoint=True,
# ngpu=2,
# ),
OverrideDefinitions(
[
[
Expand Down
27 changes: 14 additions & 13 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
Expand Down Expand Up @@ -280,7 +281,7 @@ def __init__(
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
f"Checkpointing active. Checkpoints will be loaded from {self.load_folder} and saved to {self.save_folder}"
)

def __del__(self):
Expand All @@ -291,8 +292,8 @@ def __del__(self):
def reset(self) -> None:
self.begin_time = time.monotonic()

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")
def _create_checkpoint_id(self, step: int, folder: str) -> str:
return os.path.join(folder, f"step-{step}")

def _save_last_step(self, curr_step: int) -> None:
# We only consider saving weights only at the end of the training. So
Expand Down Expand Up @@ -323,7 +324,7 @@ def _save_last_step(self, curr_step: int) -> None:
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step, self.save_folder))
self.reset()

def _should_save(self, curr_step: int, force: bool = False) -> bool:
Expand Down Expand Up @@ -411,7 +412,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
return

begin = time.monotonic()
checkpoint_id = self._create_checkpoint_id(curr_step)
checkpoint_id = self._create_checkpoint_id(curr_step, self.save_folder)
self._async_wait()
if force:
self._save_last_step(curr_step)
Expand Down Expand Up @@ -448,16 +449,16 @@ def maybe_wait_for_staging(self) -> None:
def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
return False
if not os.path.isdir(self.folder):
if not os.path.isdir(self.load_folder):
return False
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)):
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step, self.load_folder)):
return False

if step == -1:
step_counts = []
for filename in os.listdir(self.folder):
for filename in os.listdir(self.load_folder):
match = re.search(r"step-(\d+)", filename)
metadata_probe = os.path.join(self.folder, filename, ".metadata")
metadata_probe = os.path.join(self.load_folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
step_counts.append(int(match.group(1)))
if not step_counts:
Expand All @@ -470,7 +471,7 @@ def load(self, step: int = -1) -> bool:
begin = time.monotonic()
dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step),
checkpoint_id=self._create_checkpoint_id(step, self.load_folder),
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand All @@ -480,9 +481,9 @@ def load(self, step: int = -1) -> bool:
def _purge_stale_checkpoints(self):
if self.keep_latest_k > 0:
discovered_checkpoints = []
for filename in os.listdir(self.folder):
for filename in os.listdir(self.save_folder):
match = re.search(r"step-(\d+)", filename)
path = os.path.join(self.folder, filename)
path = os.path.join(self.save_folder, filename)
discovered_checkpoints.append((int(match.group(1)), path))

discovered_checkpoints.sort()
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def __init__(self):
self.parser.add_argument(
"--training.batch_size", type=int, default=8, help="Batch size"
)
self.parser.add_argument(
"--training.gradient_accumulation_steps",
type=int,
default=1,
help="Interval in steps for gradient accumulation",
)
self.parser.add_argument(
"--training.seq_len", type=int, default=2048, help="Sequence length"
)
Expand Down
13 changes: 12 additions & 1 deletion torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer
from torchtitan.models.opt import opt_configs, OPT, load_opt_weights

models_config = {
"llama2": llama2_configs,
"llama3": llama3_configs,
"opt": opt_configs
}

model_name_to_cls = {"llama2": Transformer, "llama3": Transformer}
model_name_to_cls = {
"llama2": Transformer,
"llama3": Transformer,
"opt": OPT
}

model_name_to_tokenizer = {
"llama2": "sentencepiece",
"llama3": "tiktoken",
"opt": "tiktoken"
}

model_name_to_weights_loading_fns = {
"opt": load_opt_weights
}
2 changes: 2 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, bias=False)
elif norm_type == "np_layernorm":
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "layernorm_bias":
return nn.LayerNorm(dim, eps=eps, bias=True)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
Expand Down
20 changes: 20 additions & 0 deletions torchtitan/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# <model name> is licensed under the <license name>,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.models.opt.model import ModelArgs, OPT
from torchtitan.models.opt.utils import load_opt_weights

__all__ = ["OPT", "load_opt_weights"]

opt_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=8),
"125M": ModelArgs(dim=768, n_layers=12, n_heads=12),
# "1.3B": ModelArgs(dim=2048, n_layers=, n_heads=8),
# "6.7B": ModelArgs(dim=2048, n_layers=, n_heads=8)
}
Loading

0 comments on commit dcb1253

Please sign in to comment.