Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add WSD-S scheduler #35

Merged
merged 3 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion submitit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
for _ in range(1):
# train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = "./train_configs/llama3.2_3b.toml"
train_config = "./train_configs/llama3.2_1b.toml"
# train_config = "./train_configs/llama3.2_3b.toml"
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction(
[
Expand Down
76 changes: 47 additions & 29 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging import logger, validate_log_level

from typing import Optional

from torchtitan.logging import logger, validate_log_level

TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


def string_list(raw_arg):
return raw_arg.split(",")

Expand Down Expand Up @@ -181,7 +182,10 @@ def __init__(self):
"--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
)
self.parser.add_argument(
"--optimizer.schedule", type=str, default="Linear", help="Optimization schedule to use"
"--optimizer.schedule",
type=str,
default="Linear",
help="Optimization schedule to use",
)
self.parser.add_argument(
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
Expand Down Expand Up @@ -225,17 +229,29 @@ def __init__(self):
help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
)
self.parser.add_argument(
"--training.decay_steps",
type=Optional[int],
default=None,
help="Steps for lr scheduler decay, default is decay starts immediately after warmup",
"--training.num_decays",
type=Optional[float],
default=1,
help="The number of total decays to perform throughout the training, following the WSD-S scheduler",
)
self.parser.add_argument(
"--training.decay_type",
type=str,
default="linear",
choices = ["linear","cosine"],
help="Steps for lr scheduler decay type, defaults to linear",
"--training.decay_steps",
type=Optional[int],
default=None,
help="Steps for lr scheduler decay, default is decay starts immediately after warmup",
)
self.parser.add_argument(
"--training.decay_steps_perc",
type=Optional[float],
default=1.0,
help="The percentage of the steps to use as decay steps",
)
self.parser.add_argument(
"--training.decay_type",
type=str,
default="linear",
choices=["linear", "cosine"],
help="Steps for lr scheduler decay type, defaults to linear",
)
self.parser.add_argument(
"--training.max_norm",
Expand Down Expand Up @@ -266,7 +282,7 @@ def __init__(self):
default=True,
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
)
self.parser.add_argument(
"--training.representation_type",
default="SMILES",
Expand Down Expand Up @@ -387,9 +403,7 @@ def __init__(self):
)

# validation configs
self.parser.add_argument(
"--validation.batch_size", type=int, default=None
)
self.parser.add_argument("--validation.batch_size", type=int, default=None)
self.parser.add_argument(
"--validation.dataset", type=str, help="Dataset to use", default=None
)
Expand All @@ -402,10 +416,16 @@ def __init__(self):
default=None,
)
self.parser.add_argument(
"--validation.valid_freq", type=int, default=1024, help="How often to evaluate the model and log metrics to aim."
"--validation.valid_freq",
type=int,
default=1024,
help="How often to evaluate the model and log metrics to aim.",
)
self.parser.add_argument(
"--validation.enable_valid", type=bool, default=False, help="Whether to do validation."
"--validation.enable_valid",
type=bool,
default=False,
help="Whether to do validation.",
)

