Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model loading #25

Merged
merged 14 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@

if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 4
n_gpus = 8
executor.update_parameters(
name="titan", timeout_min=15,
name="titan", timeout_min=3 * 24 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=40, cpus_per_task=n_gpus * 2
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "h100"
}
)

jobs = []
with executor.batch():
for _ in range(1):
# train_config = './train_configs/chemlactica_125m.toml'
train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = './train_configs/llama3_8b.toml'
# train_config = './train_configs/llama3_8b.toml'
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
states: Dict[str, Any],
job_config: JobConfig,
experiment_hash: str,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
Expand Down Expand Up @@ -235,7 +236,7 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler

self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
self.save_folder = os.path.join(job_config.job.dump_folder, os.path.join(ckpt_config.save_folder, experiment_hash))
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.interval_type = (
IntervalType.SECONDS
Expand Down
22 changes: 15 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,20 @@ def __init__(self):
help="Whether to enable checkpoint",
)
self.parser.add_argument(
"--checkpoint.folder",
"--checkpoint.load_folder",
type=str,
default="",
help="""
The folder to load the checkpoints.
When enable_checkpoint is set to true, checkpoints will loaded from {--job.dump_folder}/{--checkpoint.load_folder}.
""",
)
self.parser.add_argument(
"--checkpoint.save_folder",
type=str,
default="checkpoint",
help="""
The folder to store the checkpoints.
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
When enable_checkpoint is set to true, checkpoints will saved to {--job.dump_folder}/{--checkpoint.save_folder}.
""",
)
self.parser.add_argument(
Expand Down Expand Up @@ -643,13 +651,13 @@ def parse_args(self, args_list: list = sys.argv[1:]):
args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
# build up a two level dict
args_dict = self._args_to_two_level_dict(args)
self.args_dict = self._args_to_two_level_dict(args)
if config_file is not None:
try:
with open(config_file, "rb") as f:
for k, v in tomllib.load(f).items():
# to prevent overwrite of non-specified keys
args_dict[k] |= v
self.args_dict[k] |= v
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(
f"Error while loading the configuration file: {config_file}"
Expand All @@ -661,9 +669,9 @@ def parse_args(self, args_list: list = sys.argv[1:]):
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
for k, v in section_args.items():
args_dict[section][k] = v
self.args_dict[section][k] = v

for k, v in args_dict.items():
for k, v in self.args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()
Expand Down
18 changes: 12 additions & 6 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import pickle
from typing import Any, Dict, List, Optional
from pathlib import Path
import glob
import os

import numpy as np

Expand Down Expand Up @@ -33,7 +36,8 @@
_supported_datasets = {
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",
"chemlactica_train_mini": "test/assets/chemlactica_train_mini"
"chemlactica_train_mini": "test/assets/chemlactica_train_mini",
"chemlactica_train": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form/train_rdkit_computed_rel+form"
}

_supported_data_processing_styles = {
Expand Down Expand Up @@ -111,13 +115,16 @@ def __init__(
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
elif dataset_name == "c4_test":
ds = load_dataset(dataset_path, split="train")

else:
dataset_files = glob.glob(os.path.join(dataset_path, "*.jsonl"))
ds = load_dataset("text", data_files=dataset_files, split="train", streaming=True)
try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
except KeyError as e:
raise ValueError(f"Unsupported data processing style: {data_processing_style}")
# data_processing_fn = lambda x, e: str(x)

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
Expand Down Expand Up @@ -217,9 +224,8 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, pin_memory: bool, num_workers: int):
super().__init__(hf_ds, batch_size)
super().__init__(hf_ds, batch_size, num_workers=num_workers)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

Expand Down Expand Up @@ -251,7 +257,7 @@ def build_hf_data_loader(
rank,
infinite: bool = True,
pin_memory: bool = False,
num_workers: int = 0,
num_workers: int = 2,
special_mode = None,
context = "train",
):
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def init_logger(log_level):
# suppress verbose torch.profiler logging
os.environ["KINETO_LOG_LEVEL"] = "5"

# enable dataloading logging for logging the type of dataloading used
enable_dataloader_logging(log_level)


class LogLevel(Enum):
DEBUG = "DEBUG"
Expand All @@ -46,3 +49,6 @@ def from_string(cls, value: str):
def validate_log_level(value):
return LogLevel.from_string(value)


def enable_dataloader_logging(log_level):
logging.getLogger('datasets.iterable_dataset').setLevel(log_level)
8 changes: 7 additions & 1 deletion torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def log_hparams(self, config):
if self.writer is not None:
self.writer.experiment['hparams'] = config

@property
def experiment_hash(self):
if self.writer is None:
return "default"
return self.writer._run.hash

def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims
):
Expand All @@ -127,7 +133,7 @@ def build_metric_logger(
"""
dump_dir = job_config.job.dump_folder
aim_config = job_config.metrics
save_aim_folder = aim_config.save_aim_folder
save_aim_folder = os.path.join(job_config.job.dump_folder, aim_config.save_aim_folder)
philippguevorguian marked this conversation as resolved.
Show resolved Hide resolved
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, datetime_str)
Expand Down
5 changes: 2 additions & 3 deletions torchtitan/models/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,17 @@ class ModelArgs:
n_heads: int = 12
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 10000
dropout_p: float = 0.1

max_batch_size: int = 32
max_seq_len: int = 2048
# If `True`, then each transformer block init uses its layer ID, and if
# `False`, each uses the total number of transformer blocks
depth_init: bool = True
norm_type: str = "layersnorm"
norm_type: str = "layernorm_bias"


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def export_opt_weights(model: OPT, save_dir: str, token_embedding_size: int):
"""
write docs
"""
hf_model = OPTForCausalLM.from_pretrained(map_n_layers_to_model_name(model.n_layers))
hf_model = OPTForCausalLM.from_pretrained(map_n_layers_to_model_name(model.n_layers), tie_word_embeddings=False)
hf_model.resize_token_embeddings(new_num_tokens=token_embedding_size)
keys_mapping = get_hf_opt_state_dict_keys_mapping(model.n_layers)
state_dict = model.state_dict()
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/tokenizers/tokenizer/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
# copied and adjusted from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py

from typing import List
import os

from torchtitan.logging import logger
from transformers import AutoTokenizer

os.environ["TOKENIZER_PARALLELISM"] = "true"


class CustomTokenizer:
"""
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def load_jsonl_line(jsonl_line):

def chemlactica_style_data_processing(sample_json, rng):
try:
sample_json = json.loads(sample_json["text"])
compound = delete_empty_tags(sample_json)
sample_json = generate_formatted_string(
compound, rng
)
except Exception as e:
print(e)
sample_json = {}
sample_json = ""
return sample_json


Expand Down
20 changes: 11 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import os
import time
import logging
from datetime import timedelta

import torch
Expand Down Expand Up @@ -186,6 +187,9 @@ def loss_fn(pred, labels):

train_state = TrainState()

metric_logger = build_metric_logger(job_config, parallel_dims)
metric_logger.log_hparams(job_config.args_dict)

# load initial checkpoint
checkpoint = CheckpointManager(
dataloader=data_loader,
Expand All @@ -194,6 +198,7 @@ def loss_fn(pred, labels):
lr_schedulers=lr_schedulers.schedulers,
states={"train_state": train_state},
job_config=job_config,
experiment_hash=metric_logger.experiment_hash
)

if job_config.model_download_export.to_titan:
Expand All @@ -218,11 +223,6 @@ def loss_fn(pred, labels):
logger.info("Created huggingface checkpoint")
return

metric_logger = build_metric_logger(job_config, parallel_dims)
args, cmd_args = job_config.parse_args_from_command_line(job_config.args_list)
job_config_dict = job_config._args_to_two_level_dict(args)
metric_logger.log_hparams(job_config_dict)

data_iterator = iter(data_loader)

train_context = get_train_context(
Expand Down Expand Up @@ -284,12 +284,14 @@ def loss_fn(pred, labels):
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

for m in model_parts:
philippguevorguian marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

if force_finish_train:
break
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
m.parameters(), job_config.training.max_norm, foreach=True
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
Expand Down
31 changes: 17 additions & 14 deletions train_configs/chemlactica_1.3b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,52 @@ save_memory_snapshot_folder = "memory_snapshot"
[metrics]
log_freq = 1
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"
enable_aim = true
save_aim_folder = "aim"

[model]
name = "opt"
flavor = "1.3B"
# norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm"
norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"
# tokenizer_path = "./test/assets/test_tiktoken.model"
tokenizer_path = "./torchtitan/tokenizers/chemlactica-125m"

[optimizer]
name = "AdamW"
lr = 8e-4
lr = 1.0e-4

[training]
batch_size = 10
batch_size = 13
gradient_accumulation_steps = 9
seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 500 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 50
steps = 18000
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
compile = true
# dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M), chemlactica_train_mini (4K)
dataset = "chemlactica_train"
data_process_style="chemlactica_style"

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = true
create_seed_checkpoint = false
load_folder = "facebook/galactica-1.3b"
save_folder = "yerevann/chemlactica-1.3b"
interval_type = "steps"
interval = 100
interval = 2000
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
Loading
Loading