diff --git a/torchtitan/models/llama/utils.py b/torchtitan/models/llama/utils.py index 79a22af3..780d01f1 100644 --- a/torchtitan/models/llama/utils.py +++ b/torchtitan/models/llama/utils.py @@ -1,7 +1,9 @@ + from transformers import AutoModelForCausalLM import torch from torchtitan.models.llama import Transformer from torchtitan.logging import logger +import os # reverse_permute for sliced rotary @@ -101,7 +103,7 @@ def map_n_layers_to_model_name(n_layers): }[n_layers] -def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_size: int): +def export_llama3_weights(model: Transformer, save_dir, token_embedding_size: int): """ write docs """ @@ -131,7 +133,7 @@ def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_siz assert hf_model.state_dict()[value].shape == state_dict[key].shape corrected_state_dict["lm_head.weight"] = state_dict["tok_embeddings.weight"] - # hf_model.load_state_dict(corrected_state_dict) + hf_model.load_state_dict(corrected_state_dict) # from transformers import AutoTokenizer # tok = AutoTokenizer.from_pretrained(weights_path) # device = "cuda" @@ -143,4 +145,4 @@ def export_llama3_weights(model: Transformer, save_dir: str, token_embedding_siz # logits = model(data.input_ids) # print(torch.allclose(hf_logits, logits, atol=1e-2)) hf_model.save_pretrained(save_dir) - logger.info("Successfully exported Llama 3 model to huggingface model.") \ No newline at end of file + logger.info(f"Successfully exported Llama 3 model to huggingface model at {save_dir}.") diff --git a/train.py b/train.py index b67ccc3b..0895a27d 100644 --- a/train.py +++ b/train.py @@ -234,9 +234,7 @@ def loss_fn(pred, labels): ), "Must create seed-checkpoint using one gpu, to disable sharding" model_name_to_weights_export_fns[model_name]( model, - save_dir=os.path.join( - job_config.job.dump_folder, job_config.checkpoint.save_folder - ), + save_dir=checkpoint._create_checkpoint_id(job_config.checkpoint.load_at_step, checkpoint.save_folder), token_embedding_size=model_config.vocab_size, ) logger.info("Created huggingface checkpoint") diff --git a/train_configs/llama3.2_1b_conversion.toml b/train_configs/llama3.2_1b_conversion.toml index cfc1f71b..78fa7bd6 100644 --- a/train_configs/llama3.2_1b_conversion.toml +++ b/train_configs/llama3.2_1b_conversion.toml @@ -54,7 +54,7 @@ enable_checkpoint = true # load_folder = "meta-llama/Llama-3.2-1B" # save_folder = "hf/meta-llama/Llama-3.2-1B" load_folder = "yerevann/Llama-3.2-1B/145594a229f3458883c13c47" -load_at_step = 18000 +load_at_step = 20000 save_folder = "hf/yerevann/Llama-3.2-1B/145594a229f3458883c13c47" interval_type = "steps" interval = 1000