Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversion script to HF #57

Open
wants to merge 2 commits into
base: magma
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Conversion to Robin Done. (Refine forward)
Yuchen Lu committed Sep 24, 2023
commit 9694d7f636598cb19ed5a3a371ea753d5c580910
363 changes: 363 additions & 0 deletions tools/convert_module_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
# Copyright (c) 2023, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python tools/convert_module_to_hf.py \
--input_dir checkpoints/robin_ckpt \
--config_file checkpoints/robin_ckpt/configs/big_run_grid_8e-5_2208.yml \
checkpoints/robin_ckpt/configs/magma_pythia_410M.yml \
--output_dir checkpoints/robin_hf
"""
import os
import sys

import yaml
import argparse
from tqdm import tqdm
from typing import List

import torch
from transformers.models.robin import RobinConfig, RobinForCausalLM

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from megatron.tokenizer import build_tokenizer



def load_partitions(
input_checkpoint_path, mp_partitions, layer_idx
) -> List[torch.Tensor]:
"""Returns a list containing all weights in a given layer from a model (across MP partitions)"""

loaded_tp_ranks = [
torch.load(
os.path.join(
input_checkpoint_path,
f"layer_{layer_idx:02}-model_{i:02}-model_states.pt",
),
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
for i in range(mp_partitions)
]

return loaded_tp_ranks


def get_key(loaded_config, key, default=None):
"""
Search for a given key in a NeoX yaml. normalizes underscores -> hyphens
"""
key = key.replace("_", "-")
try:
return loaded_config[key]
except KeyError:
key = key.replace("-", "_")
try:
return loaded_config[key]
except KeyError:
return default


def create_config(neox_config):
"""take in a loaded yaml from NeoX and assign relevant values to HF config.
Returns: RobinConfig() object
"""

class TokenizerArgs:
# kinda hacky.
# this is to get something with the same interface as is used in build_tokenizer()
# without diving into loading a neox_args object or using argparse etc.
def __init__(self, neox_config):
self.make_vocab_size_divisible_by = get_key(
neox_config, "make-vocab-size-divisible-by", default=128
)
self.model_parallel_size = get_key(neox_config, "model-parallel-size")
self.vocab_file = get_key(neox_config, "vocab-file")
self.merge_file = get_key(neox_config, "merge-file")
self.tokenizer_type = get_key(neox_config, "tokenizer-type")
self.seq_length = get_key(neox_config, "seq_length")
self.rank = 0
args = TokenizerArgs(neox_config)
tokenizer = build_tokenizer(args)
try: # GPT2TokenizerFast raises NotImplementedError
pad_token = tokenizer.pad
except:
pad_token = (
1 # pad defaulting to 1. follows convention from GPT-NeoX-20b tokenizer
)

# TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default
use_tied_lns = get_key(neox_config, "gpt-j-tied", False)

if use_tied_lns:
raise NotImplementedError(
"""ERROR: Huggingface Transformers does not yet support a single shared layernorm
per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals.
See https://github.com/EleutherAI/gpt-neox/pull/481 for further details."""
)


# set all config values.
hf_config = RobinConfig(
vocab_size=args.padded_vocab_size,
hidden_size=get_key(neox_config, "hidden-size"),
num_hidden_layers=get_key(neox_config, "num-layers"),
num_attention_heads=get_key(neox_config, "num-attention-heads"),
intermediate_size=(get_key(neox_config, "hidden-size") * 4),
hidden_act=get_key(neox_config, "activation", default="gelu"),
rotary_pct=get_key(neox_config, "rotary-pct", default=1.0),
rotary_emb_base=get_key(neox_config, "rotary-emb-base", default=10000),
max_position_embeddings=get_key(neox_config, "max-position-embeddings"),
initializer_range=get_key(neox_config, "init-method-std", 0.02),
layer_norm_eps=get_key(neox_config, "layernorm-epsilon", 1e-5),
use_cache=True,
bos_token_id=tokenizer.eod,
eos_token_id=tokenizer.eod,
tie_word_embeddings=(not get_key(neox_config, "no-weight-tying", False)),
use_parallel_residual=get_key(neox_config, "gpt-j-residual", False),

# image prefix config
encoder_name=get_key(neox_config, "encoder_name"),
pretrained_img_encoder=get_key(neox_config, "pretrained_img_encoder"),
load_clip=get_key(neox_config, "load_clip"),
image_embed_dropout_prob=get_key(neox_config, "image_embed_dropout_prob"),
use_image_embed_layernorm=get_key(neox_config, "use_image_embed_layernorm"),

# adaptor config
adapter_downsample_factor=get_key(neox_config, "adaper_downsample_factor")
)
return hf_config


