Skip to content

Commit

Permalink
added full_batch_size to replicate_kv script (quic#185)
Browse files Browse the repository at this point in the history
added full_batch_size to replicate_kv script and removed num_hidden_layers bug

Signed-off-by: Onkar Chougule <[email protected]>
  • Loading branch information
ochougul authored Nov 28, 2024
1 parent d093912 commit 21c11b6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def main(
)
parser.add_argument(
"--full_batch_size",
"--full_batch_size",
"--full-batch-size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
Expand Down
18 changes: 15 additions & 3 deletions scripts/replicate_kv_head/replicate_kv_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main(args):
replace_transformers_quantizers()
model = AutoModelForCausalLM.from_pretrained(
model_name,
num_hidden_layers=1,
# num_hidden_layers=1, # Use for generating smaller model
attn_implementation="eager",
)
# Undo the effect of replace_transformers_quantizers
Expand Down Expand Up @@ -104,23 +104,35 @@ def main(args):
)

# Export the modified model
q_model = QEFFAutoModelForCausalLM(model, model_name)
q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if args.full_batch_size else False))
export(
model_name,
q_model,
tokenizer=tokenizer,
onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
full_batch_size=(args.full_batch_size if args.full_batch_size else None),
)


if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(description="Modify and export a causal language model.")
parser.add_argument(
"--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="Name of the model to use."
"--model_name",
"--model-name",
type=str,
default="meta-llama/Meta-Llama-3-8B-Instruct",
help="Name of the model to use.",
)
parser.add_argument("--prompt", type=str, default="My name is", help="Prompt to use for the model.")
parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.")
parser.add_argument(
"--full_batch_size",
"--full-batch-size",
type=int,
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)

args = parser.parse_args()
main(args)

0 comments on commit 21c11b6

Please sign in to comment.