Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
kc1998dp authored Dec 24, 2024
2 parents e1b9420 + fa178be commit 338f608
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")


def model_fn(model_dir):
def model_fn(model_dir, context=None):
"""Overrides default method for loading a model"""
shared_libs_path = Path(model_dir + "/shared_libs")

Expand All @@ -40,7 +40,7 @@ def model_fn(model_dir):
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))


def input_fn(input_data, content_type):
def input_fn(input_data, content_type, context=None):
"""Deserializes the bytes that were received from the model server"""
try:
if hasattr(schema_builder, "custom_input_translator"):
Expand Down Expand Up @@ -72,12 +72,12 @@ def input_fn(input_data, content_type):
raise Exception("Encountered error in deserialize_request.") from e


def predict_fn(input_data, predict_callable):
def predict_fn(input_data, predict_callable, context=None):
"""Invokes the model that is taken in by model server"""
return predict_callable(input_data)


def output_fn(predictions, accept_type):
def output_fn(predictions, accept_type, context=None):
"""Prediction is serialized to bytes and sent back to the customer"""
try:
if hasattr(inference_spec, "postprocess"):
Expand Down

0 comments on commit 338f608

Please sign in to comment.