diff --git a/src/fairseq2/checkpoint/manager.py b/src/fairseq2/checkpoint/manager.py index a3c8b0fd0..0616b1e66 100644 --- a/src/fairseq2/checkpoint/manager.py +++ b/src/fairseq2/checkpoint/manager.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterator, Mapping, Set from contextlib import AbstractContextManager, nullcontext -import json from pathlib import Path from shutil import rmtree from typing import final @@ -95,21 +94,7 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None: :param metadata: The metadata to save. Must be pickeable. """ - - @abstractmethod - def save_json_dict( - self, - output_name: str, - json_dict: Mapping[str, object], - ) -> None: - """Save a collection of key-values in JSON format, associated with the checkpoint. - :param output_name: - The name of the output json artifact. - :param json_dict: - The key-values to save. Must be json.dumps-able. - """ - @abstractmethod def save_score(self, score: float | None) -> None: """Save the score of the checkpoint.""" @@ -459,27 +444,6 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None: ) from ex self._root_gang.barrier() - - @override - def save_json_dict( - self, - output_name: str, - json_dict: Mapping[str, object], - ) -> None: - to_write = json.dumps(json_dict, indent=2, sort_keys=True) + "\n" - - if self._root_gang.rank == 0: - json_file = self._checkpoint_dir.joinpath(f"{output_name}") - - try: - with json_file.open("w") as fp: - fp.write(to_write) - except OSError as ex: - raise CheckpointError( - f"The JSON file named {output_name} cannot be saved at training step {step_nr}. See the nested exception for details." - ) from ex - - self._root_gang.barrier() @override def save_score(self, score: float | None) -> None: diff --git a/src/fairseq2/models/llama/factory.py b/src/fairseq2/models/llama/factory.py index b636138be..5b81ee538 100644 --- a/src/fairseq2/models/llama/factory.py +++ b/src/fairseq2/models/llama/factory.py @@ -8,7 +8,7 @@ import math from dataclasses import dataclass, field -from typing import Any, Final +from typing import Final import torch from torch import Tensor @@ -16,7 +16,6 @@ from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo from fairseq2.models.factory import model_factories -from fairseq2.models.llama.integ import get_ffn_dim_multipliers from fairseq2.models.transformer import ( TransformerDecoderModel, TransformerEmbeddingFrontend, @@ -325,46 +324,3 @@ def get_llama_lora_config() -> LoRAConfig: dropout_p=0.05, keys=[".*decoder.layers.*.self_attn.*(q_proj|v_proj)$"], ) - - -def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]: - - def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): - """From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171""" - return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) - - if config.use_scaled_rope: - rope_scaling = { - "factor": 32.0 if "3_2" in arch else 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3", - } - else: - # mgleize: I'm not sure what's the to_json behavior is if rope_scaling ever is None - rope_scaling = None - - # we only specify the parameters made explicit in the Huggingface converter - # https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384 - return { - "architectures": ["Fairseq2LlamaForCausalLM"], - "bos_token_id": config.vocab_info.bos_idx, - "eos_token_id": config.vocab_info.eos_idx, - "hidden_size": config.model_dim, - "intermediate_size": compute_intermediate_size( - config.model_dim, - get_ffn_dim_multipliers(arch), - config.ffn_inner_dim_to_multiple, - ), - "max_position_embeddings": config.max_seq_len, - "model_type": "llama", - "num_attention_heads": config.num_attn_heads, - "num_hidden_layers": config.num_layers, - "num_key_value_heads": config.num_key_value_heads, - "rms_norm_eps": 1e-5, - "rope_scaling": rope_scaling, - "rope_theta": config.rope_theta, - "tie_word_embeddings": "3_2" in arch, - "vocab_size": config.vocab_info.size, - } diff --git a/src/fairseq2/models/llama/integ.py b/src/fairseq2/models/llama/integ.py index 6db23a15f..b3f1170e3 100644 --- a/src/fairseq2/models/llama/integ.py +++ b/src/fairseq2/models/llama/integ.py @@ -8,11 +8,11 @@ from typing import Any +from fairseq2.models.llama.factory import LLaMAConfig from fairseq2.models.utils.checkpoint import convert_model_state_dict def get_ffn_dim_multipliers(architecture: str) -> float: - # we only specify archs where multiplier != 1.0 ffn_dim_multipliers = { "llama2_70b": 1.3, "llama3_8b": 1.3, @@ -22,7 +22,7 @@ def get_ffn_dim_multipliers(architecture: str) -> float: "llama3_1_405b": 1.2, "llama3_2_1b": 1.5, } - + return ffn_dim_multipliers.get(architecture, 1.0) @@ -53,3 +53,55 @@ def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any } return convert_model_state_dict(state_dict, key_map) + + +def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]: + """Convert Llama's config to a dict mirroring Huggingface's format""" + + def compute_intermediate_size( + n: int, ffn_dim_multiplier: float = 1, multiple_of: int = 256 + ) -> int: + """From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171""" + return multiple_of * ( + (int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of + ) + + def is_llama_3_2(arch: str) -> bool: + # TODO: this seems too britle + return "llama3_2_" in arch + + if config.use_scaled_rope: + rope_scaling = { + "factor": 32.0 if is_llama_3_2(arch) else 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + } + else: + # mgleize: not sure of the json.dump behavior if rope_scaling is None + rope_scaling = None + + # we only specify the parameters made explicit in the Huggingface converter + # https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384 + return { + "architectures": ["Fairseq2LlamaForCausalLM"], + "bos_token_id": config.vocab_info.bos_idx, + "eos_token_id": config.vocab_info.eos_idx, + "hidden_size": config.model_dim, + "intermediate_size": compute_intermediate_size( + config.model_dim, + get_ffn_dim_multipliers(arch), + config.ffn_inner_dim_to_multiple, + ), + "max_position_embeddings": config.max_seq_len, + "model_type": "llama", + "num_attention_heads": config.num_attn_heads, + "num_hidden_layers": config.num_layers, + "num_key_value_heads": config.num_key_value_heads, + "rms_norm_eps": 1e-5, + "rope_scaling": rope_scaling, + "rope_theta": config.rope_theta, + "tie_word_embeddings": is_llama_3_2(arch), + "vocab_size": config.vocab_info.size, + } diff --git a/src/fairseq2/recipes/llama/write_hf_config.py b/src/fairseq2/recipes/llama/write_hf_config.py index eb4167ede..4135afdb1 100644 --- a/src/fairseq2/recipes/llama/write_hf_config.py +++ b/src/fairseq2/recipes/llama/write_hf_config.py @@ -17,9 +17,8 @@ from fairseq2.assets import default_asset_store from fairseq2.logging import get_log_writer from fairseq2.models.llama import load_llama_config -from fairseq2.models.llama.factory import convert_to_huggingface_config +from fairseq2.models.llama.integ import convert_to_huggingface_config from fairseq2.recipes.cli import CliCommandHandler -from fairseq2.recipes.console import get_error_console from fairseq2.setup import setup_fairseq2 log = get_log_writer(__name__) @@ -57,28 +56,27 @@ def run(self, args: Namespace) -> int: model_config = None if model_config is None: - log.error("Model config could not be retrieved for model {}", args.model) + log.error("Config could not be retrieved for model {}", args.model) sys.exit(1) - + args.output_dir.mkdir(parents=True, exist_ok=True) # Convert and write the config - with get_error_console().status("[bold green]Writing config...") as status: + log.info("Writing config...") - config = convert_to_huggingface_config(arch, model_config) - to_write = json.dumps(config, indent=2, sort_keys=True) + "\n" + config = convert_to_huggingface_config(arch, model_config) - json_file = args.output_dir.joinpath("config.json") + json_file = args.output_dir.joinpath("config.json") - try: - with json_file.open("w") as fp: - fp.write(to_write) - except OSError as ex: - raise RuntimeError( - f"The file config.json cannot be saved. See the nested exception for details." - ) from ex + try: + with json_file.open("w") as fp: + json.dump(config, fp, indent=2, sort_keys=True) + except OSError as ex: + raise RuntimeError( + f"The file {json_file} cannot be saved. See the nested exception for details." + ) from ex - log.info("Config converted and saved in {}", json_file) + log.info("Config converted and saved in {}", json_file) return 0