Skip to content

Commit

Permalink
use jumpstart deployment config image as default optimization image
Browse files Browse the repository at this point in the history
  • Loading branch information
gwang111 committed Jan 14, 2025
1 parent a58654e commit 3aaf596
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
33 changes: 32 additions & 1 deletion src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,13 @@ def _optimize_for_jumpstart(
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
# only apply default image for vLLM usecases.
# vLLM does not support compilation for now so skip on compilation
return (
create_optimization_job_args
if is_compilation
else self._set_optimization_image_default(create_optimization_job_args)
)
return None

def _is_gated_model(self, model=None) -> bool:
Expand Down Expand Up @@ -986,3 +992,28 @@ def _get_neuron_model_env_vars(
)
return job_model.env
return None

def _set_optimization_image_default(
self, create_optimization_job_args: Dict[str, Any]
) -> Dict[str, Any]:
"""Defaults the optimization image to the JumpStart deployment config default
Args:
create_optimization_job_args (Dict[str, Any]): create optimization job request
Returns:
Dict[str, Any]: create optimization job request with image uri default
"""

for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig"):
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
if not model_quantization_config.get("Image"):
model_quantization_config["Image"] = self.pysdk_model.init_kwargs["image_uri"]

if optimization_config.get("ModelShardingConfig"):
model_sharding_config = optimization_config.get("ModelShardingConfig")
if not model_sharding_config.get("Image"):
model_sharding_config["Image"] = self.pysdk_model.init_kwargs["image_uri"]

return create_optimization_job_args
18 changes: 18 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
Expand All @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
accept_eula=True,
)

assert not sagemaker_session.sagemaker_client.create_optimization_job.called

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down Expand Up @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelShardingConfig"]["Image"]
is not None
)

optimized_model.deploy(
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
)
Expand Down Expand Up @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelQuantizationConfig"]["Image"]
is not None
)

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,7 @@ def test_optimize_quantize_for_jumpstart(
mock_pysdk_model.image_uri = mock_tgi_image_uri
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
mock_pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}

sample_input = {
"inputs": "The diamondback terrapin or simply terrapin is a species "
Expand Down Expand Up @@ -1201,6 +1202,9 @@ def test_optimize_quantize_for_jumpstart(
)

self.assertIsNotNone(out_put)
self.assertEqual(
out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], "mock_js_image"
)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
Expand Down Expand Up @@ -1287,6 +1291,7 @@ def test_optimize_quantize_and_compile_for_jumpstart(
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
mock_pysdk_model.config_name = "config_name"
mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config}
mock_pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}

sample_input = {
"inputs": "The diamondback terrapin or simply terrapin is a species "
Expand Down Expand Up @@ -1319,6 +1324,8 @@ def test_optimize_quantize_and_compile_for_jumpstart(
)

self.assertIsNotNone(out_put)
self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image"))
self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image"))

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,6 +3733,7 @@ def test_optimize_sharding_with_override_for_js(
pysdk_model.env = {"key": "val"}
pysdk_model._enable_network_isolation = True
pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None
pysdk_model.init_kwargs = {"image_uri": "mock_js_image"}

mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
Expand Down Expand Up @@ -3803,8 +3804,9 @@ def test_optimize_sharding_with_override_for_js(
OptimizationConfigs=[
{
"ModelShardingConfig": {
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}
}
"Image": "mock_js_image",
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"},
},
}
],
OutputConfig={
Expand Down

0 comments on commit 3aaf596

Please sign in to comment.