Skip to content

Commit

Permalink
Fix Flake8 Violations
Browse files Browse the repository at this point in the history
  • Loading branch information
nargokul committed Dec 20, 2024
1 parent beb23ec commit 637bc2e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
14 changes: 12 additions & 2 deletions src/sagemaker/serve/model_server/multi_model_server/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@ def input_fn(input_data, content_type):
try:
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type,
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type[0],
)

# Check if preprocess method is defined and call it
Expand Down
14 changes: 12 additions & 2 deletions src/sagemaker/serve/model_server/torchserve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,21 @@ def input_fn(input_data, content_type):
try:
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type,
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type[0],
)

# Check if preprocess method is defined and call it
Expand Down
14 changes: 12 additions & 2 deletions src/sagemaker/serve/model_server/torchserve/xgboost_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,21 @@ def input_fn(input_data, content_type):
try:
if hasattr(schema_builder, "custom_input_translator"):
return schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type,
)
else:
return schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
),
content_type[0],
)
except Exception as e:
raise Exception("Encountered error in deserialize_request.") from e
Expand Down

0 comments on commit 637bc2e

Please sign in to comment.