From 73d458f823bf080ac0c0c74bc5be510066708b17 Mon Sep 17 00:00:00 2001 From: Filya Geikyan Date: Tue, 27 Aug 2024 15:22:00 +0400 Subject: [PATCH] perplexity --- torchtitan/checkpoint.py | 2 ++ train.py | 14 ++++++++++++++ train_configs/debug_model.toml | 4 ++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index b71419c6..64d2f007 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -48,6 +48,8 @@ class TrainState(Stateful): step: int = 0 global_avg_losses: List[float] = field(default_factory=list) global_max_losses: List[float] = field(default_factory=list) + global_avg_perplexities: List[float] = field(default_factory=list) + global_max_perplexities: List[float] = field(default_factory=list) log_steps: List[int] = field(default_factory=list) def state_dict(self) -> Dict[str, Any]: diff --git a/train.py b/train.py index 216d28de..43d830b7 100644 --- a/train.py +++ b/train.py @@ -283,19 +283,31 @@ def loss_fn(pred, labels): or train_state.step % job_config.metrics.log_freq == 0 ): losses = [loss.item() for loss in losses_since_last_log] + + perplexities = [2 ** loss.item() for loss in losses_since_last_log] + avg_loss, max_loss = sum(losses) / len(losses), max(losses) + avg_perplexity, max_perplexity = sum(perplexities) / len(perplexities), max(perplexities) + if parallel_dims.dp_enabled: global_avg_loss, global_max_loss = ( utils.dist_mean(avg_loss, dp_mesh), utils.dist_max(max_loss, dp_mesh), ) + global_avg_perplexity, global_max_perplexity = ( + utils.dist_mean(avg_perplexity, dp_mesh), + utils.dist_max(max_perplexity, dp_mesh), + ) else: global_avg_loss, global_max_loss = avg_loss, max_loss + global_avg_perplexity, global_max_perplexity = avg_perplexity, max_perplexity # update train state train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) + train_state.global_avg_perplexities.append(global_avg_perplexity) + train_state.global_max_perplexities.append(global_max_perplexity) time_delta = time.perf_counter() - time_last_log @@ -317,6 +329,8 @@ def loss_fn(pred, labels): metrics = { "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, + "loss_metrics/global_avg_perplexity": global_avg_perplexity, + "loss_metrics/global_max_perplexity": global_max_perplexity, "wps": wps, "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index e8d58839..ff229e3f 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -17,7 +17,7 @@ log_freq = 1 enable_color_printing = true enable_aim = true save_aim_folder = "aim" -aim_hash = "bfccae7dda0640f89e390e9a" +aim_hash = "c6b4d8b340f74287b82ef928" [model] name = "llama3" @@ -35,7 +35,7 @@ batch_size = 8 seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 20 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = false