diff --git a/.gitignore b/.gitignore index cf5f06e1..9488979d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__ .idea +.vscode .DS_Store *.egg-info build diff --git a/submitit_train.py b/submitit_train.py new file mode 100644 index 00000000..8e930b9d --- /dev/null +++ b/submitit_train.py @@ -0,0 +1,31 @@ +import submitit +import datetime +import yaml +import os + + +if __name__ == "__main__": + executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j") + executor.update_parameters( + name="titan", timeout_min=15, + gpus_per_node=2, + nodes=1, mem_gb=30, cpus_per_task=10, + slurm_array_parallelism=10 + ) + + jobs = [] + with executor.batch(): + for _ in range(1): + function = submitit.helpers.CommandFunction([ + 'python3', '-m', 'torch.distributed.run', + '--nproc_per_node', '2', + '--rdzv_backend', 'c10d', + '--rdzv_endpoint', 'localhost:0', + '--local-ranks-filter', '0', + '--role', 'rank', '--tee', '3', + 'train.py', '--job.config_file', './train_configs/galactica_125m.toml', + ]) + print(' '.join(function.command)) + # subprocess.run(function.command) + job = executor.submit(function) + jobs.append(job) diff --git a/test_runner.py b/test_runner.py index a7c95ce1..307842e4 100755 --- a/test_runner.py +++ b/test_runner.py @@ -61,38 +61,38 @@ def build_test_list(): requires_seed_checkpoint=True, ngpu=4, ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", - "--experimental.pipeline_parallel_schedule 1f1b", - "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP - ], - ], - "PP 1D test 1f1b", - "pp_1f1b", - requires_seed_checkpoint=True, - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", - "--experimental.pipeline_parallel_schedule gpipe", - "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP - ], - ], - "PP 1D test gpipe", - "pp_gpipe", - requires_seed_checkpoint=True, - ngpu=2, - ), + # OverrideDefinitions( + # [ + # [ + # "--checkpoint.enable_checkpoint", + # "--experimental.pipeline_parallel_degree 2", + # "--experimental.pipeline_parallel_split_points layers.4", + # "--experimental.pipeline_parallel_schedule 1f1b", + # "--training.data_parallel_degree 1", + # "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + # ], + # ], + # "PP 1D test 1f1b", + # "pp_1f1b", + # requires_seed_checkpoint=True, + # ngpu=2, + # ), + # OverrideDefinitions( + # [ + # [ + # "--checkpoint.enable_checkpoint", + # "--experimental.pipeline_parallel_degree 2", + # "--experimental.pipeline_parallel_split_points layers.4", + # "--experimental.pipeline_parallel_schedule gpipe", + # "--training.data_parallel_degree 1", + # "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + # ], + # ], + # "PP 1D test gpipe", + # "pp_gpipe", + # requires_seed_checkpoint=True, + # ngpu=2, + # ), OverrideDefinitions( [ [ diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 64d2f007..ee326733 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -235,7 +235,8 @@ def __init__( for idx, lr_scheduler in enumerate(lr_schedulers): self.states[f"lr_scheduler_{idx}"] = lr_scheduler - self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) + self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder) + self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder) self.interval_type = ( IntervalType.SECONDS if ckpt_config.interval_type == "seconds" @@ -280,7 +281,7 @@ def __init__( raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") logger.info( - f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" + f"Checkpointing active. Checkpoints will be loaded from {self.load_folder} and saved to {self.save_folder}" ) def __del__(self): @@ -291,8 +292,8 @@ def __del__(self): def reset(self) -> None: self.begin_time = time.monotonic() - def _create_checkpoint_id(self, step: int) -> str: - return os.path.join(self.folder, f"step-{step}") + def _create_checkpoint_id(self, step: int, folder: str) -> str: + return os.path.join(folder, f"step-{step}") def _save_last_step(self, curr_step: int) -> None: # We only consider saving weights only at the end of the training. So @@ -323,7 +324,7 @@ def _save_last_step(self, curr_step: int) -> None: else: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") - dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) + dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step, self.save_folder)) self.reset() def _should_save(self, curr_step: int, force: bool = False) -> bool: @@ -411,7 +412,7 @@ def save(self, curr_step: int, force: bool = False) -> None: return begin = time.monotonic() - checkpoint_id = self._create_checkpoint_id(curr_step) + checkpoint_id = self._create_checkpoint_id(curr_step, self.save_folder) self._async_wait() if force: self._save_last_step(curr_step) @@ -448,16 +449,16 @@ def maybe_wait_for_staging(self) -> None: def load(self, step: int = -1) -> bool: if not self.enable_checkpoint: return False - if not os.path.isdir(self.folder): + if not os.path.isdir(self.load_folder): return False - if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): + if step != -1 and not os.path.isdir(self._create_checkpoint_id(step, self.load_folder)): return False if step == -1: step_counts = [] - for filename in os.listdir(self.folder): + for filename in os.listdir(self.load_folder): match = re.search(r"step-(\d+)", filename) - metadata_probe = os.path.join(self.folder, filename, ".metadata") + metadata_probe = os.path.join(self.load_folder, filename, ".metadata") if match and os.path.isfile(metadata_probe): step_counts.append(int(match.group(1))) if not step_counts: @@ -470,7 +471,7 @@ def load(self, step: int = -1) -> bool: begin = time.monotonic() dcp.load( states, - checkpoint_id=self._create_checkpoint_id(step), + checkpoint_id=self._create_checkpoint_id(step, self.load_folder), ) logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -480,9 +481,9 @@ def load(self, step: int = -1) -> bool: def _purge_stale_checkpoints(self): if self.keep_latest_k > 0: discovered_checkpoints = [] - for filename in os.listdir(self.folder): + for filename in os.listdir(self.save_folder): match = re.search(r"step-(\d+)", filename) - path = os.path.join(self.folder, filename) + path = os.path.join(self.save_folder, filename) discovered_checkpoints.append((int(match.group(1)), path)) discovered_checkpoints.sort() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index a58eb28b..231de76f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -204,6 +204,12 @@ def __init__(self): self.parser.add_argument( "--training.batch_size", type=int, default=8, help="Batch size" ) + self.parser.add_argument( + "--training.gradient_accumulation_steps", + type=int, + default=1, + help="Interval in steps for gradient accumulation", + ) self.parser.add_argument( "--training.seq_len", type=int, default=2048, help="Sequence length" ) diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index c7bb16c6..a236fa77 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -5,15 +5,26 @@ # LICENSE file in the root directory of this source tree. from torchtitan.models.llama import llama2_configs, llama3_configs, Transformer +from torchtitan.models.opt import opt_configs, OPT, load_opt_weights models_config = { "llama2": llama2_configs, "llama3": llama3_configs, + "opt": opt_configs } -model_name_to_cls = {"llama2": Transformer, "llama3": Transformer} +model_name_to_cls = { + "llama2": Transformer, + "llama3": Transformer, + "opt": OPT +} model_name_to_tokenizer = { "llama2": "sentencepiece", "llama3": "tiktoken", + "opt": "tiktoken" } + +model_name_to_weights_loading_fns = { + "opt": load_opt_weights +} \ No newline at end of file diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 798c7c4d..ff54de9a 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -40,6 +40,8 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6): return nn.LayerNorm(dim, eps=eps, bias=False) elif norm_type == "np_layernorm": return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "layernorm_bias": + return nn.LayerNorm(dim, eps=eps, bias=True) elif norm_type == "rmsnorm": return RMSNorm(dim, eps=eps) elif norm_type == "compiled_rmsnorm": diff --git a/torchtitan/models/opt/__init__.py b/torchtitan/models/opt/__init__.py new file mode 100644 index 00000000..98178dce --- /dev/null +++ b/torchtitan/models/opt/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# is licensed under the , +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.models.opt.model import ModelArgs, OPT +from torchtitan.models.opt.utils import load_opt_weights + +__all__ = ["OPT", "load_opt_weights"] + +opt_configs = { + "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=8), + "125M": ModelArgs(dim=768, n_layers=12, n_heads=12), + # "1.3B": ModelArgs(dim=2048, n_layers=, n_heads=8), + # "6.7B": ModelArgs(dim=2048, n_layers=, n_heads=8) +} \ No newline at end of file diff --git a/torchtitan/models/opt/model.py b/torchtitan/models/opt/model.py new file mode 100644 index 00000000..8f4fac82 --- /dev/null +++ b/torchtitan/models/opt/model.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# is licensed under the , +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torchtitan.models.norms import build_norm + + +@dataclass +class ModelArgs: + dim: int = 768 + n_layers: int = 12 + n_heads: int = 12 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + dropout_p: float = 0.1 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "layersnorm" + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class LearnedPositionalEmbedding(nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, positions): + return super().forward(positions + self.offset - 1) # subtract one to offset the indices to 0 + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + self.dropout_p = model_args.dropout_p + + # use bias for q, k, v projections + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=True + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=True) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=True) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=True + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training, add attention dropout during the training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True, dropout_p=self.dropout_p if self.training else 0.0) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + dropout_p: float + ): + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.dropout_p = dropout_p + + # use bias for ffn + self.w1 = nn.Linear(dim, hidden_dim, bias=True) + self.w2 = nn.Linear(hidden_dim, dim, bias=True) + + def forward(self, x): + # GELU activation function + x = self.w2(F.gelu(self.w1(x))) + x = F.dropout(x, p=self.dropout_p, training=self.training) + return x + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w2.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (LayerNorm): Layer normalization for attention output. + ffn_norm (LayerNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + dropout_p=model_args.dropout_p + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + self.dropout_p = model_args.dropout_p + + self.attention_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + # attention + h = self.attention(self.attention_norm(x)) + # add dropout during the training + h = x + F.dropout(h, p=self.dropout_p, training=self.training) + # pointwise ffn + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class OPT(nn.Module): + """ + Transformer Module + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (LayerNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.pos_encoder = LearnedPositionalEmbedding(model_args.max_seq_len, model_args.dim) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + nn.init.normal_(self.pos_encoder.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # get batch size and sequence length + batch_size, seq_length = tokens.shape + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + positions = torch.cumsum(torch.ones(batch_size, seq_length, device=h.device, dtype=torch.long), dim=1) + h = h + self.pos_encoder(positions) + + for layer in self.layers.values(): + h = layer(h) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/torchtitan/models/opt/utils.py b/torchtitan/models/opt/utils.py new file mode 100644 index 00000000..ce822bb7 --- /dev/null +++ b/torchtitan/models/opt/utils.py @@ -0,0 +1,61 @@ +from transformers import OPTForCausalLM +from torchtitan.models.opt import OPT + + +def get_hf_opt_state_dict_keys_mapping(num_layers: int): + """ + Get a mapping between state dict keys of different implementations. + + Args: + num_layers (int): number of transformer layers (blocks). + + Returns: + dict: mapping between local implementation state dict keys and hf implementation state dict keys + + """ + keys_mapping = { + 'tok_embeddings.weight': 'model.decoder.embed_tokens.weight', + 'pos_encoder.weight': 'model.decoder.embed_positions.weight', + # add layer weight mappings here + 'norm.weight': 'model.decoder.final_layer_norm.weight', + 'norm.bias': 'model.decoder.final_layer_norm.bias', + "output.weight": 'lm_head.weight', + } + for layer in range(num_layers): + keys_mapping.update({ + f'layers.{layer}.attention.wq.weight': f'model.decoder.layers.{layer}.self_attn.q_proj.weight', + f'layers.{layer}.attention.wq.bias': f'model.decoder.layers.{layer}.self_attn.q_proj.bias', + f'layers.{layer}.attention.wk.weight': f'model.decoder.layers.{layer}.self_attn.k_proj.weight', + f'layers.{layer}.attention.wk.bias': f'model.decoder.layers.{layer}.self_attn.k_proj.bias', + f'layers.{layer}.attention.wv.weight': f'model.decoder.layers.{layer}.self_attn.v_proj.weight', + f'layers.{layer}.attention.wv.bias': f'model.decoder.layers.{layer}.self_attn.v_proj.bias', + f'layers.{layer}.attention.wo.weight': f'model.decoder.layers.{layer}.self_attn.out_proj.weight', + f'layers.{layer}.attention.wo.bias': f'model.decoder.layers.{layer}.self_attn.out_proj.bias', + f'layers.{layer}.feed_forward.w1.weight': f'model.decoder.layers.{layer}.fc1.weight', + f'layers.{layer}.feed_forward.w1.bias': f'model.decoder.layers.{layer}.fc1.bias', + f'layers.{layer}.feed_forward.w2.weight': f'model.decoder.layers.{layer}.fc2.weight', + f'layers.{layer}.feed_forward.w2.bias': f'model.decoder.layers.{layer}.fc2.bias', + f'layers.{layer}.attention_norm.weight': f'model.decoder.layers.{layer}.self_attn_layer_norm.weight', + f'layers.{layer}.attention_norm.bias': f'model.decoder.layers.{layer}.self_attn_layer_norm.bias', + f'layers.{layer}.ffn_norm.weight': f'model.decoder.layers.{layer}.final_layer_norm.weight', + f'layers.{layer}.ffn_norm.bias': f'model.decoder.layers.{layer}.final_layer_norm.bias' + }) + + return keys_mapping + + +def load_opt_weights(model: OPT, weights_path: str, source: str): + """ + write docs + """ + if source == "huggingface": + hf_model = OPTForCausalLM.from_pretrained(weights_path) + keys_mapping = get_hf_opt_state_dict_keys_mapping(model.n_layers) + hf_state_dict = hf_model.state_dict() + corrected_state_dict = {} + for key, value in keys_mapping.items(): + corrected_state_dict[key] = hf_state_dict[value] + + model.load_state_dict(corrected_state_dict) + else: + raise NotImplemented \ No newline at end of file diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index b75cb336..a6617d2d 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -19,8 +19,9 @@ models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, + 'opt': parallelize_llama, } models_pipelining_fns = { "llama2": pipeline_llama, - "llama3": pipeline_llama, + "llama3": pipeline_llama } diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index eb6d1a9c..475899f4 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -35,9 +35,7 @@ def _validate(self): def build_mesh(self, device_type): dims = [] names = [] - for d, name in zip( - [self.dp], ["dp"], strict=True - ): + for d, name in zip([self.dp], ["dp"], strict=True): if d > 1: dims.append(d) names.append(name) @@ -51,8 +49,7 @@ def dp_enabled(self): @property def loss_parallel_enabled(self): - return False # requires tensor parallelism - + return False # requires tensor parallelism @cached_property def model_parallel_size(self): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 34ee4166..4432254d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -14,22 +14,13 @@ from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._composable.replicate import replicate -from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - PrepareModuleInput, - RowwiseParallel, - SequenceParallel, -) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_strided_sharding_enabled def parallelize_llama( @@ -46,7 +37,6 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ - if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -196,7 +186,9 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard(model, **fsdp_config, reshard_after_forward=True) # in torch titan, this was "not pp_enabled" + fully_shard( + model, **fsdp_config, reshard_after_forward=True + ) # in torch titan, this was "not pp_enabled" logger.info("Applied FSDP to the model") diff --git a/train.py b/train.py index 43d830b7..f02b8a45 100644 --- a/train.py +++ b/train.py @@ -12,10 +12,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record from torch.fx import GraphModule -import torch.nn.functional as F -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState @@ -24,13 +20,14 @@ from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import ( - models_parallelize_fns, - models_pipelining_fns, - ParallelDims, +from torchtitan.models import ( + model_name_to_cls, + model_name_to_weights_loading_fns, + model_name_to_tokenizer, + models_config ) +from torchtitan.optimizer import build_lr_schedulers, build_optimizers +from torchtitan.parallelisms import models_parallelize_fns, ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling @@ -84,9 +81,9 @@ def main(job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 - model_name = job_config.model.name world_mesh = parallel_dims.build_mesh(device_type="cuda") + init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] @@ -112,12 +109,24 @@ def main(job_config: JobConfig): # 3. max_seq_len base on inputs model_config.norm_type = job_config.model.norm_type model_config.vocab_size = tokenizer.n_words + model_config.vocab_size = 50000 model_config.max_seq_len = job_config.training.seq_len logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): model = model_cls.from_model_args(model_config) + # load the model on rank 0 only, then FSDP will distribute the weights + if job_config.checkpoint.create_seed_checkpoint: + assert ( + world_size == 1 + ), "Must create seed-checkpoint using one gpu, to disable sharding" + model.to_empty(device=init_device) + model_name_to_weights_loading_fns[model_name]( + model, weights_path=job_config.checkpoint.load_folder, + source=job_config.checkpoint.weights_source + ) + # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs @@ -146,7 +155,6 @@ def loss_fn(pred, labels): models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) # move sharded model to CPU/GPU and initialize weights via DTensor - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" model.to_empty(device=init_device) model_parts = [model] @@ -154,7 +162,8 @@ def loss_fn(pred, labels): # skip traced modules since we do not define init_weights in the traced module if isinstance(mod, GraphModule): continue - mod.init_weights() + if not job_config.checkpoint.create_seed_checkpoint: + mod.init_weights() mod.train() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() @@ -190,7 +199,6 @@ def loss_fn(pred, labels): checkpoint_loaded = checkpoint.load() - metric_logger = build_metric_logger(job_config, parallel_dims) args, cmd_args = job_config.parse_args_from_command_line(job_config.args_list) job_config_dict = job_config._args_to_two_level_dict(args) @@ -226,8 +234,8 @@ def loss_fn(pred, labels): # train loop logger.info( f"Training starts at step {train_state.step + 1}, " - f"with local batch size {job_config.training.batch_size}, " - f"global batch size {job_config.training.batch_size * dp_degree}, " + f"with local batch size {job_config.training.batch_size * job_config.training.gradient_accumulation_steps}, " + f"global batch size {job_config.training.batch_size * job_config.training.gradient_accumulation_steps * dp_degree}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " f"(warmup {job_config.training.warmup_steps})" @@ -243,21 +251,23 @@ def loss_fn(pred, labels): # get batch data_load_start = time.perf_counter() - batch = next(data_iterator) - input_ids, labels = batch - ntokens_since_last_log += labels.numel() - data_loading_times.append(time.perf_counter() - data_load_start) - - input_ids = input_ids.cuda() - labels = labels.cuda() optimizers.zero_grad() - with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + + for _ in range(job_config.training.gradient_accumulation_steps): + batch = next(data_iterator) + input_ids, labels = batch + ntokens_since_last_log += labels.numel() + input_ids = input_ids.cuda() + labels = labels.cuda() + data_loading_times.append(time.perf_counter() - data_load_start) + + with train_context(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() for m in model_parts: torch.nn.utils.clip_grad_norm_( m.parameters(), job_config.training.max_norm, foreach=True @@ -266,8 +276,8 @@ def loss_fn(pred, labels): # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model_parts) - # optimizer step checkpoint.maybe_wait_for_staging() + # optimizer step optimizers.step() lr_schedulers.step() diff --git a/train_configs/chemlactica_125m.toml b/train_configs/chemlactica_125m.toml new file mode 100644 index 00000000..5f303f55 --- /dev/null +++ b/train_configs/chemlactica_125m.toml @@ -0,0 +1,63 @@ +# torchtitan Config.toml + +[job] +dump_folder = "/nfs/dgx/raid/chem/titan_outputs" +description = "Galactica training" +use_for_integration_test = false + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +enable_color_printing = true +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "opt" +flavor = "125M" +norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +# test tokenizer.model, for debug purpose only +tokenizer_path = "./test/assets/test_tiktoken.model" + +[optimizer] +name = "AdamW" +lr = 8e-4 + +[training] +batch_size = 8 +seq_len = 2048 +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 10 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[experimental] +pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false + +[checkpoint] +enable_checkpoint = true +create_seed_checkpoint = false +load_folder = "facebook/galactica-125m" +save_folder = "yerevann/chemlactica-125m" +interval_type = "steps" +interval = 5 +model_weights_only = false +export_dtype = "float32" +async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index ff229e3f..27d9b4cc 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -25,13 +25,15 @@ flavor = "debugmodel" norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./test/assets/test_tiktoken.model" +init_weights = true [optimizer] name = "AdamW" lr = 8e-4 [training] -batch_size = 8 +batch_size = 1 +gradient_accumulation_steps = 1 seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping diff --git a/train_configs/galactica_125m_hf_to_titan.toml b/train_configs/galactica_125m_hf_to_titan.toml new file mode 100644 index 00000000..1318d4cf --- /dev/null +++ b/train_configs/galactica_125m_hf_to_titan.toml @@ -0,0 +1,64 @@ +# torchtitan Config.toml + +[job] +dump_folder = "/nfs/dgx/raid/chem/titan_outputs" +description = "Galactica training" +use_for_integration_test = false + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +enable_color_printing = true +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "opt" +flavor = "125M" +norm_type = "layernorm_bias" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +# test tokenizer.model, for debug purpose only +tokenizer_path = "./test/assets/test_tiktoken.model" + +[optimizer] +name = "AdamW" +lr = 8e-4 + +[training] +batch_size = 8 +seq_len = 2048 +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 10 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[experimental] +pipeline_parallel_degree = 1 +enable_async_tensor_parallel = false + +[checkpoint] +enable_checkpoint = true +create_seed_checkpoint = true +load_folder = "facebook/galactica-125m" +weights_source = "huggingface" +save_folder = "facebook/galactica-125m" +interval_type = "steps" +interval = 5 +model_weights_only = false +export_dtype = "float32" +async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'selective' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false