def convert(input_checkpoint_path, loaded_config, output_checkpoint_path):
"""convert a NeoX checkpoint to a HF model format.
should perform model-parallel merging correctly
but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings)
"""
hf_config = create_config(loaded_config)
hf_model = RobinForCausalLM(hf_config)

# save model in fp16/bf16 if Deepspeed fp16 or bf16 mixed precision was used in config, else 32 bit weights
fp16 = get_key(loaded_config, "fp16")
if fp16:
try:
# this conditional is quite messy because there were a number of ways to specify bf16 or fp16 training
# in DeeperSpeed v1.0 .
if (fp16.get("fp16", None) or fp16["enabled"]) and not (fp16.get("type", None) == "bfloat16"):
hf_model.half()
print("Saving weights in fp16 precision...")
elif fp16.get("type", None) == "bfloat16":
hf_model.to(dtype=torch.bfloat16)
print("Saving weights in bf16 precision...")
except:
print("Model not trained in fp16 / bf16 mixed precision, saving weights in fp32...")

mp_partitions = get_key(loaded_config, "model-parallel-size")

### Embedding layer ###
loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, 0)
hf_model.gpt_neox.embed_in.load_state_dict(
{
"weight": torch.cat(
[t["word_embeddings.weight"] for t in loaded_tp_ranks], dim=0
)
}
)

assert (
hf_config.vocab_size == hf_model.gpt_neox.embed_in.weight.shape[0]
), f"ERROR: calculated vocab size {hf_config.vocab_size} != embed param size {hf_model.gpt_neox.embed_in.shape[0]}"

# Load image prefix.
# No TP here. So just take one.
# The image prefix state_dict are inside the emb layer
IMG_PREFIX_NAME = 'image_prefix.'
img_prefix_state_dict = {}
for k in loaded_tp_ranks[0]:
if k.startswith(IMG_PREFIX_NAME):
new_key = k[k.find(IMG_PREFIX_NAME) + len(IMG_PREFIX_NAME): ]
img_prefix_state_dict[new_key] = loaded_tp_ranks[0][k]
hf_model.gpt_neox.image_prefix.load_state_dict(img_prefix_state_dict, strict=True)

### End Embedding Layer ###
for layer_i in tqdm(range(get_key(loaded_config, "num-layers"))):

# get layer from hf model
hf_layer = hf_model.gpt_neox.layers[layer_i]

# + 2 bc of embed layer and a dummy _pre_transformer_block
loaded_tp_ranks = load_partitions(
input_checkpoint_path, mp_partitions, layer_i + 2
)

state_dict = {}
# Used in mpu.RowParallelLinear
for key in [
"attention.attn_block.dense.weight",
"mlp.attn_block.dense_4h_to_h.weight",
]:
state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=1)

# average layernorm stats over mp ranks
for key in [
"input_layernorm.weight",
"input_layernorm.bias",
"post_attention_layernorm.weight",
"post_attention_layernorm.bias",
]:
state_dict[key] = (sum([t[key] for t in loaded_tp_ranks])) / len(
loaded_tp_ranks
)

