Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinGleize committed Dec 24, 2024
1 parent bce9212 commit 87d8a58
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 99 deletions.
36 changes: 0 additions & 36 deletions src/fairseq2/checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping, Set
from contextlib import AbstractContextManager, nullcontext
import json
from pathlib import Path
from shutil import rmtree
from typing import final
Expand Down Expand Up @@ -95,21 +94,7 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None:
:param metadata:
The metadata to save. Must be pickeable.
"""

@abstractmethod
def save_json_dict(
self,
output_name: str,
json_dict: Mapping[str, object],
) -> None:
"""Save a collection of key-values in JSON format, associated with the checkpoint.

:param output_name:
The name of the output json artifact.
:param json_dict:
The key-values to save. Must be json.dumps-able.
"""

@abstractmethod
def save_score(self, score: float | None) -> None:
"""Save the score of the checkpoint."""
Expand Down Expand Up @@ -459,27 +444,6 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None:
) from ex

self._root_gang.barrier()

@override
def save_json_dict(
self,
output_name: str,
json_dict: Mapping[str, object],
) -> None:
to_write = json.dumps(json_dict, indent=2, sort_keys=True) + "\n"

if self._root_gang.rank == 0:
json_file = self._checkpoint_dir.joinpath(f"{output_name}")

try:
with json_file.open("w") as fp:
fp.write(to_write)
except OSError as ex:
raise CheckpointError(
f"The JSON file named {output_name} cannot be saved at training step {step_nr}. See the nested exception for details."
) from ex

self._root_gang.barrier()

@override
def save_score(self, score: float | None) -> None:
Expand Down
46 changes: 1 addition & 45 deletions src/fairseq2/models/llama/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

import math
from dataclasses import dataclass, field
from typing import Any, Final
from typing import Final

import torch
from torch import Tensor

from fairseq2.config_registry import ConfigRegistry
from fairseq2.data import VocabularyInfo
from fairseq2.models.factory import model_factories
from fairseq2.models.llama.integ import get_ffn_dim_multipliers
from fairseq2.models.transformer import (
TransformerDecoderModel,
TransformerEmbeddingFrontend,
Expand Down Expand Up @@ -325,46 +324,3 @@ def get_llama_lora_config() -> LoRAConfig:
dropout_p=0.05,
keys=[".*decoder.layers.*.self_attn.*(q_proj|v_proj)$"],
)


def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]:

def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
"""From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171"""
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)

if config.use_scaled_rope:
rope_scaling = {
"factor": 32.0 if "3_2" in arch else 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
else:
# mgleize: I'm not sure what's the to_json behavior is if rope_scaling ever is None
rope_scaling = None

# we only specify the parameters made explicit in the Huggingface converter
# https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384
return {
"architectures": ["Fairseq2LlamaForCausalLM"],
"bos_token_id": config.vocab_info.bos_idx,
"eos_token_id": config.vocab_info.eos_idx,
"hidden_size": config.model_dim,
"intermediate_size": compute_intermediate_size(
config.model_dim,
get_ffn_dim_multipliers(arch),
config.ffn_inner_dim_to_multiple,
),
"max_position_embeddings": config.max_seq_len,
"model_type": "llama",
"num_attention_heads": config.num_attn_heads,
"num_hidden_layers": config.num_layers,
"num_key_value_heads": config.num_key_value_heads,
"rms_norm_eps": 1e-5,
"rope_scaling": rope_scaling,
"rope_theta": config.rope_theta,
"tie_word_embeddings": "3_2" in arch,
"vocab_size": config.vocab_info.size,
}
56 changes: 54 additions & 2 deletions src/fairseq2/models/llama/integ.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from typing import Any

from fairseq2.models.llama.factory import LLaMAConfig
from fairseq2.models.utils.checkpoint import convert_model_state_dict


def get_ffn_dim_multipliers(architecture: str) -> float:
# we only specify archs where multiplier != 1.0
ffn_dim_multipliers = {
"llama2_70b": 1.3,
"llama3_8b": 1.3,
Expand All @@ -22,7 +22,7 @@ def get_ffn_dim_multipliers(architecture: str) -> float:
"llama3_1_405b": 1.2,
"llama3_2_1b": 1.5,
}

return ffn_dim_multipliers.get(architecture, 1.0)


Expand Down Expand Up @@ -53,3 +53,55 @@ def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any
}

return convert_model_state_dict(state_dict, key_map)


def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]:
"""Convert Llama's config to a dict mirroring Huggingface's format"""

def compute_intermediate_size(
n: int, ffn_dim_multiplier: float = 1, multiple_of: int = 256
) -> int:
"""From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171"""
return multiple_of * (
(int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of
)

def is_llama_3_2(arch: str) -> bool:
# TODO: this seems too britle
return "llama3_2_" in arch

if config.use_scaled_rope:
rope_scaling = {
"factor": 32.0 if is_llama_3_2(arch) else 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
else:
# mgleize: not sure of the json.dump behavior if rope_scaling is None
rope_scaling = None

# we only specify the parameters made explicit in the Huggingface converter
# https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384
return {
"architectures": ["Fairseq2LlamaForCausalLM"],
"bos_token_id": config.vocab_info.bos_idx,
"eos_token_id": config.vocab_info.eos_idx,
"hidden_size": config.model_dim,
"intermediate_size": compute_intermediate_size(
config.model_dim,
get_ffn_dim_multipliers(arch),
config.ffn_inner_dim_to_multiple,
),
"max_position_embeddings": config.max_seq_len,
"model_type": "llama",
"num_attention_heads": config.num_attn_heads,
"num_hidden_layers": config.num_layers,
"num_key_value_heads": config.num_key_value_heads,
"rms_norm_eps": 1e-5,
"rope_scaling": rope_scaling,
"rope_theta": config.rope_theta,
"tie_word_embeddings": is_llama_3_2(arch),
"vocab_size": config.vocab_info.size,
}
30 changes: 14 additions & 16 deletions src/fairseq2/recipes/llama/write_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from fairseq2.assets import default_asset_store
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.factory import convert_to_huggingface_config
from fairseq2.models.llama.integ import convert_to_huggingface_config
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.console import get_error_console
from fairseq2.setup import setup_fairseq2

log = get_log_writer(__name__)
Expand Down Expand Up @@ -57,28 +56,27 @@ def run(self, args: Namespace) -> int:
model_config = None

if model_config is None:
log.error("Model config could not be retrieved for model {}", args.model)
log.error("Config could not be retrieved for model {}", args.model)

sys.exit(1)

args.output_dir.mkdir(parents=True, exist_ok=True)

# Convert and write the config
with get_error_console().status("[bold green]Writing config...") as status:
log.info("Writing config...")

config = convert_to_huggingface_config(arch, model_config)
to_write = json.dumps(config, indent=2, sort_keys=True) + "\n"
config = convert_to_huggingface_config(arch, model_config)

json_file = args.output_dir.joinpath("config.json")
json_file = args.output_dir.joinpath("config.json")

try:
with json_file.open("w") as fp:
fp.write(to_write)
except OSError as ex:
raise RuntimeError(
f"The file config.json cannot be saved. See the nested exception for details."
) from ex
try:
with json_file.open("w") as fp:
json.dump(config, fp, indent=2, sort_keys=True)
except OSError as ex:
raise RuntimeError(
f"The file {json_file} cannot be saved. See the nested exception for details."
) from ex

log.info("Config converted and saved in {}", json_file)
log.info("Config converted and saved in {}", json_file)

return 0

0 comments on commit 87d8a58

Please sign in to comment.