# checkpointing configs
Expand Down Expand Up @@ -647,35 +667,33 @@ def __init__(self):
)
self.parser.add_argument(
"--logging.log_level",
default = "INFO",
default="INFO",
choices=["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"],
type=str,
help="Set the log level, INFO by default"
help="Set the log level, INFO by default",
)
self.parser.add_argument(
"--dataloader.num_workers",
default = 0,
default=0,
type=int,
help="""Set the number of dataloader workers PER RANK, default is 0. 1 is non-blocking.
More than 1 may lead to issues with data splitting / duplication"""
More than 1 may lead to issues with data splitting / duplication""",
)
self.parser.add_argument(
"--dataloader.pin_memory",
default = False,
default=False,
type=bool,
help= "Whether or not to pin dataloader memory"
help="Whether or not to pin dataloader memory",
)

self.parser.add_argument(
"--dataloader.special_mode",
default = None,
choices = ["yield_tensor"],
default=None,
choices=["yield_tensor"],
type=str,
help= "Enable a special dataloading mode, useful for debugging"
help="Enable a special dataloading mode, useful for debugging",
)



def parse_args(self, args_list: list = sys.argv[1:]):
self.args_list = args_list

Expand Down
87 changes: 75 additions & 12 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import functools
import math
from enum import Enum

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig
from enum import Enum


def build_optimizers(model_parts, job_config: JobConfig):
Expand Down Expand Up @@ -57,12 +57,14 @@ def zero_grad(self):

return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def linear_warmup(warmup_steps: int, current_step: int) -> float:
"""Computes the linear warmup scaling factor."""
if warmup_steps <= 0:
raise ValueError("warmup_steps must be positive.")
return float((current_step + 1) / (warmup_steps + 1))


# Decay functions
def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float:
"""Computes the linear decay scaling factor."""
Expand All @@ -71,61 +73,121 @@ def linear_decay(decay_steps: int, current_step: int, start_step: int) -> float:
progress = float((current_step - start_step) / decay_steps)
return max(0.0, 1 - progress)


def cosine_decay(decay_steps: int, current_step: int, start_step: int) -> float:
"""Computes the cosine decay scaling factor."""
if decay_steps <= 0:
raise ValueError("decay_steps must be positive.")
current_step = min(current_step - start_step, decay_steps)
return 0.5 * (1 + math.cos(math.pi * current_step / decay_steps))


class Decay(Enum):
LINEAR = functools.partial(linear_decay)
COSINE = functools.partial(cosine_decay)

@staticmethod
def from_string(decay_type: str) -> 'Decay':
def from_string(decay_type: str) -> "Decay":
"""Converts a string to the corresponding Decay enum value."""
try:
return Decay[decay_type.upper()]
except KeyError:
raise ValueError(f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}")
except KeyError as e:
raise ValueError(
f"Invalid decay type: {decay_type}. Expected one of {list(Decay.__members__.keys())}"
) from e


def warmup_stable_decay(
decay_type: Decay, warmup_steps: int, decay_steps: int,training_steps:int, current_step: int
decay_type: Decay,
warmup_steps: int,
decay_steps: int,
training_steps: int,
current_step: int,
) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
start_decay_step = training_steps-decay_steps
start_decay_step = training_steps - decay_steps

if current_step < warmup_steps:
# warmup phase
curr_adjustment = linear_warmup(warmup_steps,current_step)
return linear_warmup(warmup_steps,current_step)
curr_adjustment = linear_warmup(warmup_steps, current_step)
return linear_warmup(warmup_steps, current_step)

elif (current_step >= warmup_steps) and (current_step<start_decay_step):
elif (current_step >= warmup_steps) and (current_step < start_decay_step):
# stable phase, no adjustment to lr
return 1.0

else:
# decay phase supporting multiple decay functions
return decay_type.value(decay_steps, current_step, start_decay_step)


# implementation of WSD-S scheduler
def warmup_stable_decay_simplified(
decay_type: Decay,
warmup_steps: int,
decay_steps_perc: float,
num_decays: int,
training_steps: int,
current_step: int,
) -> float:
# num steps for each decay
per_decay_num_steps = training_steps // num_decays
# current decay index
decay_index = math.ceil(current_step / per_decay_num_steps)
# the step at which lr is decayed
decay_at_step = decay_index * per_decay_num_steps
# number of decay steps
if decay_index == 1:
# make sure the decay_steps_perc does not include the warmup_steps
decay_steps_perc = min(decay_steps_perc, 1 - warmup_steps / decay_at_step)

decay_steps = int(decay_at_step * decay_steps_perc)
# the step at which to start the decay
start_decay_step = decay_at_step - decay_steps

if current_step < warmup_steps:
# warmup phase
curr_adjustment = current_step / warmup_steps
elif current_step < start_decay_step:
# stable phase, no adjustment to lr
curr_adjustment = 1.0
else:
# decay phase supporting multiple decay functions
curr_adjustment = decay_type.value(decay_steps, current_step, start_decay_step)

return curr_adjustment


def build_lr_schedulers(optimizers, job_config: JobConfig) -> LambdaLR:
def _build_lr_scheduler(optimizer):
"""Build a linear warmup optionally stable and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
post_warmup_steps = float(max(1, job_config.training.steps - warmup_steps))

# If decay steps is not set in config, decay will begin immediately after warmup
decay_steps = job_config.training.decay_steps if job_config.training.decay_steps else post_warmup_steps
decay_steps = (
job_config.training.decay_steps
if job_config.training.decay_steps
else post_warmup_steps
)
decay_steps_perc = job_config.training.decay_steps_perc
num_decays = job_config.training.num_decays
decay_type = Decay.from_string(job_config.training.decay_type)

# lr_lambda = functools.partial(
# warmup_stable_decay, decay_type, warmup_steps, decay_steps, job_config.training.steps
# )
lr_lambda = functools.partial(
warmup_stable_decay, decay_type ,warmup_steps, decay_steps, job_config.training.steps
warmup_stable_decay_simplified,
decay_type,
warmup_steps,
decay_steps_perc,
num_decays,
job_config.training.steps,
)
warmup_stable_decay_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_stable_decay_scheduler
Expand All @@ -139,7 +201,8 @@ def __init__(self, schedulers):
def step(self):
for schedulers in self.schedulers:
schedulers.step()
@property

@property
def last_lr(self):
return self.schedulers[0].get_last_lr()[0]

Expand Down
12 changes: 7 additions & 5 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ lr = 8e-4

[training]
batch_size = 1
gradient_accumulation_steps = 48
gradient_accumulation_steps = 1
seq_len = 2048
warmup_steps = 20 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 100
warmup_steps = 5 # lr scheduler warm up, normally 20% of the train steps
steps = 200
decay_steps_perc = 0.1
num_decays = 4
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
Expand All @@ -56,12 +58,12 @@ enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = false
# load_folder = "yerevann/Llama-debug/b00ef18db9d447ff84b9035a"
# load_folder = "yerevann/Llama-debug/bab005ed36ef4e02a3e62333"
save_folder = "yerevann/Llama-debug"
# load_at_step = 100
create_seed_checkpoint = false
interval_type = "steps"
interval = 50
interval = 100
model_weights_only = false
export_dtype = "float32"
async_mode = "async_with_pinned_mem" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down
Loading
Loading