Skip to content

Commit

Permalink
perplexity
Browse files Browse the repository at this point in the history
  • Loading branch information
FilyaGeikyan committed Aug 27, 2024
1 parent 5c078f3 commit 73d458f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
global_avg_perplexities: List[float] = field(default_factory=list)
global_max_perplexities: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
Expand Down
14 changes: 14 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,31 @@ def loss_fn(pred, labels):
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [loss.item() for loss in losses_since_last_log]

perplexities = [2 ** loss.item() for loss in losses_since_last_log]

avg_loss, max_loss = sum(losses) / len(losses), max(losses)
avg_perplexity, max_perplexity = sum(perplexities) / len(perplexities), max(perplexities)

if parallel_dims.dp_enabled:
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, dp_mesh),
utils.dist_max(max_loss, dp_mesh),
)
global_avg_perplexity, global_max_perplexity = (
utils.dist_mean(avg_perplexity, dp_mesh),
utils.dist_max(max_perplexity, dp_mesh),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss
global_avg_perplexity, global_max_perplexity = avg_perplexity, max_perplexity

# update train state
train_state.log_steps.append(train_state.step)
train_state.global_avg_losses.append(global_avg_loss)
train_state.global_max_losses.append(global_max_loss)
train_state.global_avg_perplexities.append(global_avg_perplexity)
train_state.global_max_perplexities.append(global_max_perplexity)

time_delta = time.perf_counter() - time_last_log

Expand All @@ -317,6 +329,8 @@ def loss_fn(pred, labels):
metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
"loss_metrics/global_max_loss": global_max_loss,
"loss_metrics/global_avg_perplexity": global_avg_perplexity,
"loss_metrics/global_max_perplexity": global_max_perplexity,
"wps": wps,
"mfu(%)": mfu,
"time_metrics/end_to_end(s)": time_end_to_end,
Expand Down
4 changes: 2 additions & 2 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ log_freq = 1
enable_color_printing = true
enable_aim = true
save_aim_folder = "aim"
aim_hash = "bfccae7dda0640f89e390e9a"
aim_hash = "c6b4d8b340f74287b82ef928"

[model]
name = "llama3"
Expand All @@ -35,7 +35,7 @@ batch_size = 8
seq_len = 2048
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 20
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
Expand Down

0 comments on commit 73d458f

Please sign in to comment.