From 637bc2e2b6ec1bf3ff93b69e687e8ab71c6340ab Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 20 Dec 2024 09:31:54 -0800 Subject: [PATCH] Fix Flake8 Violations --- .../model_server/multi_model_server/inference.py | 14 ++++++++++++-- .../serve/model_server/torchserve/inference.py | 14 ++++++++++++-- .../model_server/torchserve/xgboost_inference.py | 14 ++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 3cece40c5e..1d2440f5f9 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -45,11 +45,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 294c032ccc..489cc1bc1e 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -67,11 +67,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 6dab9bc6c6..517c774bbc 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e