# LinearWithTPMerge
# Used in mpu.ColumnParallelLinear
for key in [
"mlp.attn_block.dense_h_to_4h.weight",
"mlp.attn_block.dense_h_to_4h.bias",
"attention.attn_block.query_key_value.weight",
"attention.attn_block.query_key_value.bias",
]:
state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=0)

# LinearWithTPSplitBias RowParallelLinear
for key in [
"mlp.attn_block.dense_4h_to_h.bias",
"attention.attn_block.dense.bias",
]:
state_dict[key] = sum([t[key] for t in loaded_tp_ranks])

# Just take one
state_dict["attention.attn_block.rotary_emb.inv_freq"] = loaded_tp_ranks[0][
"attention.attn_block.rotary_emb.inv_freq"
]

# Adaptor currently have to TP strategy. Just take one
for key in [
'attention.adapter.2.weight',
'attention.adapter.0.bias',
'mlp.adapter.2.bias',
'attention.adapter.2.bias',
'attention.adapter.0.weight',
'mlp.adapter.2.weight',
'mlp.adapter.0.bias',
'mlp.adapter.0.weight',
]:
state_dict[key] = loaded_tp_ranks[0][key]

# load state_dict into layer
hf_layer.load_state_dict(state_dict, strict=True)

# Load final layer norm
loaded_tp_ranks = load_partitions(
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3
)

hf_model.gpt_neox.final_layer_norm.load_state_dict(
{
"weight": (sum([t["norm.weight"] for t in loaded_tp_ranks]))
/ len(loaded_tp_ranks),
"bias": (sum([t["norm.bias"] for t in loaded_tp_ranks]))
/ len(loaded_tp_ranks),
}
)
del loaded_tp_ranks

# Load output embedding
loaded_tp_ranks = load_partitions(
input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4
)
hf_model.embed_out.load_state_dict(
{
"weight": torch.cat(
[t["final_linear.weight"] for t in loaded_tp_ranks], dim=0
),
},
strict=True
)

del loaded_tp_ranks

return hf_model


if __name__ == "__main__":

# before running script:
# `pip install --upgrade transformers`
# `huggingface-cli login`
#
from huggingface_hub import create_repo, HfApi

parser = argparse.ArgumentParser(
description="Merge MP partitions and convert to HF Model."
)
parser.add_argument(
"--input_dir",
type=str,
help="Path to NeoX checkpoint, e.g. /path/to/model/global_step143000",
)
parser.add_argument(
"--config_file",
type=str,
nargs="+",
help="Path to config file for the input NeoX checkpoint.",
)
parser.add_argument(
"--output_dir",
type=str,
help="Output dir, where to save the HF Model, tokenizer, and configs",
)
parser.add_argument(
"--upload",
action="store_true",
help="Set to true in order to upload to the HF Hub directly.",
)
args = parser.parse_args()

loaded_config = {}
for config_file in args.config_file:
with open(config_file) as f:
loaded = yaml.full_load(f)
for k,v in loaded.items():
loaded_config[k] = v

hf_model = convert(args.input_dir, loaded_config, args.output_dir)

hf_model.save_pretrained(args.output_dir)

# save tokenizer to directory as well, for easy loading of model as a HF model
tokenizer_type = get_key(loaded_config, "tokenizer-type")

if tokenizer_type == "HFTokenizer":
print(f"saving tokenizer from file {get_key(loaded_config, 'vocab-file')}")
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
tokenizer_file=get_key(loaded_config, "vocab-file")
)
print("loaded tokenizer: ", tokenizer)
tokenizer.save_pretrained(args.output_dir)
print("tokenizer saved!")

if args.upload:
repo_name = input("Provide a repository name for the HF Hub: ")
create_repo(repo_name, repo_type="model", private=False, use_auth_token=True)

api = HfApi()
api.upload_folder(
folder_path=args.output_dir,
repo_id=repo_name,
repo_type="model",
)