Skip to content

Commit

Permalink
Merge branch 'quic:main' into finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
quic-mamta authored Jan 20, 2025
2 parents c3079b0 + 2904183 commit 2c105b0
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 56 deletions.
29 changes: 18 additions & 11 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase):
.. code-block:: python
from QEfficient import QEFFAutoModelForCausalLM
from transformers import AutoTokenizer
model_name = "gpt2"
model = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers=2)
model.compile(prefill_seq_len=32, ctx_len=1024)
model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16, num_devices=1)
model.generate(prompts=["Hi there!!"])
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
"""

_hf_auto_class = AutoModelForCausalLM
Expand Down Expand Up @@ -141,15 +144,18 @@ def from_pretrained(
.. code-block:: python
from QEfficient import QEFFAutoModelForCausalLM
from transformers import AutoTokenizer
# Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM
model = QEFFAutoModelForCausalLM.from_pretrained("gpt2")
model_name = "gpt2"
model = QEFFAutoModelForCausalLM.from_pretrained(model_name)
# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=6, device_group=[0]) # Considering you have a Cloud AI 100 Standard SKU
model.compile(num_cores=16) # Considering you have a Cloud AI 100 Standard SKU
# You can now execute the model
model.generate(prompts=["Hi there!!"])
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
"""

if kwargs.pop("full_batch_size", None):
Expand Down Expand Up @@ -391,9 +397,11 @@ def generate(
If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
``Mandatory`` Args:
:tokenizer (Union[PreTrainedTokenizerFast, PreTrainedTokenizer]): Pass tokenizer of the model.
:prompts (List[str]): List of prompts to run the execution.
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
``optional`` Args:
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
"""
Expand Down Expand Up @@ -430,7 +438,7 @@ class QEFFAutoModel(QEFFTransformersBase):
model = QEFFAutoModel.from_pretrained("model_name")
# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=16, device_group=[0]) # Considering you have a Cloud AI 100 SKU
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
#prepare input
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -469,7 +477,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = QEFFAutoModel.from_pretrained("model_name")
# Now you can directly compile the model for Cloud AI 100
model.compile(num_cores=16, device_group=[0]) # Considering you have a Cloud AI 100 SKU
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
#prepare input
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -594,10 +602,9 @@ def generate(
This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
``Mandatory`` Args:
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
``optional`` Args:
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
:eq_len (int, optional): Sequence length for the inputs. Defaults to constants.Constants.CTX_LEN.
Returns:
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
"""
Expand Down Expand Up @@ -660,7 +667,7 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray
Generates features from a list of text prompts using a PyTorch model.
``Mandatory`` Args:
model: The transformed PyTorch model used for generating features.
:model: The transformed PyTorch model used for generating features.
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
Returns:
Expand Down
7 changes: 0 additions & 7 deletions docs/source/hl_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@
import QEfficient
base_path, onnx_model_path = QEfficient.export(model_name="gpt2")
qpc_path = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0])
# Similarly for QPC Compiled via QNN SDK
# 1. export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder
# 2. add --enable_qnn in the command
# 3. An optional config file can be provided via qnn_config if user wish to override the default parameters.
qpc_path_qnn = QEfficient.compile(onnx_path=onnx_model_path, qpc_path=os.path.join(base_path, "qpc"), num_cores=14, device_group=[0],
enable_qnn=True, qnn_config = "QEfficient/compile/qnn_config.json")
.. deprecated::
This function will be deprecated in version 1.19, please use QEFFAutoModelForCausalLM.compile instead
```
Expand Down
97 changes: 70 additions & 27 deletions docs/source/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ You can also pass path of txt file with input prompts when you want to run infer
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 3 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompts_txt_file_path examples/prompts.txt --mxfp6 --mos 1 --aic_enable_depth_first
```

For QNN Compilation, export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder & add --enable_qnn in the command and an optional config file if user wish to override the default parameters.
Without QNN Config
```bash
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn
```

With QNN Config
```bash
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn QEfficient/compile/qnn_config.json
````
### QEfficient.cloud.execute
You can first run `infer` API and then use `execute` to run the pre-compiled model on Cloud AI 100 cards.
Once we have compiled the QPC, we can now use the precompiled QPC in execute API to run for different prompts. Make sure to pass same `--device_group` as used during infer. Refer [Execute API doc](execute_api) for more details.
Expand All @@ -83,10 +73,6 @@ You can also enable MQ, just based on the number of devices. Based on the `--dev
python -m QEfficient.cloud.infer --model_name Salesforce/codegen-2B-mono --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device-group [0,1] --prompt "def fibonacci(n):" --mos 2 --aic_enable_depth_first
```

For QNN Compilation, export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder & add --enable_qnn in the command and an optional config file if user wish to override the default parameters.
```bash
python -m QEfficient.cloud.infer --model_name Salesforce/codegen-2B-mono --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device-group [0,1] --prompt "def fibonacci(n):" --mos 2 --aic_enable_depth_first --enable_qnn QEfficient/compile/qnn_config.json
```
Above step will save the `qpc` files under `efficient-transformers/qeff_models/{model_card_name}`, you can use the execute API to run for different prompts. This will automatically pick the pre-compiled `qpc` files.

```bash
Expand All @@ -99,12 +85,6 @@ To disable MQ, just pass single soc like below, below step will compile the mode
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device-group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first
```

For QNN Compilation, export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder & add --enable_qnn in the command and an optional config file if user wish to override the default parameters.
```bash
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device-group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn QEfficient/compile/qnn_config.json
```


### Continuous Batching

Users can compile a model utilizing the continuous batching feature by specifying full_batch_size <full_batch_size_value> in the infer and compiler APIs. If full_batch_size is not provided, the model will be compiled in the regular way.
Expand All @@ -118,11 +98,77 @@ python -m QEfficient.cloud.infer --model_name TinyLlama/TinyLlama_v1.1 --prompt_
theory is the belief that|The sun rises from" --mxfp6 --mos 1 --aic_enable_depth_first --full_batch_size 3
```

