Skip to content

Commit

Permalink
Merge pull request #9 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Model loading
  • Loading branch information
tigranfah authored Aug 26, 2024
2 parents 21d8e10 + ba91e87 commit 5cd10c7
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 39 deletions.
31 changes: 31 additions & 0 deletions submitit_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import submitit
import datetime
import yaml
import os


if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
executor.update_parameters(
name="titan", timeout_min=15,
gpus_per_node=2,
nodes=1, mem_gb=30, cpus_per_task=10,
slurm_array_parallelism=10
)

jobs = []
with executor.batch():
for _ in range(1):
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
'--nproc_per_node', '2',
'--rdzv_backend', 'c10d',
'--rdzv_endpoint', 'localhost:0',
'--local-ranks-filter', '0',
'--role', 'rank', '--tee', '3',
'train.py', '--job.config_file', './train_configs/galactica_125m.toml',
])
print(' '.join(function.command))
# subprocess.run(function.command)
job = executor.submit(function)
jobs.append(job)
27 changes: 14 additions & 13 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
f"Checkpointing active. Checkpoints will be loaded from {self.load_folder} and saved to {self.save_folder}"
)

def __del__(self):
Expand All @@ -289,8 +290,8 @@ def __del__(self):
def reset(self) -> None:
self.begin_time = time.monotonic()

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")
def _create_checkpoint_id(self, step: int, folder: str) -> str:
return os.path.join(folder, f"step-{step}")

def _save_last_step(self, curr_step: int) -> None:
# We only consider saving weights only at the end of the training. So
Expand Down Expand Up @@ -321,7 +322,7 @@ def _save_last_step(self, curr_step: int) -> None:
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step, self.save_folder))
self.reset()

def _should_save(self, curr_step: int, force: bool = False) -> bool:
Expand Down Expand Up @@ -409,7 +410,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
return

begin = time.monotonic()
checkpoint_id = self._create_checkpoint_id(curr_step)
checkpoint_id = self._create_checkpoint_id(curr_step, self.save_folder)
self._async_wait()
if force:
self._save_last_step(curr_step)
Expand Down Expand Up @@ -446,16 +447,16 @@ def maybe_wait_for_staging(self) -> None:
def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
return False
if not os.path.isdir(self.folder):
if not os.path.isdir(self.load_folder):
return False
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)):
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step, self.load_folder)):
return False

if step == -1:
step_counts = []
for filename in os.listdir(self.folder):
for filename in os.listdir(self.load_folder):
match = re.search(r"step-(\d+)", filename)
metadata_probe = os.path.join(self.folder, filename, ".metadata")
metadata_probe = os.path.join(self.load_folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
step_counts.append(int(match.group(1)))
if not step_counts:
Expand All @@ -468,7 +469,7 @@ def load(self, step: int = -1) -> bool:
begin = time.monotonic()
dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step),
checkpoint_id=self._create_checkpoint_id(step, self.load_folder),
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand All @@ -478,9 +479,9 @@ def load(self, step: int = -1) -> bool:
def _purge_stale_checkpoints(self):
if self.keep_latest_k > 0:
discovered_checkpoints = []
for filename in os.listdir(self.folder):
for filename in os.listdir(self.save_folder):
match = re.search(r"step-(\d+)", filename)
path = os.path.join(self.folder, filename)
path = os.path.join(self.save_folder, filename)
discovered_checkpoints.append((int(match.group(1)), path))

discovered_checkpoints.sort()
Expand Down
15 changes: 11 additions & 4 deletions torchtitan/models/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,24 @@ def __init__(
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
dropout_p: float
):
super().__init__()
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.dropout_p = dropout_p

# use bias for ffn
self.w1 = nn.Linear(dim, hidden_dim, bias=True)
self.w2 = nn.Linear(hidden_dim, dim, bias=True)

def forward(self, x):
# use GELU activation function
return self.w2(F.gelu(self.w1(x)))
# GELU activation function
x = self.w2(F.gelu(self.w1(x)))
x = F.dropout(x, p=self.dropout_p, training=self.training)
return x

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=init_std)
Expand Down Expand Up @@ -222,6 +226,7 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
dropout_p=model_args.dropout_p
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
Expand Down Expand Up @@ -253,9 +258,11 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention(self.attention_norm(x))
# attention
h = self.attention(self.attention_norm(x))
# add dropout during the training
h = F.dropout(h, p=self.dropout_p, training=self.training)
h = x + F.dropout(h, p=self.dropout_p, training=self.training)
# pointwise ffn
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down
30 changes: 14 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,19 @@ def main(job_config: JobConfig):
model_config.max_seq_len = job_config.training.seq_len

logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
# with torch.device("meta"):
model = model_cls.from_model_args(model_config)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)

# load the model on rank 0 only, then FSDP will distribute the weights
if job_config.model.init_weights:
if dp_rank == 0:
# model.to_empty(device=init_device)
model.init_weights()
else:
if dp_rank == 0:
# model.to_empty(device=init_device)
model_name_to_weights_loading_fns[model_name](
model, weights_path=job_config.model.load_weights_path,
source=job_config.model.weights_source
)
if job_config.checkpoint.create_seed_checkpoint:
assert (
world_size == 1
), "Must create seed-checkpoint using one gpu, to disable sharding"
model.to_empty(device=init_device)
model_name_to_weights_loading_fns[model_name](
model, weights_path=job_config.checkpoint.load_folder,
source=job_config.checkpoint.weights_source
)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
Expand Down Expand Up @@ -157,15 +155,15 @@ def loss_fn(pred, labels):
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
model.to(device=init_device)
model.to_empty(device=init_device)
model_parts = [model]

for mod in model_parts:
# skip traced modules since we do not define init_weights in the traced module
if isinstance(mod, GraphModule):
continue
# if job_config.model.init_weights:
# mod.init_weights()
if not job_config.checkpoint.create_seed_checkpoint:
mod.init_weights()
mod.train()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
Expand Down
63 changes: 63 additions & 0 deletions train_configs/chemlactica_125m.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# torchtitan Config.toml

[job]
dump_folder = "/nfs/dgx/raid/chem/titan_outputs"
description = "Galactica training"
use_for_integration_test = false

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "opt"
flavor = "125M"
norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

[optimizer]
name = "AdamW"
lr = 8e-4

[training]
batch_size = 8
seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = true
create_seed_checkpoint = false
load_folder = "facebook/galactica-125m"
save_folder = "yerevann/chemlactica-125m"
interval_type = "steps"
interval = 5
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = false
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# torchtitan Config.toml

[job]
dump_folder = "./outputs"
dump_folder = "/nfs/dgx/raid/chem/titan_outputs"
description = "Galactica training"
use_for_integration_test = false

Expand All @@ -23,9 +23,6 @@ name = "opt"
flavor = "125M"
norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
init_weights = false
load_weights_path = "facebook/galactica-125m"
weights_source = "huggingface"
tokenizer_path = "./test/assets/test_tiktoken.model"

[optimizer]
Expand All @@ -48,8 +45,11 @@ pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
enable_checkpoint = true
create_seed_checkpoint = true
load_folder = "facebook/galactica-125m"
weights_source = "huggingface"
save_folder = "facebook/galactica-125m"
interval_type = "steps"
interval = 5
model_weights_only = false
Expand Down

0 comments on commit 5cd10c7

Please sign in to comment.