-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added falcon model converter #2040
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
|
||
from keras_hub.src.models.falcon import FalconBackbone | ||
from keras_hub.src.utils.preset_utils import load_json | ||
|
||
backbone_cls = FalconBackbone | ||
|
||
|
||
def convert_backbone_config(transformers_config): | ||
return { | ||
"vocabulary_size": transformers_config["vocab_size"], | ||
"num_layers": transformers_config["num_hidden_layers"], | ||
"num_attention_heads": transformers_config["num_attention_heads"], | ||
"hidden_dim": transformers_config["hidden_size"], | ||
"intermediate_dim": 32 * 4, | ||
} | ||
|
||
|
||
def transpose_and_reshape(x, shape): | ||
return np.reshape(np.transpose(x), shape) | ||
|
||
|
||
def convert_weights(backbone, loader, transformers_config): | ||
# Embeddings | ||
loader.port_weight( | ||
keras_variable=backbone.get_layer("token_embedding").embeddings, | ||
hf_weight_key="word_embeddings.weight", | ||
) | ||
|
||
for i in range(backbone.num_layers): | ||
decoder_layer = backbone.get_layer(f"transformer_layer_{i}") | ||
|
||
# Norm layer | ||
loader.port_weight( | ||
keras_variable=decoder_layer.input_layernorm.gamma, | ||
hf_weight_key=f"h.{i}.input_layernorm.weight", | ||
) | ||
|
||
# Attention layers | ||
loader.port_weight( | ||
keras_variable=decoder_layer.attention_layer.output_dense.kernel, | ||
hf_weight_key=f"h.{i}.self_attention.dense.weight", | ||
) | ||
|
||
loader.port_weight( | ||
keras_variable=decoder_layer.post_attention_layernorm.gamma, | ||
hf_weight_key=f"h.{i}.self_attention.query_key_value.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.mean( | ||
np.reshape(hf_tensor, (-1, keras_shape[0])), axis=0 | ||
), | ||
) | ||
|
||
|
||
def convert_tokenizer(cls, preset, **kwargs): | ||
tokenizer_data = load_json(preset, "tokenizer.json") | ||
vocab = tokenizer_data["model"]["vocab"] | ||
merges = tokenizer_data["model"].get("merges", None) | ||
|
||
tokenizer_kwargs = {"vocabulary": vocab, "merges": merges} | ||
return cls(**tokenizer_kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pytest | ||
|
||
from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone | ||
from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class TestTask(TestCase): | ||
@pytest.mark.large | ||
def test_convert_tiny_preset(self): | ||
model = FalconCausalLM.from_preset("hf://tiiuae/falcon-7b") | ||
prompt = "What is your favorite condiment?" | ||
model.generate([prompt], max_length=15) | ||
|
||
@pytest.mark.large | ||
def test_class_detection(self): | ||
model = FalconCausalLM.from_preset("hf://tiiuae/falcon-7b") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work? I think we only have Falcon-1b support! 7b model has a different attention mechanism which hasn't been added! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably also attach a colab verifying that output from the huggingface and KerasHub versions align. And sound like that might actually run into differences here due to what @SamanehSaadat is saying. @SamanehSaadat how much work is needed of the architecture code to support the 7 and other variants? Is it something that could be added here or a ton to do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mattdangerw I think adding support for the 7b is non-trivial. There are some major architectural differences like alibi, GQA vs. MHA, and rotary embedding (to me, it's almost like adding a new architecture!). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Sounds like we will need to either throw in the converter if we encounter the falcon huggingface options we don't currently support, or add them in (on a separate pr?). @mehtamansi29 we'd probably need a colab verifying that the output matches for some subset of falcon checkpoints on huggingface, and ideally that we throw for falcon checkpoints that needs arch options we don't yet support. |
||
self.assertIsInstance(model, FalconCausalLM) | ||
model = FalconBackbone.from_preset( | ||
"hf://tiiuae/falcon-7b", | ||
load_weights=False, | ||
) | ||
self.assertIsInstance(model, FalconBackbone) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can afford to download this ~15gb file in our testing setup. You could try the 1b model? Or create a small test model on hf, as was done for llama and others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw - I'll create small test with 1b falcon model and commit again.