From 34ba5fa4b99cadf7d4f156cf8992b947da9cc209 Mon Sep 17 00:00:00 2001 From: garywan Date: Mon, 6 Jan 2025 22:53:32 +0000 Subject: [PATCH] model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream --- .../multi_model_server/inference.py | 22 ++++++++++++++----- .../model_server/torchserve/inference.py | 22 ++++++++++++++----- .../torchserve/xgboost_inference.py | 22 ++++++++++++++----- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 908ffcc7aa..9a0639e508 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -46,18 +46,28 @@ def input_fn(input_data, content_type, context=None): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 489cc1bc1e..4966809db8 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -68,18 +68,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 517c774bbc..f83c279da9 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -71,18 +71,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: return schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + input_data + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], )