For QNN Compilation, export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder & add --enable_qnn in the command and an optional config file if user wish to override the default parameters.
### QNN Compilation

Users can compile a model with QNN SDK by following the steps below:

* Set QNN SDK Path: export $QNN_SDK_ROOT=/path/to/qnn_sdk_folder
* Enabled QNN by passing enable_qnn flag, add --enable_qnn in the cli command.
* An optional config file can be passed to override the default parameters.

**CLI Inference Command**

Without QNN Config
```bash
python -m QEfficient.cloud.infer --model_name TinyLlama/TinyLlama_v1.1 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompt "My name is|The flat earth
theory is the belief that|The sun rises from" --mxfp6 --mos 1 --aic_enable_depth_first --full_batch_size 3 --enable_qnn QEfficient/compile/qnn_config.json
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn
```

With QNN Config
```bash
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn QEfficient/compile/qnn_config.json
````

**CLI Compile Command**

Users can also use `compile` API to compile pre exported onnx models using QNN SDK.

Without QNN Config
```bash
python -m QEfficient.cloud.compile --onnx_path <path to gpt2 onnx file> --qpc-path <path to save qpc files> --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn
```

With QNN Config
```bash
python -m QEfficient.cloud.compile --onnx_path <path to gpt2 onnx file> --qpc-path <path to save qpc files> --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first --enable_qnn QEfficient/compile/qnn_config.json
````
**CLI Execute Command**
Once we have compiled the QPC using `infer` or `compile` API, we can now use the precompiled QPC in `execute` API to run for different prompts.
Make sure to pass same `--device_group` as used during infer. Refer [Execute API doc](execute_api) for more details.
```bash
python -m QEfficient.cloud.execute --model_name gpt2 --qpc_path qeff_models/gpt2/qpc_qnn_16cores_1BS_32PL_128CL_1devices_mxfp6/qpcs --prompt "Once upon a time in" --device_group [0]
```

**QNN Compilation via Python API**

Users can also use python API to export, compile and execute onnx models using QNN SDK.

```Python
# We can now export the modified models to ONNX framework
# This will generate single ONNX Model for both Prefill and Decode Variations which are optimized for
# Cloud AI 100 Platform.
from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM

# Model-Card name (This is HF Model Card name) : https://huggingface.co/gpt2-xl
model_name = "gpt2" # Similar, we can change model name and generate corresponding models, if we have added the support in the lib.

qeff_model = AutoModelForCausalLM.from_pretrained(model_name)

generated_qpc_path = qeff_model.compile(
num_cores=14,
mxfp6=True,
enable_qnn=True,
qnn_config = qnn_config_file_path # QNN compilation configuration is passed.
)

qeff_model.generate(prompts=["My name is"])
```

**Users can also take advantage of features like multi-Qranium inference and continuous batching with QNN SDK Compilation.**

## Python API

### 1. Model download and Optimize for Cloud AI 100
Expand Down Expand Up @@ -169,9 +215,6 @@ Use the qualcomm_efficient_converter API to export the KV transformed Model to O
generated_qpc_path = qeff_model.compile(
num_cores=14,
mxfp6=True,
device_group=[0],
enable_qnn=True # if QNN Compilation path {default = False}
qnn_config = qnn_config_file_path # if QNN compilation configuration is passed {default = None}.
)
```

Expand Down Expand Up @@ -202,4 +245,4 @@ tlm.compile(num_speculative_tokens=k)
dlm.compile()
```

The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level.
The `is_tlm` flag is fed during the instantiation of the model because slight changes to the ONNX graph are required. Once complete, the user can specify `num_speculative_tokens` to define the actual number of speculations that the TLM will take as input during the decode phase. As for the DLM, no new changes are required at the ONNX or compile level.
2 changes: 2 additions & 0 deletions docs/source/validate.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
| [Gemma-2-2b](https://huggingface.co/google/gemma-2-2b) |✔️ |
| [Gemma-2-9b](https://huggingface.co/google/gemma-2-9b) |✔️ |
| [Gemma-2-27b](https://huggingface.co/google/gemma-2-27b) |✔️ |
| [Granite-20b-code-base](https://huggingface.co/ibm-granite/granite-20b-code-base-8k) | ✔️ |
| [Granite-20b-code-instruct-8k](https://huggingface.co/ibm-granite/granite-20b-code-instruct-8k) | ✔️ |
| [Jais-adapted-7b](https://huggingface.co/inceptionai/jais-adapted-7b) |✔️ |
| [Jais-adapted-13b-chat](https://huggingface.co/inceptionai/jais-adapted-13b-chat) |✔️ |
| [Jais-adapted-70b](https://huggingface.co/inceptionai/jais-adapted-70b) |✔️ |
Expand Down
Loading

0 comments on commit 2c105b0

Please sign in to comment.