Skip to content

Commit

Permalink
Replacing assertions (quic#98)
Browse files Browse the repository at this point in the history
* Replacing assertions

Signed-off-by: quic-meet <[email protected]>

* Changes

Signed-off-by: Meet Doshi <[email protected]>

* misisng device ids param

Signed-off-by: Meet Doshi <[email protected]>

* changes

Signed-off-by: Meet Doshi <[email protected]>

* changes

Signed-off-by: Meet Doshi <[email protected]>

---------

Signed-off-by: quic-meet <[email protected]>
Signed-off-by: Meet Doshi <[email protected]>
Co-authored-by: Meet Doshi <[email protected]>
Co-authored-by: Amit Raj <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent c433c11 commit b8cb759
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 33 deletions.
12 changes: 6 additions & 6 deletions QEfficient/base/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE:
"""
Loads model config file and returns the type of the model (i.e. LLMs, SD, quantized etc.) as supported by the library.
"""
assert os.path.isdir(
hf_model_path
), "Pleae pass local dir path where the model is downloaded; use `QEfficient.utils.login_and_download_hf_lm` for downloading hf model"
if not os.path.isdir(hf_model_path):
raise FileNotFoundError(
"Please pass local dir path where the model is downloaded; use `QEfficient.utils.login_and_download_hf_lm` for downloading hf model"
)
config, kwargs = AutoConfig.from_pretrained(
hf_model_path,
return_unused_kwargs=True,
Expand Down Expand Up @@ -84,9 +85,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
pretrained_model_name_or_path = login_and_download_hf_lm(pretrained_model_name_or_path, *args, **kwargs)
model_type = get_hf_model_type(hf_model_path=pretrained_model_name_or_path)
qeff_auto_model_class = MODEL_TYPE_TO_QEFF_AUTO_MODEL_MAP[model_type]
assert issubclass(
qeff_auto_model_class, QEFFBaseModel
), f"Expected class that inherits {QEFFBaseModel}, got {type(qeff_auto_model_class)}"
if not issubclass(qeff_auto_model_class, QEFFBaseModel):
raise Exception(f"Expected class that inherits {QEFFBaseModel}, got {type(qeff_auto_model_class)}")

return qeff_auto_model_class.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
Expand Down
8 changes: 4 additions & 4 deletions QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def compile_kv_model_on_cloud_ai_100(
if os.path.isdir(aic_binary_dir):
shutil.rmtree(aic_binary_dir)

assert os.path.isfile(
specializations_json
), f"Please use 'QEfficient.compile', as {specializations_json} file was not found"
assert os.path.isfile(custom_io_path), f"{custom_io_path} file was not found!"
if not os.path.isfile(specializations_json):
raise FileNotFoundError(f"Please use 'QEfficient.compile', as {specializations_json} file was not found")
if not os.path.isfile(custom_io_path):
raise FileNotFoundError(f"{custom_io_path} file was not found!")
command = [
"/opt/qti-aic/exec/qaic-exec",
f"-m={onnx_path}",
Expand Down
9 changes: 6 additions & 3 deletions QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def convert_to_cloud_kvstyle(
logger.warning(f"Overriding {onnx_dir_path}")
shutil.rmtree(onnx_dir_path)

assert qeff_model.is_transformed, f"please pass the {qeff_model.__class__.__name__} after transform API"
if not qeff_model.is_transformed:
raise Exception(f"please pass the {qeff_model.__class__.__name__} after transform API")

# Decide path for saving exported ONNX files.
model_name = export_kvstyle_transformed_model_to_onnx(
Expand Down Expand Up @@ -220,8 +221,10 @@ def export_kvstyle_transformed_model_to_onnx(
output_names = list(pt_outputs.keys())

# Raise error if expected outputs are not present
assert "logits" in output_names, "logits not found in output"
assert "past_key_values" in output_names, "past_key_values not found in output"
if "logits" not in output_names:
raise KeyError("logits not found in output")
if "past_key_values" not in output_names:
raise KeyError("past_key_values not found in output")

# Build inputs for next iteration from outputs
# Build inputs for decode
Expand Down
21 changes: 13 additions & 8 deletions QEfficient/generation/cloud_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ def __init__(
self.context = qaicrt.Context()
self.queue = qaicrt.Queue(self.context, 0) # Async API
if enable_debug_logs:
assert (
self.context.setLogLevel(qaicrt.QLogLevel.QL_DEBUG) == qaicrt.QStatus.QS_SUCCESS
), "Failed to setLogLevel"
if self.context.setLogLevel(qaicrt.QLogLevel.QL_DEBUG) != qaicrt.QStatus.QS_SUCCESS:
raise RuntimeError("Failed to setLogLevel")

qpc = qaicrt.Qpc(qpc_path)
# Load IO Descriptor
iodesc = aicapi.IoDesc()
status, iodesc_data = qpc.getIoDescriptor()
assert status == qaicrt.QStatus.QS_SUCCESS, "Failed to getIoDescriptor"
if status != qaicrt.QStatus.QS_SUCCESS:
raise RuntimeError("Failed to getIoDescriptor")
iodesc.ParseFromString(bytes(iodesc_data))
self.allowed_shapes = [
[(aic_to_np_dtype_mapping[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes]
Expand All @@ -87,7 +88,8 @@ def __init__(
if device_ids and len(device_ids) > 1:
prog_properties.devMapping = ":".join(map(str, device_ids))
self.program = qaicrt.Program(self.context, None, qpc, prog_properties)
assert self.program.load() == qaicrt.QStatus.QS_SUCCESS, "Failed to load program"
if self.program.load() != qaicrt.QStatus.QS_SUCCESS:
raise RuntimeError("Failed to load program")
if activate:
self.activate()
# Create input qbuffers and buf_dims
Expand Down Expand Up @@ -157,11 +159,13 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
# Set inputs
self.set_buffers(inputs)
assert self.execObj.setData(self.qbuffers, self.buf_dims) == qaicrt.QStatus.QS_SUCCESS, "Failed to setData"
if self.execObj.setData(self.qbuffers, self.buf_dims) != qaicrt.QStatus.QS_SUCCESS:
raise MemoryError("Failed to setData")
# # Run with sync API
# if self.execObj.run(self.qbuffers) != qaicrt.QStatus.QS_SUCCESS:
# Run with async API
assert self.queue.enqueue(self.execObj) == qaicrt.QStatus.QS_SUCCESS, "Failed to enqueue"
if self.queue.enqueue(self.execObj) != qaicrt.QStatus.QS_SUCCESS:
raise MemoryError("Failed to enqueue")
if self.execObj.waitForCompletion() != qaicrt.QStatus.QS_SUCCESS:
error_message = "Failed to run"
# Print additional error messages for unmatched dimension error
Expand All @@ -187,7 +191,8 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
raise ValueError(error_message)
# Get output buffers
status, output_qbuffers = self.execObj.getData()
assert status == qaicrt.QStatus.QS_SUCCESS, "Failed to getData"
if status != qaicrt.QStatus.QS_SUCCESS:
raise MemoryError("Failed to getData")
# Build output
outputs = {}
for output_name in self.output_names:
Expand Down
8 changes: 4 additions & 4 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int]:


def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]:
assert (
prompt is not None or prompts_txt_file_path is not None
), "Please pass at least one argument either using --prompt or --prompts_txt_file_path"
if prompt is None and prompts_txt_file_path is None:
raise ValueError("Please pass at least one argument either using --prompt or --prompts_txt_file_path")
if prompts_txt_file_path is not None:
if prompt is not None:
logger.warning("Found inputs passed using txt file as well as CLI, taking inputs from given txt file")
Expand Down Expand Up @@ -444,7 +443,8 @@ def _fetch_generation_len(self, generation_len, max_gen_len):
"Passed generation_len is greater than allowed length. "
"Make sure this model supports sliding window, such as Mistral"
)
assert generation_len > 0, "generation length should be greater than zero"
if generation_len <= 0:
raise ValueError("generation length should be greater than zero")
return generation_len

def prepare_decode_inputs(self):
Expand Down
7 changes: 4 additions & 3 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def export(self) -> str:
Returns:
:str: Path of the generated ``ONNX`` graph.
"""
assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object"
# Export
_, onnx_model_path = QEfficient.export(
model_name=self.model_card_name,
Expand Down Expand Up @@ -366,11 +365,13 @@ def generate(self, prompts: List[str], device_id: List[int] = None, runtime: str
``optional`` Args:
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
"""
assert Runtime(runtime) == Runtime.AI_100, "Only AI_100 runtime is supported right now via generate API"
if Runtime(runtime) != Runtime.AI_100:
raise ValueError("Only AI_100 runtime is supported right now via generate API")
self.run_cloud_ai_100(prompts=prompts, device_id=device_id, **kwargs)

def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs):
assert isinstance(self.qpc_path, str), "Please run compile API first!"
if not isinstance(self.qpc_path, str):
raise TypeError("Please run compile API first!")
generation_len = kwargs.pop("generation_len", None)
return QEfficient.cloud_ai_100_exec_kv(
self.tokenizer,
Expand Down
8 changes: 4 additions & 4 deletions QEfficient/transformers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def transform_lm(model: nn.Module) -> nn.Module:

# Check with new params hash
later_params_hash = get_params_hash(model)
assert (
prior_params_hash == later_params_hash
), "Weights were changed in the transform process, please report an issue"
if prior_params_hash != later_params_hash:
raise RuntimeError("Weights were changed in the transform process, please report an issue")

# Replace the Dyanmic cache utils update api
transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update
Expand All @@ -94,7 +93,8 @@ def transform(model: QEFFBaseModel, form_factor="cloud"):
model (torch.nn.Module): object of any instance of class that is child of `QEFFBaseAutoModelFactory`
form_factor (str): form factor configuration for optimizing the model, available options=["cloud", "edge"].
"""
assert form_factor == "cloud", "Only form_factor='cloud' is supported as of now!"
if form_factor != "cloud":
raise ValueError("Only form_factor='cloud' is supported as of now!")
# FIXME: move this to class and use model.transform()
if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM:
transform_lm(model.model) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokeni
tokenizer.padding_side = "right"

if tokenizer.pad_token_id is None:
assert tokenizer.eos_token_id is not None, "Found tokenizer.eos_token_id to be None, expected int"
if not isinstance(tokenizer.eos_token_id, int):
raise TypeError("Found tokenizer.eos_token_id to be None, expected int")
# If Pad token is out of range of vocab size
if tokenizer.eos_token_id < tokenizer.vocab_size:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down

0 comments on commit b8cb759

Please sign in to comment.