From aaac0d4397971bcd31c1050439175056bd764561 Mon Sep 17 00:00:00 2001 From: Abukhoyer Shaik Date: Mon, 20 Jan 2025 09:28:48 +0000 Subject: [PATCH] Formatted Signed-off-by: Abukhoyer Shaik --- tests/qnn_tests/test_causal_lm_models_qnn.py | 6 +++--- tests/transformers/models/test_causal_lm_models.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/qnn_tests/test_causal_lm_models_qnn.py b/tests/qnn_tests/test_causal_lm_models_qnn.py index 9ab57b7a..f66224f3 100644 --- a/tests/qnn_tests/test_causal_lm_models_qnn.py +++ b/tests/qnn_tests/test_causal_lm_models_qnn.py @@ -86,9 +86,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - assert ( - pytorch_hf_tokens == pytorch_kv_tokens - ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + ) onnx_model_path = qeff_model.export() ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index babfc810..433f1ff1 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -110,9 +110,9 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - assert ( - pytorch_hf_tokens == pytorch_kv_tokens - ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + ) onnx_model_path = qeff_model.export() ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) @@ -204,9 +204,9 @@ def test_causal_lm_export_with_deprecated_api(model_name): new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path) old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path) - assert ( - new_api_ort_tokens == old_api_ort_tokens - ).all(), "New API output does not match old API output for ONNX export function" + assert (new_api_ort_tokens == old_api_ort_tokens).all(), ( + "New API output does not match old API output for ONNX export function" + ) @pytest.mark.on_qaic