Skip to content

Commit

Permalink
Use eager mode attention implementation in Infer API (quic#54)
Browse files Browse the repository at this point in the history
* Update from_pretrained method to always use eager mode attention implemntation

Signed-off-by: Mamta Singh <[email protected]>

* Update number of layers in falcon transformed config

Signed-off-by: Mamta Singh <[email protected]>

---------

Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta authored Jun 28, 2024
1 parent 30ce13b commit 171bcd8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 40 deletions.
115 changes: 75 additions & 40 deletions QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -47,9 +47,9 @@ def convert_to_cloud_bertstyle(
if os.path.exists(onnx_dir_path):
logger.warning(f"Overriding {onnx_dir_path}")
shutil.rmtree(onnx_dir_path)

# Decide path for saving exported ONNX files.
model_name = export_bertstyle_model_to_onnx(model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len) # type: ignore
model_name = export_bertstyle_model_to_onnx(model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len) # type: ignore

# return the model path for automation.
return os.path.join(onnx_dir_path, f"{model_name}.onnx")
Expand Down Expand Up @@ -134,7 +134,7 @@ def export_bertstyle_model_to_onnx(model_name, model, tokenizer, onnx_dir_path,
inputs=inputs,
input_list_file=input_list_file,
)

return model_name


Expand All @@ -160,25 +160,33 @@ def convert_to_cloud_kvstyle(
onnx_dir_path (str, optional): The path where the model is stored. If None, the model is loaded from the default location.
seq_len (int, optional): The length of the sequence. Default is 128.
"""
warnings.warn("\033[93mThis function will be deprecated soon, use QEfficient.export instead\033[0m", DeprecationWarning, stacklevel=2)
warnings.warn(
"\033[93mThis function will be deprecated soon, use QEfficient.export instead\033[0m",
DeprecationWarning,
stacklevel=2,
)
if os.path.exists(onnx_dir_path):
logger.warning(f"Overriding {onnx_dir_path}")
shutil.rmtree(onnx_dir_path)

assert qeff_model.is_transformed, f"please pass the {qeff_model.__class__.__name__} after transform API"

# Decide path for saving exported ONNX files.
model_name = export_kvstyle_transformed_model_to_onnx(model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len) # type: ignore
model_name = export_kvstyle_transformed_model_to_onnx(
model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len
) # type: ignore

# return the model path for automation.
return os.path.join(onnx_dir_path, f"{model_name}.onnx")


def export_kvstyle_transformed_model_to_onnx(model_name: str,
transformed_model: torch.nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
onnx_dir_path: str, seq_len: int) -> str:

def export_kvstyle_transformed_model_to_onnx(
model_name: str,
transformed_model: torch.nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
onnx_dir_path: str,
seq_len: int,
) -> str:
# Disabling requires_grad on all parameters
for j, p in enumerate(transformed_model.parameters()):
p.requires_grad_(False)
Expand Down Expand Up @@ -212,12 +220,10 @@ def export_kvstyle_transformed_model_to_onnx(model_name: str,
multi_query_value = getattr(config, "multi_query")
if multi_query_value:
n_heads = 1 # MQA
d_head = config.hidden_size // config.num_attention_heads
n_layer = 1 # Due to multi query
else:
n_heads = config.num_attention_heads
d_head = config.hidden_size // config.num_attention_heads
n_layer = config.num_hidden_layers
d_head = config.hidden_size // config.num_attention_heads
n_layer = config.num_hidden_layers
else:
raise ValueError("Invalid model configuration: n_head/n_heads or num_key_value_heads not found.")
inputs["past_key_values"] = [
Expand Down Expand Up @@ -277,7 +283,7 @@ def export_kvstyle_transformed_model_to_onnx(model_name: str,

model_base_name = model_name.replace("/", "_") + "_kv"
os.makedirs(onnx_dir_path, exist_ok=True)

# Export and simplify ONNX model
model_name = export_onnx(
pt_model=transformed_model,
Expand All @@ -286,28 +292,29 @@ def export_kvstyle_transformed_model_to_onnx(model_name: str,
gen_models_path=onnx_dir_path,
model_base_name=model_base_name,
)

# Replace nested past_key_values inputs with separate KV tensors
inputs.pop("past_key_values")
for i, (key, value) in enumerate(pkv):
inputs[f"past_key.{i}"] = key
inputs[f"past_value.{i}"] = value

# Run onnxrt inference
input_names, ort_outputs = run_model_on_ort(
onnx_path=os.path.join(onnx_dir_path, f"{model_name}.onnx"),
inputs=inputs,
output_names=output_names,
pt_outputs=pt_outputs,
)

model_name = fix_onnx_fp16(
inputs=inputs,
output_names=output_names,
ort_outputs=ort_outputs,
gen_models_path=onnx_dir_path,
model_base_name=model_name,
pt_outputs=pt_outputs)
inputs=inputs,
output_names=output_names,
ort_outputs=ort_outputs,
gen_models_path=onnx_dir_path,
model_base_name=model_name,
pt_outputs=pt_outputs,
)

# Generate custom-IO files for fp16 and int8 kv
with open(os.path.join(onnx_dir_path, "custom_io_fp16.yaml"), "w") as fp:
Expand Down Expand Up @@ -335,7 +342,7 @@ def export_kvstyle_transformed_model_to_onnx(model_name: str,
inputs=inputs,
input_list_file=input_list_file,
)

return model_name


Expand All @@ -360,9 +367,14 @@ def export_for_cloud(
f"Only model type {QEFFAutoModelForCausalLM.__class__.__name__} is supported for export, got {type(qeff_model)}"
)

def export_lm_model_for_cloud(model_name:str, qeff_model: QEFFAutoModelForCausalLM,
tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
onnx_dir_path: str, seq_length: int) -> str:

def export_lm_model_for_cloud(
model_name: str,
qeff_model: QEFFAutoModelForCausalLM,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
onnx_dir_path: str,
seq_length: int,
) -> str:
if os.path.exists(onnx_dir_path):
logger.warning(f"Overriding {onnx_dir_path}")
shutil.rmtree(onnx_dir_path)
Expand All @@ -373,33 +385,34 @@ def export_lm_model_for_cloud(model_name:str, qeff_model: QEFFAutoModelForCausal
transformed_model=qeff_model.model,
tokenizer=tokenizer,
onnx_dir_path=onnx_dir_path,
seq_len=seq_length) # type: ignore
seq_len=seq_length,
) # type: ignore

else:
model_name = export_bertstyle_model_to_onnx(
model_name=model_name,
model=qeff_model.model,
tokenizer=tokenizer,
onnx_dir_path=onnx_dir_path,
seq_len=seq_length) # type: ignore
seq_len=seq_length,
) # type: ignore

# return the model path for automation.
return os.path.join(onnx_dir_path, f"{model_name}.onnx")


def qualcomm_efficient_converter(
model_name: str,
model_kv: QEFFBaseModel = None, # type: ignore
model_kv: QEFFBaseModel = None, # type: ignore
local_model_dir: Optional[str] = None,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]=None,
tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
cache_dir: Optional[str] = None,
onnx_dir_path: Optional[str] = None,
hf_token: Optional[str] = None,
seq_length: int = Constants.seq_length,
kv: bool = True,
form_factor: str="cloud",
form_factor: str = "cloud",
) -> Tuple[str, str]:

"""
Function to convert the input string using the specified model and returns the result.
Expand All @@ -419,9 +432,21 @@ def qualcomm_efficient_converter(
None, if automation is False, else path to exported Onnx file
"""
warnings.warn("\033[93mmodel_kv argument will be replaced by qeff_model of type QEFFBaseModel\033[0m", DeprecationWarning, stacklevel=2)
warnings.warn(
"\033[93mmodel_kv argument will be replaced by qeff_model of type QEFFBaseModel\033[0m",
DeprecationWarning,
stacklevel=2,
)
# Get model_kv first
model_kv = model_kv if model_kv else QEFFCommonLoader.from_pretrained(pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), hf_token=hf_token, cache_dir=cache_dir)
model_kv = (
model_kv
if model_kv
else QEFFCommonLoader.from_pretrained(
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
hf_token=hf_token,
cache_dir=cache_dir,
)
)

# Transform if required
if model_kv.is_transformed and not kv:
Expand All @@ -434,15 +459,25 @@ def qualcomm_efficient_converter(
onnx_dir_path = os.path.join(model_card_dir, "onnx")

# Load tokenizer if not passed
tokenizer = tokenizer if tokenizer else load_hf_tokenizer(pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), hf_token=hf_token, cache_dir=cache_dir, local_model_dir=local_model_dir)

tokenizer = (
tokenizer
if tokenizer
else load_hf_tokenizer(
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
hf_token=hf_token,
cache_dir=cache_dir,
local_model_dir=local_model_dir,
)
)

if form_factor == "cloud":
generated_onnx_model_path = export_for_cloud(
model_name=model_name,
qeff_model=model_kv,
tokenizer=tokenizer,
onnx_dir_path=onnx_dir_path,
seq_length=seq_length)
seq_length=seq_length,
)
return onnx_dir_path, generated_onnx_model_path
else:
# [TODO]: Apply the class transformation to make changes for the KV models in edge use cases
Expand Down
1 change: 1 addition & 0 deletions QEfficient/src/_transformers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
"""
transform: bool = kwargs.get("transform", True)
kwargs.update({"use_cache": True}) # Always pass use_cache = True, to get KV values as output during ONNX export
kwargs.update({"attn_implementation" : "eager"}) # Always use eager mode for attention implementation

model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls(model, transform=transform)
Expand Down

0 comments on commit 171bcd8

Please sign in to comment.