Skip to content

Commit

Permalink
Merge branch 'master' into enable-github-actions-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
benieric authored Mar 11, 2024
2 parents 96fb7ed + 615a8ad commit 7304bfe
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/sagemaker/huggingface/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
Returns:
dict: The model metadata retrieved with the HuggingFace API
"""

if not model_id:
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
hf_model_metadata_json = None
try:
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:

def _script_mode_env_vars(self):
"""Returns a mapping of environment variables for script mode execution"""
script_name = None
dir_name = None
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
if self.uploaded_code:
script_name = self.uploaded_code.script_name
if self.repacked_model_data or self.enable_network_isolation():
Expand All @@ -783,8 +783,8 @@ def _script_mode_env_vars(self):
else "file://" + self.source_dir
)
return {
SCRIPT_PARAM_NAME.upper(): script_name or str(),
DIR_PARAM_NAME.upper(): dir_name or str(),
SCRIPT_PARAM_NAME.upper(): script_name,
DIR_PARAM_NAME.upper(): dir_name,
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
}
Expand Down
19 changes: 14 additions & 5 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
into a stream. All translations between the server and the client are handled
automatically with the specified input and output.
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
``inference_spec`` is required for the model builder to build the artifact.
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
is required for the model builder to build the artifact.
inference_spec (InferenceSpec): The inference spec file with your customized
``invoke`` and ``load`` functions.
image_uri (Optional[str]): The container image uri (which is derived from a
Expand All @@ -145,6 +145,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and``TGI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace
model metadata. Currently ``HF_TASK`` is overridable.
"""

model_path: Optional[str] = field(
Expand Down Expand Up @@ -241,6 +243,10 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
model_server: Optional[ModelServer] = field(
default=None, metadata={"help": "Define the model server to deploy to."}
)
model_metadata: Optional[Dict[str, Any]] = field(
default=None,
metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"},
)

def _build_validations(self):
"""Placeholder docstring"""
Expand Down Expand Up @@ -616,6 +622,9 @@ def build( # pylint: disable=R0911
self._is_custom_image_uri = self.image_uri is not None

if isinstance(self.model, str):
model_task = None
if self.model_metadata:
model_task = self.model_metadata.get("HF_TASK")
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl(): # pylint: disable=R1705
Expand All @@ -625,10 +634,10 @@ def build( # pylint: disable=R0911
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task:
if model_task is None:
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._schema_builder_init(model_task)

if model_task == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
elif self._can_fit_on_single_gpu():
Expand Down
40 changes: 20 additions & 20 deletions src/sagemaker/serve/schema/task.json
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
{
"fill-mask": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Paris is the [MASK] of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"sequence": "Paris is the capital of France.",
"score": 0.7
}
]
}
},
},
"question-answering": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"context": "I have a German Shepherd dog, named Coco.",
"question": "What is my dog's breed?"
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"answer": "German Shepherd",
Expand All @@ -32,36 +32,36 @@
}
]
}
},
},
"text-classification": {
"sample_inputs": {
"sample_inputs": {
"properties": {
"inputs": "Where is the capital of France?, Paris is the capital of France.",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"label": "entailment",
"score": 0.997
}
]
}
},
"text-generation": {
"sample_inputs": {
},
"text-generation": {
"sample_inputs": {
"properties": {
"inputs": "Hello, I'm a language model",
"parameters": {}
}
},
"sample_outputs": {
},
"sample_outputs": {
"properties": [
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
]
}
}
}
}
66 changes: 66 additions & 0 deletions tests/integ/sagemaker/serve/test_schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,69 @@ def test_model_builder_negative_path(sagemaker_session):
match="Error Message: Schema builder for text-to-image could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing Schema Builder Simplification feature",
)
@pytest.mark.parametrize(
"model_id, task_provided",
[
("bert-base-uncased", "fill-mask"),
("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"),
],
)
def test_model_builder_happy_path_with_task_provided(
model_id, task_provided, sagemaker_session, gpu_instance_type
):
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided})

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas(task_provided)
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
caught_ex = None
try:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(
role=role_arn, instance_count=1, instance_type=gpu_instance_type
)

predicted_outputs = predictor.predict(inputs)
assert predicted_outputs is not None

except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert (
False
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"


def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
model_builder = ModelBuilder(
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}
)

with pytest.raises(
TaskNotFoundException,
match="Error Message: Schema builder for invalid-task could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)
125 changes: 124 additions & 1 deletion tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def test_build_negative_path_when_schema_builder_not_present(

model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")

self.assertRaisesRegexp(
self.assertRaisesRegex(
TaskNotFoundException,
"Error Message: Schema builder for text-to-image could not be found.",
lambda: model_builder.build(sagemaker_session=mock_session),
Expand Down Expand Up @@ -1593,3 +1593,126 @@ def test_total_inference_model_size_mib_throws(
model_builder.build(sagemaker_session=mock_session)

self.assertEqual(model_builder._can_fit_on_single_gpu(), False)

@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_happy_path_override_with_task_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
mock_hf_model,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"

model_builder = ModelBuilder(
model="bert-base-uncased", model_metadata={"HF_TASK": "text-generation"}
)
model_builder.build(sagemaker_session=mock_session)

self.assertIsNotNone(model_builder.schema_builder)
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
self.assertEqual(
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
)
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)

@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_task_override_with_invalid_task_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"
model_ids_with_invalid_task = {
"bert-base-uncased": "invalid-task",
"bert-large-uncased-whole-word-masking-finetuned-squad": "",
}
for model_id in model_ids_with_invalid_task:
provided_task = model_ids_with_invalid_task[model_id]
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": provided_task})

self.assertRaisesRegex(
TaskNotFoundException,
f"Error Message: Schema builder for {provided_task} could not be found.",
lambda: model_builder.build(sagemaker_session=mock_session),
)

@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_task_override_with_invalid_model_provided(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_image_uris_retrieve,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError

mock_image_uris_retrieve.return_value = "https://some-image-uri"
invalid_model_id = ""
provided_task = "fill-mask"

model_builder = ModelBuilder(
model=invalid_model_id, model_metadata={"HF_TASK": provided_task}
)
with self.assertRaises(Exception):
model_builder.build(sagemaker_session=mock_session)
Loading

0 comments on commit 7304bfe

Please sign in to comment.