diff --git a/docs/sagemaker/inference.md b/docs/sagemaker/inference.md index c7ff4df6d..b5c5f7fc4 100644 --- a/docs/sagemaker/inference.md +++ b/docs/sagemaker/inference.md @@ -346,25 +346,50 @@ The `inference.py` file contains your custom inference module, and the `requirem Here is an example of a custom inference module with `model_fn`, `input_fn`, `predict_fn`, and `output_fn`: ```python +from sagemaker_huggingface_inference_toolkit import decoder_encoder + def model_fn(model_dir): - return "model" + # implement custom code to load the model + loaded_model = ... + + return loaded_model -def input_fn(data, content_type): - return "data" +def input_fn(input_data, content_type): + # decode the input data (e.g. JSON string -> dict) + data = decoder_encoder.decode(input_data, content_type) + return data def predict_fn(data, model): - return "output" + # call your custom model with the data + outputs = model(data , ... ) + return predictions def output_fn(prediction, accept): - return prediction + # convert the model output to the desired output format (e.g. dict -> JSON string) + response = decoder_encoder.encode(prediction, accept) + return response ``` Customize your inference module with only `model_fn` and `transform_fn`: ```python +from sagemaker_huggingface_inference_toolkit import decoder_encoder + def model_fn(model_dir): - return "loading model" + # implement custom code to load the model + loaded_model = ... + + return loaded_model def transform_fn(model, input_data, content_type, accept): - return f"output" + # decode the input data (e.g. JSON string -> dict) + data = decoder_encoder.decode(input_data, content_type) + + # call your custom model with the data + outputs = model(data , ... ) + + # convert the model output to the desired output format (e.g. dict -> JSON string) + response = decoder_encoder.encode(output, accept) + + return response ```