diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 3cece40c5e..1d2440f5f9 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -45,11 +45,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 294c032ccc..489cc1bc1e 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -67,11 +67,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 6dab9bc6c6..517c774bbc 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e