Skip to content

Commit

Permalink
small fices
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Oct 7, 2024
1 parent 40ae625 commit 2139ac6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 5 additions & 3 deletions torchtitan/models/llama/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from transformers import AutoModelForCausalLM
import torch
from torchtitan.models.llama import Transformer
from torchtitan.logging import logger
import os


# reverse_permute for sliced rotary
Expand Down Expand Up @@ -101,7 +103,7 @@ def map_n_layers_to_model_name(n_layers):
}[n_layers]


def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_size: int):
def export_llama3_weights(model: Transformer, save_dir, token_embedding_size: int):
"""
write docs
"""
Expand Down Expand Up @@ -131,7 +133,7 @@ def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_siz
assert hf_model.state_dict()[value].shape == state_dict[key].shape
corrected_state_dict["lm_head.weight"] = state_dict["tok_embeddings.weight"]

# hf_model.load_state_dict(corrected_state_dict)
hf_model.load_state_dict(corrected_state_dict)
# from transformers import AutoTokenizer
# tok = AutoTokenizer.from_pretrained(weights_path)
# device = "cuda"
Expand All @@ -143,4 +145,4 @@ def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_siz
# logits = model(data.input_ids)
# print(torch.allclose(hf_logits, logits, atol=1e-2))
hf_model.save_pretrained(save_dir)
logger.info("Successfully exported Llama 3 model to huggingface model.")
logger.info(f"Successfully exported Llama 3 model to huggingface model at {save_dir}.")
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def loss_fn(pred, labels):
), "Must create seed-checkpoint using one gpu, to disable sharding"
model_name_to_weights_export_fns[model_name](
model,
save_dir=os.path.join(
job_config.job.dump_folder, job_config.checkpoint.save_folder
),
save_dir=checkpoint._create_checkpoint_id(job_config.checkpoint.load_at_step, checkpoint.save_folder),
token_embedding_size=model_config.vocab_size,
)
logger.info("Created huggingface checkpoint")
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3.2_1b_conversion.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ enable_checkpoint = true
# load_folder = "meta-llama/Llama-3.2-1B"
# save_folder = "hf/meta-llama/Llama-3.2-1B"
load_folder = "yerevann/Llama-3.2-1B/145594a229f3458883c13c47"
load_at_step = 18000
load_at_step = 20000
save_folder = "hf/yerevann/Llama-3.2-1B/145594a229f3458883c13c47"
interval_type = "steps"
interval = 1000
Expand Down

0 comments on commit 2139ac6

Please sign in to comment.