From b6203741733d6af31a2c235aacdf9b07e725131c Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 18 Oct 2023 15:51:23 +0000 Subject: [PATCH 01/35] feature: Added components to support time series explainability with Clarify. These components are TimeSeriesDataConfig, TimeSeriesModelConfig, and AsymmetricSHAPConfig, alomg with unit tests for them. fix: Modified DataConfig, ModelConfig, and _AnalysisConfigGenerator to support the new components and time series explainability. documentation: added docstrings for the new components and their tests. --- src/sagemaker/clarify.py | 286 ++++++++++++++++- tests/unit/test_clarify.py | 633 +++++++++++++++++++++++++++++++++++++ 2 files changed, 914 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 9421d0e419..efcef6bb13 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -39,6 +39,19 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" +# TODO: verify these are sensible/sound values +# timeseries predictor config default values +TS_MODEL_DEFAULT_FORECAST_HORIZON = 5 # predictor config +# asymmetric shap default values (timeseries) +ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "fine_grained" +ASYM_SHAP_DEFAULT_NUM_SAMPLES = 5 +ASYM_SHAP_EXPLANATION_TYPES = [ + "timewise_chronological", + "timewise_anti_chronological", + "timewise_bidirectional", + "fine_grained", +] + ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema( { SchemaOptional("version"): str, @@ -84,6 +97,13 @@ SchemaOptional("excluded_columns"): [Or(int, str)], SchemaOptional("joinsource_name_or_index"): Or(str, int), SchemaOptional("group_variable"): Or(str, int), + SchemaOptional("time_series_data_config"): { + "target_time_series": Or(str, int), + "item_id": Or(str, int), + "timestamp": Or(str, int), + SchemaOptional("related_time_series"): [Or(int, str)], + SchemaOptional("item_metadata"): [Or(int, str)], + }, "methods": { SchemaOptional("shap"): { SchemaOptional("baseline"): Or( @@ -277,6 +297,20 @@ SchemaOptional("top_k_features"): int, }, SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, + SchemaOptional("asymmetric_shap"): { + "explanation_type": And( + str, + Use(str.lower), + lambda s: s + in ( + "timewise_chronological", + "timewise_anti_chronological", + "timewise_bidirectional", + "fine_grained", + ), + ), + SchemaOptional("num_samples"): int, + }, }, SchemaOptional("predictor"): { SchemaOptional("endpoint_name"): str, @@ -310,6 +344,11 @@ SchemaOptional("content_template"): Or(str, {str: str}), SchemaOptional("record_template"): str, SchemaOptional("custom_attributes"): str, + SchemaOptional("time_series_predictor_config"): { + "forecast": str, + "forecast_horizon": int, + SchemaOptional("use_future_covariates"): bool, + }, }, } ) @@ -393,6 +432,94 @@ def to_dict(self) -> Dict[str, Any]: # pragma: no cover return segment_config_dict +class TimeSeriesDataConfig: + """Config object for TimeSeries explainability specific data fields.""" + + def __init__( + self, + target_time_series: Union[str, int], + item_id: Union[str, int], + timestamp: Union[str, int], + related_time_series: Optional[List[Union[str, int]]] = None, + item_metadata: Optional[List[Union[str, int]]] = None, + ): + """Initialises TimeSeries explainability data configuration fields. + + Args: #TODO: verify param descriptions are accurate + target_time_series (str or int): A string or a zero-based integer index. + Used to locate target time series in the shared input dataset. + item_id (str or int): A string or a zero-based integer index. Used to + locate item id in the shared input dataset. + timestamp (str or int): A string or a zero-based integer index. Used to + locate timestamp in the shared input dataset. + related_time_series (list[str] or list[int]): Optional. An array of strings + or array of zero-based integer indices. Used to locate all related time + series in the shared input dataset. + item_metadata (list[str] or list[int]): Optional. An array of strings or + array of zero-based integer indices. Used to locate all item metadata + fields in the shared input dataset. + + Raises: + AssertionError: If any required arguments are not provided. + ValueError: If any provided arguments are the wrong type. + """ + # check target_time_series, item_id, and timestamp are provided + assert target_time_series, "Please provide a target time series." + assert item_id, "Please provide an item id." + assert timestamp, "Please provide a timestamp." + # check all arguments are the right types + if not isinstance(target_time_series, (str, int)): + raise ValueError("Please provide a string or an int for ``target_time_series``") + if not isinstance(item_id, (str, int)): + raise ValueError("Please provide a string or an int for ``item_id``") + if not isinstance(timestamp, (str, int)): + raise ValueError("Please provide a string or an int for ``timestamp``") + # add remaining fields to an internal dictionary + self.analysis_config = dict() + _set(target_time_series, "target_time_series", self.analysis_config) + _set(item_id, "item_id", self.analysis_config) + _set(timestamp, "timestamp", self.analysis_config) + # check optional arguments are right types if provided + related_time_series_error_message = ( + "Please provide a list of strings or list of ints for ``related_time_series``" + ) + if related_time_series: + if not isinstance(related_time_series, list): + raise ValueError( + related_time_series_error_message + ) # related_time_series is not a list + if not ( + all([isinstance(value, str) for value in related_time_series]) + or all([isinstance(value, int) for value in related_time_series]) + ): + raise ValueError( + related_time_series_error_message + ) # related_time_series is not a list of strings or list of ints + _set( + related_time_series, "related_time_series", self.analysis_config + ) # related_time_series is valid, add it + item_metadata_series_error_message = ( + "Please provide a list of strings or list of ints for ``item_metadata``" + ) + if item_metadata: + if not isinstance(item_metadata, list): + raise ValueError(item_metadata_series_error_message) # item_metadata is not a list + if not ( + all([isinstance(value, str) for value in item_metadata]) + or all([isinstance(value, int) for value in item_metadata]) + ): + raise ValueError( + item_metadata_series_error_message + ) # item_metadata is not a list of strings or list of ints + _set( + item_metadata, "item_metadata", self.analysis_config + ) # item_metadata is valid, add it + + def get_config(self): + """Returns part of an analysis config dictionary.""" + return copy.deepcopy(self.analysis_config) + + class DataConfig: """Config object related to configurations of the input and output dataset.""" @@ -414,6 +541,7 @@ def __init__( predicted_label: Optional[Union[str, int]] = None, excluded_columns: Optional[Union[List[int], List[str]]] = None, segmentation_config: Optional[List[SegmentationConfig]] = None, + time_series_data_config: Optional[TimeSeriesDataConfig] = None, ): """Initializes a configuration of both input and output datasets. @@ -482,6 +610,8 @@ def __init__( which are to be excluded from making model inference API calls. segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig`` objects. + time_series_data_config (TimeSeriesDataConfig): Optional. A config object for TimeSeries + data specific fields, required for TimeSeries explainability use cases. Raises: ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters @@ -573,6 +703,11 @@ def __init__( "segment_config", self.analysis_config, ) + _set( + time_series_data_config.get_config() if time_series_data_config else None, + "time_series_data_config", + self.analysis_config, + ) def get_config(self): """Returns part of an analysis config dictionary.""" @@ -663,6 +798,52 @@ def get_config(self): return copy.deepcopy(self.analysis_config) +class TimeSeriesModelConfig: + """Config object for TimeSeries predictor configuration fields.""" + + def __init__( + self, + forecast: str, + forecast_horizon: int = TS_MODEL_DEFAULT_FORECAST_HORIZON, + use_future_covariates: Optional[bool] = False, + ): + """Initializes model configuration fields for TimeSeries explainability use cases. + + Args: + forecast (str): JMESPath expression to extract the forecast result. + forecast_horizon (int): An integer that tells the forecast horizon. + use_future_covariates (None or bool): If set as True, future covariates + included in model input and used for forecasting + + Raises: + AssertionError: when either ``forecast`` or ``forecast_horizon`` are not provided + ValueError: when any provided argument are not of specified type + """ + # assert forecast and forecast_horizon are provided + assert ( + forecast + ), "Please provide ``forecast``, a JMESPath expression to extract the forecast result." + assert forecast_horizon, "Please provide an integer ``forecast_horizon``." + # check provided arguments are of the right type + if not isinstance(forecast, str): + raise ValueError("Please provide a string JMESPath expression for ``forecast``.") + if not isinstance(forecast_horizon, int): + raise ValueError("Please provide an integer ``forecast_horizon``.") + if use_future_covariates and not isinstance(use_future_covariates, bool): + raise ValueError("Please provide a boolean value for ``use_future_covariates``.") + # add fields to an internal config dictionary + self.predictor_config = dict() + _set(forecast, "forecast", self.predictor_config) + _set(forecast_horizon, "forecast_horizon", self.predictor_config) + _set( + use_future_covariates, "use_future_covariates", self.predictor_config + ) # _set() does nothing if a given argument is None + + def get_predictor_config(self): + """Returns TimeSeries predictor config dictionary""" + return copy.deepcopy(self.predictor_config) + + class ModelConfig: """Config object related to a model and its endpoint to be created.""" @@ -680,6 +861,7 @@ def __init__( endpoint_name_prefix: Optional[str] = None, target_model: Optional[str] = None, endpoint_name: Optional[str] = None, + time_series_model_config: Optional[TimeSeriesModelConfig] = None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -796,6 +978,9 @@ def __init__( endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint. Cannot be set when ``model_name``, ``instance_count``, and ``instance_type`` set + time_series_model_config (TimeSeriesModelConfig): Optional. A config object for + TimeSeries predictor specific fields, required for TimeSeries + explainability use cases. Raises: ValueError: when the @@ -884,6 +1069,11 @@ def __init__( _set(custom_attributes, "custom_attributes", self.predictor_config) _set(accelerator_type, "accelerator_type", self.predictor_config) _set(target_model, "target_model", self.predictor_config) + _set( + time_series_model_config.get_predictor_config() if time_series_model_config else None, + "time_series_predictor_config", + self.predictor_config, + ) def get_predictor_config(self): """Returns part of the predictor dictionary of the analysis config.""" @@ -1399,6 +1589,45 @@ def get_explainability_config(self): return copy.deepcopy({"shap": self.shap_config}) +class AsymmetricSHAPConfig(ExplainabilityConfig): + """Config class for Asymmetric SHAP algorithm for TimeSeries explainability""" + + def __init__( + self, + explanation_type: str = ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + num_samples: Optional[int] = ASYM_SHAP_DEFAULT_NUM_SAMPLES, + ): + """Initialises config for asymmetric SHAP config. + + AsymmetricSHAPConfig is used specifically and only for TimeSeries explainability purposes. + + Args: + explanation_type (str): Type of explanation to be used + num_samples (None or int): Number of samples to be used in the Asymmetric SHAP + algorithm. + + Raises: + AssertionError: when ``explanation_type`` is not valid + """ + self.asymmetric_shap_config = dict() + # validate explanation type + assert ( + explanation_type in ASYM_SHAP_EXPLANATION_TYPES + ), "Please provide a valid explanation type from: " + ", ".join(ASYM_SHAP_EXPLANATION_TYPES) + # validate num_samples if provided + if num_samples and not isinstance(num_samples, int): + raise ValueError("Please provide an integer value for ``num_samples``.") + # set explanation type and (if provided) num_samples in internal config dictionary + _set(explanation_type, "explanation_type", self.asymmetric_shap_config) + _set( + num_samples, "num_samples", self.asymmetric_shap_config + ) # _set() does nothing if a given argument is None + + def get_explainability_config(self): + """Returns an asymmetric shap config dictionary.""" + return copy.deepcopy({"asymmetric_shap": self.asymmetric_shap_config}) + + class SageMakerClarifyProcessor(Processor): """Handles SageMaker Processing tasks to compute bias metrics and model explanations.""" @@ -2092,12 +2321,26 @@ def explainability( explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Generates a config for Explainability""" + # determine if this is a timeseries explainability case by checking + # if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given + ts_data_config_present = "time_series_data_config" in data_config.analysis_config + ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config + + if ts_data_config_present and ts_model_config_present: + time_series_case = True + elif not ts_data_config_present and not ts_model_config_present: + time_series_case = False + else: + raise ValueError("Please provide both TimeSeriesDataConfig and TimeSeriesModelConfig.") + # construct whole analysis config analysis_config = data_config.analysis_config analysis_config = cls._add_predictor( analysis_config, model_config, model_predicted_label_config ) analysis_config = cls._add_methods( - analysis_config, explainability_config=explainability_config + analysis_config, + explainability_config=explainability_config, + time_series_case=time_series_case, ) return analysis_config @@ -2164,7 +2407,11 @@ def _add_predictor( if isinstance(model_config, ModelConfig): analysis_config["predictor"] = model_config.get_predictor_config() else: - if "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"]: + if ( + "shap" in analysis_config["methods"] + or "pdp" in analysis_config["methods"] + or "asymmetric_shap" in analysis_config["methods"] + ): raise ValueError( "model_config must be provided when explainability methods are selected." ) @@ -2195,11 +2442,16 @@ def _add_methods( pre_training_methods: Union[str, List[str]] = None, post_training_methods: Union[str, List[str]] = None, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None, - report=True, + report: bool = True, + time_series_case: bool = False, ): """Extends analysis config with methods.""" # validate params = [pre_training_methods, post_training_methods, explainability_config] + if time_series_case and not explainability_config: + raise AttributeError( + "At least one AsymmetricSHAPConfig must be provided for TimeSeriex explainability." + ) if not any(params): raise AttributeError( "analysis_config must have at least one working method: " @@ -2225,7 +2477,9 @@ def _add_methods( analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} if explainability_config is not None: - explainability_methods = cls._merge_explainability_configs(explainability_config) + explainability_methods = cls._merge_explainability_configs( + explainability_config, time_series_case + ) analysis_config["methods"] = { **analysis_config["methods"], **explainability_methods, @@ -2236,6 +2490,7 @@ def _add_methods( def _merge_explainability_configs( cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + time_series_case: bool = False, ): """Merges explainability configs, when more than one.""" if isinstance(explainability_config, list): @@ -2243,16 +2498,37 @@ def _merge_explainability_configs( if len(explainability_config) == 0: raise ValueError("Please provide at least one explainability config.") for config in explainability_config: + # ensure all provided explainability configs + # are AsymmetricSHAPConfig in time series case + is_asym_shap_config = isinstance(config, AsymmetricSHAPConfig) + if time_series_case and not is_asym_shap_config: + raise ValueError( + "Please provide only Asymmetric SHAP configs for TimeSeries explainability." + ) + if not time_series_case and is_asym_shap_config: + raise ValueError( + "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." + ) explain_config = config.get_explainability_config() explainability_methods.update(explain_config) if not len(explainability_methods) == len(explainability_config): raise ValueError("Duplicate explainability configs are provided") if ( - "shap" not in explainability_methods + not time_series_case + and "shap" not in explainability_methods and "features" not in explainability_methods["pdp"] ): raise ValueError("PDP features must be provided when ShapConfig is not provided") return explainability_methods + is_asym_shap_config = isinstance(explainability_config, AsymmetricSHAPConfig) + if time_series_case and not is_asym_shap_config: + raise ValueError( + "Please provide only Asymmetric SHAP configs for TimeSeries explainability." + ) + if not time_series_case and is_asym_shap_config: + raise ValueError( + "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." + ) if ( isinstance(explainability_config, PDPConfig) and "features" not in explainability_config.get_explainability_config()["pdp"] diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 58d3f56639..a935215685 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -17,22 +17,30 @@ import pytest from mock import MagicMock, Mock, patch +from typing import List, NamedTuple, Optional, Union from sagemaker import Processor, image_uris from sagemaker.clarify import ( BiasConfig, DataConfig, + TimeSeriesDataConfig, ModelConfig, + TimeSeriesModelConfig, ModelPredictedLabelConfig, PDPConfig, SageMakerClarifyProcessor, SHAPConfig, + AsymmetricSHAPConfig, TextConfig, ImageConfig, _AnalysisConfigGenerator, DatasetType, ProcessingOutputHandler, SegmentationConfig, + TS_MODEL_DEFAULT_FORECAST_HORIZON, + ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + ASYM_SHAP_DEFAULT_NUM_SAMPLES, + ASYM_SHAP_EXPLANATION_TYPES, ) JOB_NAME_PREFIX = "my-prefix" @@ -321,6 +329,241 @@ def test_s3_data_distribution_type_ignorance(): assert data_config.s3_data_distribution_type == "FullyReplicated" +class TimeSeriesDataConfigCase(NamedTuple): + target_time_series: Union[str, int] + item_id: Union[str, int] + timestamp: Union[str, int] + related_time_series: Optional[List[Union[str, int]]] + item_metadata: Optional[List[Union[str, int]]] + error: Exception + error_message: Optional[str] + + +class TestTimeSeriesDataConfig: + valid_ts_data_config_case_list = [ + TimeSeriesDataConfigCase( # no optional args provided + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # related_time_series provided + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=[1, 2, 3], + item_metadata=None, + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # item_metadata provided + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + item_metadata=["a", "b", "c", "d"], + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=[1, 2, 3], + item_metadata=["a", "b", "c", "d"], + error=None, + error_message=None, + ), + ] + + @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) + def test_time_series_data_config(self, test_case): + """ + GIVEN A set of valid parameters are given + WHEN A TimeSeriesDataConfig object is instantiated + THEN the returned config dictionary matches what's expected + """ + # construct expected output + expected_output = { + "target_time_series": test_case.target_time_series, + "item_id": test_case.item_id, + "timestamp": test_case.timestamp, + } + if test_case.related_time_series: + expected_output["related_time_series"] = test_case.related_time_series + if test_case.item_metadata: + expected_output["item_metadata"] = test_case.item_metadata + # GIVEN, WHEN + ts_data_config = TimeSeriesDataConfig( + target_time_series=test_case.target_time_series, + item_id=test_case.item_id, + timestamp=test_case.timestamp, + related_time_series=test_case.related_time_series, + item_metadata=test_case.item_metadata, + ) + # THEN + assert ts_data_config.analysis_config == expected_output + + @pytest.mark.parametrize( + "test_case", + [ + TimeSeriesDataConfigCase( # no target_time_series provided + target_time_series=None, + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=AssertionError, + error_message="Please provide a target time series.", + ), + TimeSeriesDataConfigCase( # no item_id provided + target_time_series="target_time_series", + item_id=None, + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=AssertionError, + error_message="Please provide an item id.", + ), + TimeSeriesDataConfigCase( # no timestamp provided + target_time_series="target_time_series", + item_id="item_id", + timestamp=None, + related_time_series=None, + item_metadata=None, + error=AssertionError, + error_message="Please provide a timestamp.", + ), + TimeSeriesDataConfigCase( # target_time_series not int or str + target_time_series=["target_time_series"], + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=ValueError, + error_message="Please provide a string or an int for ``target_time_series``", + ), + TimeSeriesDataConfigCase( # item_id not int or str + target_time_series="target_time_series", + item_id=["item_id"], + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=ValueError, + error_message="Please provide a string or an int for ``item_id``", + ), + TimeSeriesDataConfigCase( # timestamp not int or str + target_time_series="target_time_series", + item_id="item_id", + timestamp=["timestamp"], + related_time_series=None, + item_metadata=None, + error=ValueError, + error_message="Please provide a string or an int for ``timestamp``", + ), + TimeSeriesDataConfigCase( # related_time_series not list of ints or list of strs + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=5, + item_metadata=None, + error=ValueError, + error_message="Please provide a list of strings or list of ints for ``related_time_series``", + ), + TimeSeriesDataConfigCase( # item_metadata not list of ints or list of strs + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + item_metadata=[4, 5, 6.0], + error=ValueError, + error_message="Please provide a list of strings or list of ints for ``item_metadata``", + ), + ], + ) + def test_time_series_data_config_invalid(self, test_case): + """ + GIVEN required parameters are incomplete or invalid + WHEN TimeSeriesDataConfig constructor is called + THEN the expected error and message are raised + """ + with pytest.raises(test_case.error, match=test_case.error_message): + TimeSeriesDataConfig( + target_time_series=test_case.target_time_series, + item_id=test_case.item_id, + timestamp=test_case.timestamp, + related_time_series=test_case.related_time_series, + item_metadata=test_case.item_metadata, + ) + + @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) + def test_data_config_with_time_series(self, test_case): + """ + GIVEN a TimeSeriesDataConfig object is created + WHEN a DataConfig object is created and given valid params + the TimeSeriesDataConfig + THEN the internal config dictionary matches what's expected + """ + # setup + headers = ["Label", "F1", "F2", "F3", "F4", "Predicted Label"] + dataset_type = "application/json" + segment_config = [ + SegmentationConfig( + name_or_index="F1", + segments=[[0]], + config_name="c1", + display_aliases=["a1"], + ) + ] + # construct expected output + mock_ts_data_config_dict = { + "target_time_series": test_case.target_time_series, + "item_id": test_case.item_id, + "timestamp": test_case.timestamp, + } + if test_case.related_time_series: + mock_ts_data_config_dict["related_time_series"] = test_case.related_time_series + if test_case.item_metadata: + mock_ts_data_config_dict["item_metadata"] = test_case.item_metadata + expected_config = { + "dataset_type": dataset_type, + "headers": headers, + "label": "Label", + "segment_config": [ + { + "config_name": "c1", + "display_aliases": ["a1"], + "name_or_index": "F1", + "segments": [[0]], + } + ], + "excluded_columns": ["F4"], + "features": "[*].[F1,F2,F3]", + "predicted_label": "Predicted Label", + "time_series_data_config": mock_ts_data_config_dict, + } + # GIVEN + ts_data_config = Mock() + ts_data_config.get_config.return_value = copy.deepcopy(mock_ts_data_config_dict) + # WHEN + data_config = DataConfig( + s3_data_input_path="s3://path/to/input.csv", + s3_output_path="s3://path/to/output", + features="[*].[F1,F2,F3]", + label="Label", + headers=headers, + dataset_type="application/json", + excluded_columns=["F4"], + predicted_label="Predicted Label", + segmentation_config=segment_config, + time_series_data_config=ts_data_config, + ) + # THEN + assert expected_config == data_config.get_config() + + def test_bias_config(): label_values = [1] facet_name = "F1" @@ -641,6 +884,229 @@ def test_invalid_model_predicted_label_config(): ) +class TestTimeSeriesModelConfig: + def test_time_series_model_config(self): + """ + GIVEN a valid forecast expression + WHEN a TimeSeriesModelConfig is constructed with it + THEN the predictor_config dictionary matches the expected + """ + # GIVEN + forecast = "results.[forecast]" # mock JMESPath expression for forecast + # create expected output + expected_config = { + "forecast": forecast, + "forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON, + "use_future_covariates": False, + } + # WHEN + ts_model_config = TimeSeriesModelConfig( + forecast, + ) + # THEN + assert ts_model_config.predictor_config == expected_config + + def test_time_series_model_config_with_forecast_horizon(self): + """ + GIVEN a valid forecast expression and forecast horizon + WHEN a TimeSeriesModelConfig is constructed with it + THEN the predictor_config dictionary matches the expected + """ + # GIVEN + forecast = "results.[forecast]" # mock JMESPath expression for forecast + forecast_horizon = 25 # non-default forecast horizon + # create expected output + expected_config = { + "forecast": forecast, + "forecast_horizon": forecast_horizon, + "use_future_covariates": False, + } + # WHEN + ts_model_config = TimeSeriesModelConfig( + forecast, + forecast_horizon=forecast_horizon, + ) + # THEN + assert ts_model_config.predictor_config == expected_config + + def test_time_series_model_config_with_future_covariates(self): + """ + GIVEN a valid forecast expression + WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True + THEN the predictor_config dictionary matches the expected + """ + # GIVEN + forecast = "results.[forecast]" # mock JMESPath expression for forecast + # create expected output + expected_config = { + "forecast": forecast, + "forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON, + "use_future_covariates": True, + } + # WHEN + ts_model_config = TimeSeriesModelConfig( + forecast, + use_future_covariates=True, + ) + # THEN + assert ts_model_config.predictor_config == expected_config + + def test_time_series_model_config_with_horizon_and_covariates(self): + """ + GIVEN a valid forecast expression and forecast horizon + WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True + THEN the predictor_config dictionary matches the expected + """ + # GIVEN + forecast = "results.[forecast]" # mock JMESPath expression for forecast + forecast_horizon = 25 # non-default forecast horizon + # create expected output + expected_config = { + "forecast": forecast, + "forecast_horizon": forecast_horizon, + "use_future_covariates": True, + } + # WHEN + ts_model_config = TimeSeriesModelConfig( + forecast, + forecast_horizon=forecast_horizon, + use_future_covariates=True, + ) + # THEN + assert ts_model_config.predictor_config == expected_config + + @pytest.mark.parametrize( + ("forecast", "forecast_horizon", "use_future_covariates", "error", "error_message"), + [ + ( + None, + TS_MODEL_DEFAULT_FORECAST_HORIZON, + None, + AssertionError, + "Please provide ``forecast``, a JMESPath expression to extract the forecast result.", + ), + ( + "results.[forecast]", + None, + None, + AssertionError, + "Please provide an integer ``forecast_horizon``.", + ), + ( + 123, + TS_MODEL_DEFAULT_FORECAST_HORIZON, + None, + ValueError, + "Please provide a string JMESPath expression for ``forecast``.", + ), + ( + "results.[forecast]", + "Not an int", + None, + ValueError, + "Please provide an integer ``forecast_horizon``.", + ), + ( + "results.[forecast]", + TS_MODEL_DEFAULT_FORECAST_HORIZON, + "Not a bool", + ValueError, + "Please provide a boolean value for ``use_future_covariates``.", + ), + ], + ) + def test_time_series_model_config_invalid( + self, + forecast, + forecast_horizon, + use_future_covariates, + error, + error_message, + ): + """ + GIVEN invalid args for a TimeSeriesModelConfig + WHEN TimeSeriesModelConfig constructor is called + THEN The appropriate error is raised + """ + with pytest.raises(error, match=error_message): + TimeSeriesModelConfig( + forecast=forecast, + forecast_horizon=forecast_horizon, + use_future_covariates=use_future_covariates, + ) + + def test_model_config_with_time_series(self): + """ + GIVEN valid fields for a ModelConfig and a TimeSeriesModelConfig + WHEN a ModelConfig is constructed with them + THEN actual predictor_config matches expected + """ + # setup + model_name = "xgboost-model" + instance_type = "ml.c5.xlarge" + instance_count = 1 + custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4" + target_model = "target_model_name" + accelerator_type = "ml.eia1.medium" + content_type = "application/x-npy" + accept_type = "text/csv" + content_template = ( + '{"instances":$features}' + if content_type == "application/jsonlines" + else "$records" + if content_type == "application/json" + else None + ) + record_template = "$features_kvp" if content_type == "application/json" else None + # create mock config for TimeSeriesModelConfig + forecast = "results.[forecast]" # mock JMESPath expression for forecast + forecast_horizon = 25 # non-default forecast horizon + mock_ts_model_config_dict = { + "forecast": forecast, + "forecast_horizon": forecast_horizon, + "use_future_covariates": True, + } + mock_ts_model_config = Mock() + mock_ts_model_config.get_predictor_config.return_value = mock_ts_model_config_dict + # create expected config + expected_config = { + "model_name": model_name, + "instance_type": instance_type, + "initial_instance_count": instance_count, + "accept_type": accept_type, + "content_type": content_type, + "custom_attributes": custom_attributes, + "accelerator_type": accelerator_type, + "target_model": target_model, + "time_series_predictor_config": mock_ts_model_config_dict, + } + if content_template is not None: + expected_config["content_template"] = content_template + if record_template is not None: + expected_config["record_template"] = record_template + # GIVEN + mock_ts_model_config = Mock() # create mock TimeSeriesModelConfig object + mock_ts_model_config.get_predictor_config.return_value = copy.deepcopy( + mock_ts_model_config_dict + ) # set the mock's get_config return value + # WHEN + model_config = ModelConfig( + model_name=model_name, + instance_type=instance_type, + instance_count=instance_count, + accept_type=accept_type, + content_type=content_type, + content_template=content_template, + record_template=record_template, + custom_attributes=custom_attributes, + accelerator_type=accelerator_type, + target_model=target_model, + time_series_model_config=mock_ts_model_config, + ) + # THEN + assert expected_config == model_config.get_predictor_config() + + @pytest.mark.parametrize( "baseline", [ @@ -783,6 +1249,81 @@ def test_shap_config_no_parameters(): assert expected_config == shap_config.get_explainability_config() +class AsymmetricSHAPConfigCase(NamedTuple): + explanation_type: str + num_samples: Optional[int] + + +class TestAsymmetricSHAPConfig: + @pytest.mark.parametrize( + "test_case", + [ + AsymmetricSHAPConfigCase( # cases for different explanation types + explanation_type=explanation_type, + num_samples=ASYM_SHAP_DEFAULT_NUM_SAMPLES, + ) + for explanation_type in ASYM_SHAP_EXPLANATION_TYPES + ] + + [ + AsymmetricSHAPConfigCase( # case for non-default number of samples + explanation_type=ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + num_samples=50, + ), + ], + ) + def test_asymmetric_shap_config(self, test_case): + """ + GIVEN valid arguments for an AsymmetricSHAPConfig object + WHEN AsymmetricSHAPConfig object is instantiated with those arguments + THEN the asymmetric_shap_config dictionary matches expected + """ + # test case is GIVEN + # construct expected config + expected_config = { + "explanation_type": test_case.explanation_type, + "num_samples": test_case.num_samples, + } + # WHEN + asym_shap_config = AsymmetricSHAPConfig( + explanation_type=test_case.explanation_type, + num_samples=test_case.num_samples, + ) + # THEN + assert asym_shap_config.asymmetric_shap_config == expected_config + + def test_asymmetric_shap_config_invalid_explanation_type(self): + """ + GIVEN invalid explanation_type + WHEN AsymmetricSHAPConfig constructor is called with it + THEN ``AssertionError`` with correct message is raised + """ + # setup + error_message = "Please provide a valid explanation type from: " + ", ".join( + ASYM_SHAP_EXPLANATION_TYPES + ) + # GIVEN + explanation_type = "disaggregated_random" + # WHEN, THEN + with pytest.raises(AssertionError, match=error_message): + AsymmetricSHAPConfig( + explanation_type=explanation_type, + ) + + def test_asymmetric_shap_config_invalid_num_samples(self): + """ + GIVEN non-integer num_samples + WHEN AsymmetricSHAPConfig constructor is called with it + THEN ``ValueError`` with correct message is raised + """ + # setup + error_message = "Please provide an integer value for ``num_samples``." + # GIVEN + num_samples = "NaN" + # WHEN, THEN + with pytest.raises(ValueError, match=error_message): + AsymmetricSHAPConfig(num_samples=num_samples) + + def test_pdp_config(): pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20) expected_config = { @@ -1917,6 +2458,98 @@ def test_invalid_analysis_config(data_config, data_bias_config, model_config): ) +def _build_pdp_config_mock(): + pdp_config_dict = { + "pdp": { + "grid_resolution": 15, + "top_k_features": 10, + "features": [ + "some", + "features", + ], + } + } + pdp_config = Mock(spec=PDPConfig) + pdp_config.get_explainability_config.return_value = pdp_config_dict + return pdp_config + + +def _build_asymmetric_shap_config_mock(): + asym_shap_config_dict = { + "asymmetric_shap": { + "explanation_type": ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + "num_samples": ASYM_SHAP_DEFAULT_NUM_SAMPLES, + }, + } + asym_shap_config = Mock(spec=AsymmetricSHAPConfig) + asym_shap_config.get_explainability_config.return_value = asym_shap_config_dict + return asym_shap_config + + +class TestAnalysisConfigGeneratorForTimeSeriesExplainability: + @pytest.mark.parametrize( + ("mock_config", "time_series_case", "error", "error_message"), + [ + ( # single pdp config for TSX + _build_pdp_config_mock(), + True, + ValueError, + "Please provide only Asymmetric SHAP configs for TimeSeries explainability.", + ), + ( # single asym shap config for non TSX + _build_asymmetric_shap_config_mock(), + False, + ValueError, + "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", + ), + ( # list of duplicate asym_shap configs for TSX + [ + _build_asymmetric_shap_config_mock(), + _build_asymmetric_shap_config_mock(), + ], + True, + ValueError, + "Duplicate explainability configs are provided", + ), + ( # list with pdp config for TSX + [ + _build_asymmetric_shap_config_mock(), + _build_pdp_config_mock(), + ], + True, + ValueError, + "Please provide only Asymmetric SHAP configs for TimeSeries explainability.", + ), + ( # list with asym shap config for non-TSX + [ + _build_asymmetric_shap_config_mock(), + _build_pdp_config_mock(), + ], + False, + ValueError, + "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", + ), + ], + ) + def test_merge_explainability_configs_with_timeseries_invalid( + self, + mock_config, + time_series_case, + error, + error_message, + ): + """ + GIVEN _merge_explainability_configs is called with a explainability config or list thereof + WHEN the provided config(s) aren't the right type for the given case + THEN the function will raise the appropriate error + """ + with pytest.raises(error, match=error_message): + _AnalysisConfigGenerator._merge_explainability_configs( + explainability_config=mock_config, + time_series_case=time_series_case, + ) + + class TestProcessingOutputHandler: def test_get_s3_upload_mode_image(self): analysis_config = {"dataset_type": DatasetType.IMAGE.value} From 75204ade45dd29ca0b72ee6180d620a1eb4e9fcf Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 18 Oct 2023 19:47:46 +0000 Subject: [PATCH 02/35] fix: removed field use_future_covariates and related unit tests from TimeSeriesModelConfig --- src/sagemaker/clarify.py | 9 ------ tests/unit/test_clarify.py | 64 +------------------------------------- 2 files changed, 1 insertion(+), 72 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index efcef6bb13..cf42b52199 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -347,7 +347,6 @@ SchemaOptional("time_series_predictor_config"): { "forecast": str, "forecast_horizon": int, - SchemaOptional("use_future_covariates"): bool, }, }, } @@ -805,15 +804,12 @@ def __init__( self, forecast: str, forecast_horizon: int = TS_MODEL_DEFAULT_FORECAST_HORIZON, - use_future_covariates: Optional[bool] = False, ): """Initializes model configuration fields for TimeSeries explainability use cases. Args: forecast (str): JMESPath expression to extract the forecast result. forecast_horizon (int): An integer that tells the forecast horizon. - use_future_covariates (None or bool): If set as True, future covariates - included in model input and used for forecasting Raises: AssertionError: when either ``forecast`` or ``forecast_horizon`` are not provided @@ -829,15 +825,10 @@ def __init__( raise ValueError("Please provide a string JMESPath expression for ``forecast``.") if not isinstance(forecast_horizon, int): raise ValueError("Please provide an integer ``forecast_horizon``.") - if use_future_covariates and not isinstance(use_future_covariates, bool): - raise ValueError("Please provide a boolean value for ``use_future_covariates``.") # add fields to an internal config dictionary self.predictor_config = dict() _set(forecast, "forecast", self.predictor_config) _set(forecast_horizon, "forecast_horizon", self.predictor_config) - _set( - use_future_covariates, "use_future_covariates", self.predictor_config - ) # _set() does nothing if a given argument is None def get_predictor_config(self): """Returns TimeSeries predictor config dictionary""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index a935215685..cdc5165d07 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -897,7 +897,6 @@ def test_time_series_model_config(self): expected_config = { "forecast": forecast, "forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON, - "use_future_covariates": False, } # WHEN ts_model_config = TimeSeriesModelConfig( @@ -919,7 +918,6 @@ def test_time_series_model_config_with_forecast_horizon(self): expected_config = { "forecast": forecast, "forecast_horizon": forecast_horizon, - "use_future_covariates": False, } # WHEN ts_model_config = TimeSeriesModelConfig( @@ -929,97 +927,39 @@ def test_time_series_model_config_with_forecast_horizon(self): # THEN assert ts_model_config.predictor_config == expected_config - def test_time_series_model_config_with_future_covariates(self): - """ - GIVEN a valid forecast expression - WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True - THEN the predictor_config dictionary matches the expected - """ - # GIVEN - forecast = "results.[forecast]" # mock JMESPath expression for forecast - # create expected output - expected_config = { - "forecast": forecast, - "forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON, - "use_future_covariates": True, - } - # WHEN - ts_model_config = TimeSeriesModelConfig( - forecast, - use_future_covariates=True, - ) - # THEN - assert ts_model_config.predictor_config == expected_config - - def test_time_series_model_config_with_horizon_and_covariates(self): - """ - GIVEN a valid forecast expression and forecast horizon - WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True - THEN the predictor_config dictionary matches the expected - """ - # GIVEN - forecast = "results.[forecast]" # mock JMESPath expression for forecast - forecast_horizon = 25 # non-default forecast horizon - # create expected output - expected_config = { - "forecast": forecast, - "forecast_horizon": forecast_horizon, - "use_future_covariates": True, - } - # WHEN - ts_model_config = TimeSeriesModelConfig( - forecast, - forecast_horizon=forecast_horizon, - use_future_covariates=True, - ) - # THEN - assert ts_model_config.predictor_config == expected_config - @pytest.mark.parametrize( - ("forecast", "forecast_horizon", "use_future_covariates", "error", "error_message"), + ("forecast", "forecast_horizon", "error", "error_message"), [ ( None, TS_MODEL_DEFAULT_FORECAST_HORIZON, - None, AssertionError, "Please provide ``forecast``, a JMESPath expression to extract the forecast result.", ), ( "results.[forecast]", None, - None, AssertionError, "Please provide an integer ``forecast_horizon``.", ), ( 123, TS_MODEL_DEFAULT_FORECAST_HORIZON, - None, ValueError, "Please provide a string JMESPath expression for ``forecast``.", ), ( "results.[forecast]", "Not an int", - None, ValueError, "Please provide an integer ``forecast_horizon``.", ), - ( - "results.[forecast]", - TS_MODEL_DEFAULT_FORECAST_HORIZON, - "Not a bool", - ValueError, - "Please provide a boolean value for ``use_future_covariates``.", - ), ], ) def test_time_series_model_config_invalid( self, forecast, forecast_horizon, - use_future_covariates, error, error_message, ): @@ -1032,7 +972,6 @@ def test_time_series_model_config_invalid( TimeSeriesModelConfig( forecast=forecast, forecast_horizon=forecast_horizon, - use_future_covariates=use_future_covariates, ) def test_model_config_with_time_series(self): @@ -1064,7 +1003,6 @@ def test_model_config_with_time_series(self): mock_ts_model_config_dict = { "forecast": forecast, "forecast_horizon": forecast_horizon, - "use_future_covariates": True, } mock_ts_model_config = Mock() mock_ts_model_config.get_predictor_config.return_value = mock_ts_model_config_dict From f8852047e0ac574c125ba8167cf4794efa093d78 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 20 Oct 2023 06:11:54 +0000 Subject: [PATCH 03/35] change: rename ``TimeSeriesDataConfig.analysis_config`` to ``time_series_data_config`` --- src/sagemaker/clarify.py | 27 ++++++++++++++------------- tests/unit/test_clarify.py | 6 ++++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index cf42b52199..4d52e3e5d9 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -474,10 +474,10 @@ def __init__( if not isinstance(timestamp, (str, int)): raise ValueError("Please provide a string or an int for ``timestamp``") # add remaining fields to an internal dictionary - self.analysis_config = dict() - _set(target_time_series, "target_time_series", self.analysis_config) - _set(item_id, "item_id", self.analysis_config) - _set(timestamp, "timestamp", self.analysis_config) + self.time_series_data_config = dict() + _set(target_time_series, "target_time_series", self.time_series_data_config) + _set(item_id, "item_id", self.time_series_data_config) + _set(timestamp, "timestamp", self.time_series_data_config) # check optional arguments are right types if provided related_time_series_error_message = ( "Please provide a list of strings or list of ints for ``related_time_series``" @@ -495,7 +495,7 @@ def __init__( related_time_series_error_message ) # related_time_series is not a list of strings or list of ints _set( - related_time_series, "related_time_series", self.analysis_config + related_time_series, "related_time_series", self.time_series_data_config ) # related_time_series is valid, add it item_metadata_series_error_message = ( "Please provide a list of strings or list of ints for ``item_metadata``" @@ -511,12 +511,12 @@ def __init__( item_metadata_series_error_message ) # item_metadata is not a list of strings or list of ints _set( - item_metadata, "item_metadata", self.analysis_config + item_metadata, "item_metadata", self.time_series_data_config ) # item_metadata is valid, add it - def get_config(self): + def get_time_series_data_config(self): """Returns part of an analysis config dictionary.""" - return copy.deepcopy(self.analysis_config) + return copy.deepcopy(self.time_series_data_config) class DataConfig: @@ -702,11 +702,12 @@ def __init__( "segment_config", self.analysis_config, ) - _set( - time_series_data_config.get_config() if time_series_data_config else None, - "time_series_data_config", - self.analysis_config, - ) + if time_series_data_config: + _set( + time_series_data_config.get_time_series_data_config(), + "time_series_data_config", + self.analysis_config, + ) def get_config(self): """Returns part of an analysis config dictionary.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index cdc5165d07..0bf3bd8d40 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -405,7 +405,7 @@ def test_time_series_data_config(self, test_case): item_metadata=test_case.item_metadata, ) # THEN - assert ts_data_config.analysis_config == expected_output + assert ts_data_config.time_series_data_config == expected_output @pytest.mark.parametrize( "test_case", @@ -546,7 +546,9 @@ def test_data_config_with_time_series(self, test_case): } # GIVEN ts_data_config = Mock() - ts_data_config.get_config.return_value = copy.deepcopy(mock_ts_data_config_dict) + ts_data_config.get_time_series_data_config.return_value = copy.deepcopy( + mock_ts_data_config_dict + ) # WHEN data_config = DataConfig( s3_data_input_path="s3://path/to/input.csv", From 30d447282001042559de9e3ec83513bb73aa2269 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 20 Oct 2023 06:20:03 +0000 Subject: [PATCH 04/35] change: validate DataConfig content_type and accept_type for TS exp. change: renamed TimeSeriesDataConfig.predictor_config to time_series_model_config change: modified default value of forecast_horizon to 1 --- src/sagemaker/clarify.py | 36 ++++++++++++++++++++++++------------ tests/unit/test_clarify.py | 21 +++++++++++++-------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4d52e3e5d9..90ab39624d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -41,7 +41,7 @@ # TODO: verify these are sensible/sound values # timeseries predictor config default values -TS_MODEL_DEFAULT_FORECAST_HORIZON = 5 # predictor config +TS_MODEL_DEFAULT_FORECAST_HORIZON = 1 # predictor config # asymmetric shap default values (timeseries) ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "fine_grained" ASYM_SHAP_DEFAULT_NUM_SAMPLES = 5 @@ -827,13 +827,13 @@ def __init__( if not isinstance(forecast_horizon, int): raise ValueError("Please provide an integer ``forecast_horizon``.") # add fields to an internal config dictionary - self.predictor_config = dict() - _set(forecast, "forecast", self.predictor_config) - _set(forecast_horizon, "forecast_horizon", self.predictor_config) + self.time_series_model_config = dict() + _set(forecast, "forecast", self.time_series_model_config) + _set(forecast_horizon, "forecast_horizon", self.time_series_model_config) - def get_predictor_config(self): - """Returns TimeSeries predictor config dictionary""" - return copy.deepcopy(self.predictor_config) + def get_time_series_model_config(self): + """Returns TimeSeries model config dictionary""" + return copy.deepcopy(self.time_series_model_config) class ModelConfig: @@ -1017,6 +1017,10 @@ def __init__( f"Invalid accept_type {accept_type}." f" Please choose text/csv or application/jsonlines." ) + if time_series_model_config and accept_type == "text/csv": + raise ValueError( + "``accept_type`` must be JSON or JSONLines for time series explainability." + ) self.predictor_config["accept_type"] = accept_type if content_type is not None: if content_type not in [ @@ -1053,6 +1057,13 @@ def __init__( f"Invalid content_template {content_template}." f" Please include either placeholder $records or $record." ) + if time_series_model_config and content_type not in [ + "application/json", + "application/jsonlines" + ]: + raise ValueError( + "``content_type`` must be JSON or JSONLines for time series explainability." + ) self.predictor_config["content_type"] = content_type if content_template is not None: self.predictor_config["content_template"] = content_template @@ -1061,11 +1072,12 @@ def __init__( _set(custom_attributes, "custom_attributes", self.predictor_config) _set(accelerator_type, "accelerator_type", self.predictor_config) _set(target_model, "target_model", self.predictor_config) - _set( - time_series_model_config.get_predictor_config() if time_series_model_config else None, - "time_series_predictor_config", - self.predictor_config, - ) + if time_series_model_config: + _set( + time_series_model_config.get_time_series_model_config(), + "time_series_predictor_config", + self.predictor_config, + ) def get_predictor_config(self): """Returns part of the predictor dictionary of the analysis config.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0bf3bd8d40..95669d6f78 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -905,7 +905,7 @@ def test_time_series_model_config(self): forecast, ) # THEN - assert ts_model_config.predictor_config == expected_config + assert ts_model_config.time_series_model_config == expected_config def test_time_series_model_config_with_forecast_horizon(self): """ @@ -927,7 +927,7 @@ def test_time_series_model_config_with_forecast_horizon(self): forecast_horizon=forecast_horizon, ) # THEN - assert ts_model_config.predictor_config == expected_config + assert ts_model_config.time_series_model_config == expected_config @pytest.mark.parametrize( ("forecast", "forecast_horizon", "error", "error_message"), @@ -976,7 +976,16 @@ def test_time_series_model_config_invalid( forecast_horizon=forecast_horizon, ) - def test_model_config_with_time_series(self): + @pytest.mark.parametrize( + ("content_type", "accept_type"), + [ + ("application/json", "application/json"), + ("application/json", "application/jsonlines"), + ("application/jsonlines", "application/json"), + ("application/jsonlines", "application/jsonlines"), + ], + ) + def test_model_config_with_time_series(self, content_type, accept_type): """ GIVEN valid fields for a ModelConfig and a TimeSeriesModelConfig WHEN a ModelConfig is constructed with them @@ -989,8 +998,6 @@ def test_model_config_with_time_series(self): custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4" target_model = "target_model_name" accelerator_type = "ml.eia1.medium" - content_type = "application/x-npy" - accept_type = "text/csv" content_template = ( '{"instances":$features}' if content_type == "application/jsonlines" @@ -1006,8 +1013,6 @@ def test_model_config_with_time_series(self): "forecast": forecast, "forecast_horizon": forecast_horizon, } - mock_ts_model_config = Mock() - mock_ts_model_config.get_predictor_config.return_value = mock_ts_model_config_dict # create expected config expected_config = { "model_name": model_name, @@ -1026,7 +1031,7 @@ def test_model_config_with_time_series(self): expected_config["record_template"] = record_template # GIVEN mock_ts_model_config = Mock() # create mock TimeSeriesModelConfig object - mock_ts_model_config.get_predictor_config.return_value = copy.deepcopy( + mock_ts_model_config.get_time_series_model_config.return_value = copy.deepcopy( mock_ts_model_config_dict ) # set the mock's get_config return value # WHEN From 615ed867e5d6e6db8382ecd97f2dc748865f0656 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 20 Oct 2023 06:25:22 +0000 Subject: [PATCH 05/35] change: reworked validation in AsymmetricSHAPConfig change: removed default value for num_samples change: changed default value for explanation_type change: added more explicit type hint for explanation_type documentation: added more information for parameters change: reworked unit tests for AsymmetricSHAPConfig --- src/sagemaker/clarify.py | 32 +++++++++----- tests/unit/test_clarify.py | 85 +++++++++++++++++++++----------------- 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 90ab39624d..888c032982 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,7 +25,7 @@ import tempfile from abc import ABC, abstractmethod -from typing import List, Union, Dict, Optional, Any +from typing import List, Literal, Union, Dict, Optional, Any from enum import Enum from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex @@ -43,8 +43,7 @@ # timeseries predictor config default values TS_MODEL_DEFAULT_FORECAST_HORIZON = 1 # predictor config # asymmetric shap default values (timeseries) -ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "fine_grained" -ASYM_SHAP_DEFAULT_NUM_SAMPLES = 5 +ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" ASYM_SHAP_EXPLANATION_TYPES = [ "timewise_chronological", "timewise_anti_chronological", @@ -1598,29 +1597,40 @@ class AsymmetricSHAPConfig(ExplainabilityConfig): def __init__( self, - explanation_type: str = ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, - num_samples: Optional[int] = ASYM_SHAP_DEFAULT_NUM_SAMPLES, + explanation_type: Literal[ + "timewise_chronological", + "timewise_anti_chronological", + "timewise_bidirectional", + "fine_grained", + ] = ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + num_samples: Optional[int] = None, ): """Initialises config for asymmetric SHAP config. AsymmetricSHAPConfig is used specifically and only for TimeSeries explainability purposes. Args: - explanation_type (str): Type of explanation to be used + explanation_type (str): Type of explanation to be used. Available explanation + types are ``"timewise_chronological"``, ``"timewise_anti_chronological"``, + ``"timewise_bidirectional"``, and ``"fine_grained"``. num_samples (None or int): Number of samples to be used in the Asymmetric SHAP - algorithm. + algorithm. Only applicable when using ``"fine_grained"`` explanations. Raises: - AssertionError: when ``explanation_type`` is not valid + AssertionError: when ``explanation_type`` is not valid or ``num_samples`` + is not provided for fine-grained explanations + ValueError: when ``num_samples`` is provided for non fine-grained explanations """ self.asymmetric_shap_config = dict() # validate explanation type assert ( explanation_type in ASYM_SHAP_EXPLANATION_TYPES ), "Please provide a valid explanation type from: " + ", ".join(ASYM_SHAP_EXPLANATION_TYPES) - # validate num_samples if provided - if num_samples and not isinstance(num_samples, int): - raise ValueError("Please provide an integer value for ``num_samples``.") + # validate integer num_samples is provided when necessary + if explanation_type == "fine_grained": + assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``." + elif num_samples: # validate num_samples is not provided when unnecessary + raise ValueError("``num_samples`` is only used for fine-grained explanations.") # set explanation type and (if provided) num_samples in internal config dictionary _set(explanation_type, "explanation_type", self.asymmetric_shap_config) _set( diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 95669d6f78..db80ab4621 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -39,7 +39,6 @@ SegmentationConfig, TS_MODEL_DEFAULT_FORECAST_HORIZON, ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, - ASYM_SHAP_DEFAULT_NUM_SAMPLES, ASYM_SHAP_EXPLANATION_TYPES, ) @@ -1197,6 +1196,8 @@ def test_shap_config_no_parameters(): class AsymmetricSHAPConfigCase(NamedTuple): explanation_type: str num_samples: Optional[int] + error: Exception + error_message: str class TestAsymmetricSHAPConfig: @@ -1205,15 +1206,11 @@ class TestAsymmetricSHAPConfig: [ AsymmetricSHAPConfigCase( # cases for different explanation types explanation_type=explanation_type, - num_samples=ASYM_SHAP_DEFAULT_NUM_SAMPLES, + num_samples=1 if explanation_type == "fine_grained" else None, + error=None, + error_message=None, ) for explanation_type in ASYM_SHAP_EXPLANATION_TYPES - ] - + [ - AsymmetricSHAPConfigCase( # case for non-default number of samples - explanation_type=ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, - num_samples=50, - ), ], ) def test_asymmetric_shap_config(self, test_case): @@ -1225,9 +1222,10 @@ def test_asymmetric_shap_config(self, test_case): # test case is GIVEN # construct expected config expected_config = { - "explanation_type": test_case.explanation_type, - "num_samples": test_case.num_samples, + "explanation_type": test_case.explanation_type } + if test_case.explanation_type == "fine_grained": + expected_config["num_samples"] = test_case.num_samples # WHEN asym_shap_config = AsymmetricSHAPConfig( explanation_type=test_case.explanation_type, @@ -1236,38 +1234,49 @@ def test_asymmetric_shap_config(self, test_case): # THEN assert asym_shap_config.asymmetric_shap_config == expected_config - def test_asymmetric_shap_config_invalid_explanation_type(self): + @pytest.mark.parametrize( + "test_case", + [ + AsymmetricSHAPConfigCase( # case for invalid explanation_type + explanation_type="coarse_grained", + num_samples=None, + error=AssertionError, + error_message="Please provide a valid explanation type from: " + + ", ".join(ASYM_SHAP_EXPLANATION_TYPES), + ), + AsymmetricSHAPConfigCase( # case for fine_grained and no num_samples + explanation_type="fine_grained", + num_samples=None, + error=AssertionError, + error_message="Please provide an integer for ``num_samples``.", + ), + AsymmetricSHAPConfigCase( # case for fine_grained and non-integer num_samples + explanation_type="fine_grained", + num_samples="5", + error=AssertionError, + error_message="Please provide an integer for ``num_samples``.", + ), + AsymmetricSHAPConfigCase( # case for num_samples when non fine-grained explanation + explanation_type="timewise_chronological", + num_samples=5, + error=ValueError, + error_message="``num_samples`` is only used for fine-grained explanations.", + ), + ], + ) + def test_asymmetric_shap_config_invalid(self, test_case): """ - GIVEN invalid explanation_type - WHEN AsymmetricSHAPConfig constructor is called with it - THEN ``AssertionError`` with correct message is raised + GIVEN invalid parameters for AsymmetricSHAP + WHEN AsymmetricSHAPConfig constructor is called with them + THEN the expected error and message are raised """ - # setup - error_message = "Please provide a valid explanation type from: " + ", ".join( - ASYM_SHAP_EXPLANATION_TYPES - ) - # GIVEN - explanation_type = "disaggregated_random" - # WHEN, THEN - with pytest.raises(AssertionError, match=error_message): - AsymmetricSHAPConfig( - explanation_type=explanation_type, + # test case is GIVEN + with pytest.raises(test_case.error, match=test_case.error_message): # THEN + AsymmetricSHAPConfig( # WHEN + explanation_type=test_case.explanation_type, + num_samples=test_case.num_samples, ) - def test_asymmetric_shap_config_invalid_num_samples(self): - """ - GIVEN non-integer num_samples - WHEN AsymmetricSHAPConfig constructor is called with it - THEN ``ValueError`` with correct message is raised - """ - # setup - error_message = "Please provide an integer value for ``num_samples``." - # GIVEN - num_samples = "NaN" - # WHEN, THEN - with pytest.raises(ValueError, match=error_message): - AsymmetricSHAPConfig(num_samples=num_samples) - def test_pdp_config(): pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20) From 3af583edb91663f96c5e811e9839962b9ef9da42 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 20 Oct 2023 06:30:37 +0000 Subject: [PATCH 06/35] change: time series case no longer uses _merge_explainability_configs change: _merge_explainability_configs reordered to put validation first change: unit tests reworked --- src/sagemaker/clarify.py | 56 +++++++++++++++++--------------------- tests/unit/test_clarify.py | 31 +-------------------- 2 files changed, 26 insertions(+), 61 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 888c032982..000238ba73 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2462,9 +2462,9 @@ def _add_methods( """Extends analysis config with methods.""" # validate params = [pre_training_methods, post_training_methods, explainability_config] - if time_series_case and not explainability_config: + if time_series_case and not isinstance(explainability_config, AsymmetricSHAPConfig): raise AttributeError( - "At least one AsymmetricSHAPConfig must be provided for TimeSeriex explainability." + "Please provide one AsymmetricSHAPConfig for TimeSeries explainability." ) if not any(params): raise AttributeError( @@ -2491,9 +2491,12 @@ def _add_methods( analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} if explainability_config is not None: - explainability_methods = cls._merge_explainability_configs( - explainability_config, time_series_case - ) + if time_series_case: + explainability_methods = explainability_config.get_explainability_config() + else: + explainability_methods = cls._merge_explainability_configs( + explainability_config, + ) analysis_config["methods"] = { **analysis_config["methods"], **explainability_methods, @@ -2504,50 +2507,41 @@ def _add_methods( def _merge_explainability_configs( cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], - time_series_case: bool = False, ): """Merges explainability configs, when more than one.""" + # validation + if isinstance(explainability_config, AsymmetricSHAPConfig): + raise ValueError( + "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." + ) + if ( + isinstance(explainability_config, PDPConfig) + and "features" not in explainability_config.get_explainability_config()["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") if isinstance(explainability_config, list): - explainability_methods = {} if len(explainability_config) == 0: raise ValueError("Please provide at least one explainability config.") + # list validation for config in explainability_config: - # ensure all provided explainability configs - # are AsymmetricSHAPConfig in time series case - is_asym_shap_config = isinstance(config, AsymmetricSHAPConfig) - if time_series_case and not is_asym_shap_config: - raise ValueError( - "Please provide only Asymmetric SHAP configs for TimeSeries explainability." - ) - if not time_series_case and is_asym_shap_config: + # ensure all provided explainability configs are not AsymmetricSHAPConfig + if isinstance(config, AsymmetricSHAPConfig): raise ValueError( "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." ) + # main logic + explainability_methods = {} + for config in explainability_config: explain_config = config.get_explainability_config() explainability_methods.update(explain_config) if not len(explainability_methods) == len(explainability_config): raise ValueError("Duplicate explainability configs are provided") if ( - not time_series_case - and "shap" not in explainability_methods + "shap" not in explainability_methods and "features" not in explainability_methods["pdp"] ): raise ValueError("PDP features must be provided when ShapConfig is not provided") return explainability_methods - is_asym_shap_config = isinstance(explainability_config, AsymmetricSHAPConfig) - if time_series_case and not is_asym_shap_config: - raise ValueError( - "Please provide only Asymmetric SHAP configs for TimeSeries explainability." - ) - if not time_series_case and is_asym_shap_config: - raise ValueError( - "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." - ) - if ( - isinstance(explainability_config, PDPConfig) - and "features" not in explainability_config.get_explainability_config()["pdp"] - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") return explainability_config.get_explainability_config() diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index db80ab4621..6bf6c14e5e 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -2432,7 +2432,6 @@ def _build_asymmetric_shap_config_mock(): asym_shap_config_dict = { "asymmetric_shap": { "explanation_type": ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, - "num_samples": ASYM_SHAP_DEFAULT_NUM_SAMPLES, }, } asym_shap_config = Mock(spec=AsymmetricSHAPConfig) @@ -2442,44 +2441,18 @@ def _build_asymmetric_shap_config_mock(): class TestAnalysisConfigGeneratorForTimeSeriesExplainability: @pytest.mark.parametrize( - ("mock_config", "time_series_case", "error", "error_message"), + ("mock_config", "error", "error_message"), [ - ( # single pdp config for TSX - _build_pdp_config_mock(), - True, - ValueError, - "Please provide only Asymmetric SHAP configs for TimeSeries explainability.", - ), ( # single asym shap config for non TSX _build_asymmetric_shap_config_mock(), - False, ValueError, "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", ), - ( # list of duplicate asym_shap configs for TSX - [ - _build_asymmetric_shap_config_mock(), - _build_asymmetric_shap_config_mock(), - ], - True, - ValueError, - "Duplicate explainability configs are provided", - ), - ( # list with pdp config for TSX - [ - _build_asymmetric_shap_config_mock(), - _build_pdp_config_mock(), - ], - True, - ValueError, - "Please provide only Asymmetric SHAP configs for TimeSeries explainability.", - ), ( # list with asym shap config for non-TSX [ _build_asymmetric_shap_config_mock(), _build_pdp_config_mock(), ], - False, ValueError, "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", ), @@ -2488,7 +2461,6 @@ class TestAnalysisConfigGeneratorForTimeSeriesExplainability: def test_merge_explainability_configs_with_timeseries_invalid( self, mock_config, - time_series_case, error, error_message, ): @@ -2500,7 +2472,6 @@ def test_merge_explainability_configs_with_timeseries_invalid( with pytest.raises(error, match=error_message): _AnalysisConfigGenerator._merge_explainability_configs( explainability_config=mock_config, - time_series_case=time_series_case, ) From 39c485c9f106a3aff93256bd09bb3d5efa994736 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 20 Oct 2023 06:41:56 +0000 Subject: [PATCH 07/35] fix: minor style changes to meet formatting reqs --- src/sagemaker/clarify.py | 2 +- tests/unit/test_clarify.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 000238ba73..fa2997b6cf 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1058,7 +1058,7 @@ def __init__( ) if time_series_model_config and content_type not in [ "application/json", - "application/jsonlines" + "application/jsonlines", ]: raise ValueError( "``content_type`` must be JSON or JSONLines for time series explainability." diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 6bf6c14e5e..9e40d50e8e 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1221,9 +1221,7 @@ def test_asymmetric_shap_config(self, test_case): """ # test case is GIVEN # construct expected config - expected_config = { - "explanation_type": test_case.explanation_type - } + expected_config = {"explanation_type": test_case.explanation_type} if test_case.explanation_type == "fine_grained": expected_config["num_samples"] = test_case.num_samples # WHEN From 9712b73627b99942a5ce7ae36f8983ca79bed0e4 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Sat, 21 Oct 2023 00:25:23 +0000 Subject: [PATCH 08/35] change: modified how time_series_case flag is set change: removed now-redundant check in time_series_case --- src/sagemaker/clarify.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index fa2997b6cf..779f60047e 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2340,12 +2340,11 @@ def explainability( ts_data_config_present = "time_series_data_config" in data_config.analysis_config ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config - if ts_data_config_present and ts_model_config_present: + if isinstance(explainability_config, AsymmetricSHAPConfig): + assert ts_data_config_present, "Please provide a TimeSeriesDataConfig" + assert ts_model_config_present, "Please provide a TimeSeriesModelConfig" time_series_case = True - elif not ts_data_config_present and not ts_model_config_present: - time_series_case = False - else: - raise ValueError("Please provide both TimeSeriesDataConfig and TimeSeriesModelConfig.") + # construct whole analysis config analysis_config = data_config.analysis_config analysis_config = cls._add_predictor( @@ -2462,10 +2461,6 @@ def _add_methods( """Extends analysis config with methods.""" # validate params = [pre_training_methods, post_training_methods, explainability_config] - if time_series_case and not isinstance(explainability_config, AsymmetricSHAPConfig): - raise AttributeError( - "Please provide one AsymmetricSHAPConfig for TimeSeries explainability." - ) if not any(params): raise AttributeError( "analysis_config must have at least one working method: " From 99e8630fa5f9de805e85f71cd984c09ca647058c Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Sat, 21 Oct 2023 01:05:57 +0000 Subject: [PATCH 09/35] fix: set time_series_case to False to prevent exception --- src/sagemaker/clarify.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 779f60047e..51ac87eb5b 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2344,6 +2344,8 @@ def explainability( assert ts_data_config_present, "Please provide a TimeSeriesDataConfig" assert ts_model_config_present, "Please provide a TimeSeriesModelConfig" time_series_case = True + else: + time_series_case = False # construct whole analysis config analysis_config = data_config.analysis_config From cc636f72fa7c761229885d4f1a6e7b0cd22c2b97 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Sat, 21 Oct 2023 01:06:40 +0000 Subject: [PATCH 10/35] change: params for ``TimeSeriesDataConfig`` now must all be same type change: updated ``TimeSeriesDataConfig`` unit tests to reflect above change --- src/sagemaker/clarify.py | 27 +++++----- tests/unit/test_clarify.py | 106 +++++++++++++++++++++++++++++++------ 2 files changed, 101 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 51ac87eb5b..ba0976955c 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -445,7 +445,9 @@ def __init__( Args: #TODO: verify param descriptions are accurate target_time_series (str or int): A string or a zero-based integer index. - Used to locate target time series in the shared input dataset. + Used to locate target time series in the shared input dataset. Also + used to determine the correct type of all succeeding arguments, as + parameters must either be all strings or all ints. item_id (str or int): A string or a zero-based integer index. Used to locate item id in the shared input dataset. timestamp (str or int): A string or a zero-based integer index. Used to @@ -468,10 +470,11 @@ def __init__( # check all arguments are the right types if not isinstance(target_time_series, (str, int)): raise ValueError("Please provide a string or an int for ``target_time_series``") - if not isinstance(item_id, (str, int)): - raise ValueError("Please provide a string or an int for ``item_id``") - if not isinstance(timestamp, (str, int)): - raise ValueError("Please provide a string or an int for ``timestamp``") + params_type = type(target_time_series) + if not isinstance(item_id, params_type): + raise ValueError(f"Please provide {params_type} for ``item_id``") + if not isinstance(timestamp, params_type): + raise ValueError(f"Please provide {params_type} for ``timestamp``") # add remaining fields to an internal dictionary self.time_series_data_config = dict() _set(target_time_series, "target_time_series", self.time_series_data_config) @@ -479,17 +482,14 @@ def __init__( _set(timestamp, "timestamp", self.time_series_data_config) # check optional arguments are right types if provided related_time_series_error_message = ( - "Please provide a list of strings or list of ints for ``related_time_series``" + f"Please provide a list of {params_type} for ``related_time_series``" ) if related_time_series: if not isinstance(related_time_series, list): raise ValueError( related_time_series_error_message ) # related_time_series is not a list - if not ( - all([isinstance(value, str) for value in related_time_series]) - or all([isinstance(value, int) for value in related_time_series]) - ): + if not all([isinstance(value, params_type) for value in related_time_series]): raise ValueError( related_time_series_error_message ) # related_time_series is not a list of strings or list of ints @@ -497,15 +497,12 @@ def __init__( related_time_series, "related_time_series", self.time_series_data_config ) # related_time_series is valid, add it item_metadata_series_error_message = ( - "Please provide a list of strings or list of ints for ``item_metadata``" + f"Please provide a list of {params_type} for ``item_metadata``" ) if item_metadata: if not isinstance(item_metadata, list): raise ValueError(item_metadata_series_error_message) # item_metadata is not a list - if not ( - all([isinstance(value, str) for value in item_metadata]) - or all([isinstance(value, int) for value in item_metadata]) - ): + if not all([isinstance(value, params_type) for value in item_metadata]): raise ValueError( item_metadata_series_error_message ) # item_metadata is not a list of strings or list of ints diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 9e40d50e8e..7d48877337 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -340,7 +340,7 @@ class TimeSeriesDataConfigCase(NamedTuple): class TestTimeSeriesDataConfig: valid_ts_data_config_case_list = [ - TimeSeriesDataConfigCase( # no optional args provided + TimeSeriesDataConfigCase( # no optional args provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", @@ -349,16 +349,16 @@ class TestTimeSeriesDataConfig: error=None, error_message=None, ), - TimeSeriesDataConfigCase( # related_time_series provided + TimeSeriesDataConfigCase( # related_time_series provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", - related_time_series=[1, 2, 3], + related_time_series=["ts1", "ts2", "ts3"], item_metadata=None, error=None, error_message=None, ), - TimeSeriesDataConfigCase( # item_metadata provided + TimeSeriesDataConfigCase( # item_metadata provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", @@ -367,15 +367,51 @@ class TestTimeSeriesDataConfig: error=None, error_message=None, ), - TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided + TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", - related_time_series=[1, 2, 3], + related_time_series=["ts1", "ts2", "ts3"], item_metadata=["a", "b", "c", "d"], error=None, error_message=None, ), + TimeSeriesDataConfigCase( # no optional args provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=None, + item_metadata=None, + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # related_time_series provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + item_metadata=None, + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # item_metadata provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=None, + item_metadata=[7, 8, 9, 10], + error=None, + error_message=None, + ), + TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + item_metadata=[7, 8, 9, 10], + error=None, + error_message=None, + ), ] @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) @@ -445,41 +481,77 @@ def test_time_series_data_config(self, test_case): error=ValueError, error_message="Please provide a string or an int for ``target_time_series``", ), - TimeSeriesDataConfigCase( # item_id not int or str + TimeSeriesDataConfigCase( # item_id differing type from str target_time_series target_time_series="target_time_series", - item_id=["item_id"], + item_id=5, timestamp="timestamp", related_time_series=None, item_metadata=None, error=ValueError, - error_message="Please provide a string or an int for ``item_id``", + error_message=f"Please provide {str} for ``item_id``", ), - TimeSeriesDataConfigCase( # timestamp not int or str + TimeSeriesDataConfigCase( # timestamp differing type from str target_time_series target_time_series="target_time_series", item_id="item_id", - timestamp=["timestamp"], + timestamp=10, related_time_series=None, item_metadata=None, error=ValueError, - error_message="Please provide a string or an int for ``timestamp``", + error_message=f"Please provide {str} for ``timestamp``", ), - TimeSeriesDataConfigCase( # related_time_series not list of ints or list of strs + TimeSeriesDataConfigCase( # related_time_series not str list if str target_time_series target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", - related_time_series=5, + related_time_series=["ts1", "ts2", "ts3", 4], item_metadata=None, error=ValueError, - error_message="Please provide a list of strings or list of ints for ``related_time_series``", + error_message=f"Please provide a list of {str} for ``related_time_series``", ), - TimeSeriesDataConfigCase( # item_metadata not list of ints or list of strs + TimeSeriesDataConfigCase( # item_metadata not str list if str target_time_series target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", related_time_series=None, item_metadata=[4, 5, 6.0], error=ValueError, - error_message="Please provide a list of strings or list of ints for ``item_metadata``", + error_message=f"Please provide a list of {str} for ``item_metadata``", + ), + TimeSeriesDataConfigCase( # item_id differing type from int target_time_series + target_time_series=1, + item_id="item_id", + timestamp=3, + related_time_series=None, + item_metadata=None, + error=ValueError, + error_message=f"Please provide {int} for ``item_id``", + ), + TimeSeriesDataConfigCase( # timestamp differing type from int target_time_series + target_time_series=1, + item_id=2, + timestamp="timestamp", + related_time_series=None, + item_metadata=None, + error=ValueError, + error_message=f"Please provide {int} for ``timestamp``", + ), + TimeSeriesDataConfigCase( # related_time_series not int list if int target_time_series + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6, "ts7"], + item_metadata=None, + error=ValueError, + error_message=f"Please provide a list of {int} for ``related_time_series``", + ), + TimeSeriesDataConfigCase( # item_metadata not int list if int target_time_series + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6, 7], + item_metadata=[8, 9, "10"], + error=ValueError, + error_message=f"Please provide a list of {int} for ``item_metadata``", ), ], ) From cf8512e91c24b4e8048666f307614adbf51cd833 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Tue, 24 Oct 2023 19:23:18 +0000 Subject: [PATCH 11/35] fix: schema entries for related_ts and item_metadata to keep list items same type --- src/sagemaker/clarify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index ba0976955c..ebe29e0d75 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -100,8 +100,8 @@ "target_time_series": Or(str, int), "item_id": Or(str, int), "timestamp": Or(str, int), - SchemaOptional("related_time_series"): [Or(int, str)], - SchemaOptional("item_metadata"): [Or(int, str)], + SchemaOptional("related_time_series"): Or([str], [int]), + SchemaOptional("item_metadata"): Or([str], [int]), }, "methods": { SchemaOptional("shap"): { From ec1d7fa2fe68b32c1658a79ebc63c912a42e0c7d Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Tue, 24 Oct 2023 19:24:01 +0000 Subject: [PATCH 12/35] change: remove forecast_horizon from TimeSeriesModelConfig --- src/sagemaker/clarify.py | 9 -------- tests/unit/test_clarify.py | 44 +------------------------------------- 2 files changed, 1 insertion(+), 52 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index ebe29e0d75..19579f40ec 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -40,8 +40,6 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" # TODO: verify these are sensible/sound values -# timeseries predictor config default values -TS_MODEL_DEFAULT_FORECAST_HORIZON = 1 # predictor config # asymmetric shap default values (timeseries) ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" ASYM_SHAP_EXPLANATION_TYPES = [ @@ -345,7 +343,6 @@ SchemaOptional("custom_attributes"): str, SchemaOptional("time_series_predictor_config"): { "forecast": str, - "forecast_horizon": int, }, }, } @@ -800,13 +797,11 @@ class TimeSeriesModelConfig: def __init__( self, forecast: str, - forecast_horizon: int = TS_MODEL_DEFAULT_FORECAST_HORIZON, ): """Initializes model configuration fields for TimeSeries explainability use cases. Args: forecast (str): JMESPath expression to extract the forecast result. - forecast_horizon (int): An integer that tells the forecast horizon. Raises: AssertionError: when either ``forecast`` or ``forecast_horizon`` are not provided @@ -816,16 +811,12 @@ def __init__( assert ( forecast ), "Please provide ``forecast``, a JMESPath expression to extract the forecast result." - assert forecast_horizon, "Please provide an integer ``forecast_horizon``." # check provided arguments are of the right type if not isinstance(forecast, str): raise ValueError("Please provide a string JMESPath expression for ``forecast``.") - if not isinstance(forecast_horizon, int): - raise ValueError("Please provide an integer ``forecast_horizon``.") # add fields to an internal config dictionary self.time_series_model_config = dict() _set(forecast, "forecast", self.time_series_model_config) - _set(forecast_horizon, "forecast_horizon", self.time_series_model_config) def get_time_series_model_config(self): """Returns TimeSeries model config dictionary""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7d48877337..11336a542f 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -37,7 +37,6 @@ DatasetType, ProcessingOutputHandler, SegmentationConfig, - TS_MODEL_DEFAULT_FORECAST_HORIZON, ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, ASYM_SHAP_EXPLANATION_TYPES, ) @@ -969,7 +968,6 @@ def test_time_series_model_config(self): # create expected output expected_config = { "forecast": forecast, - "forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON, } # WHEN ts_model_config = TimeSeriesModelConfig( @@ -978,61 +976,24 @@ def test_time_series_model_config(self): # THEN assert ts_model_config.time_series_model_config == expected_config - def test_time_series_model_config_with_forecast_horizon(self): - """ - GIVEN a valid forecast expression and forecast horizon - WHEN a TimeSeriesModelConfig is constructed with it - THEN the predictor_config dictionary matches the expected - """ - # GIVEN - forecast = "results.[forecast]" # mock JMESPath expression for forecast - forecast_horizon = 25 # non-default forecast horizon - # create expected output - expected_config = { - "forecast": forecast, - "forecast_horizon": forecast_horizon, - } - # WHEN - ts_model_config = TimeSeriesModelConfig( - forecast, - forecast_horizon=forecast_horizon, - ) - # THEN - assert ts_model_config.time_series_model_config == expected_config - @pytest.mark.parametrize( - ("forecast", "forecast_horizon", "error", "error_message"), + ("forecast", "error", "error_message"), [ ( None, - TS_MODEL_DEFAULT_FORECAST_HORIZON, AssertionError, "Please provide ``forecast``, a JMESPath expression to extract the forecast result.", ), - ( - "results.[forecast]", - None, - AssertionError, - "Please provide an integer ``forecast_horizon``.", - ), ( 123, - TS_MODEL_DEFAULT_FORECAST_HORIZON, ValueError, "Please provide a string JMESPath expression for ``forecast``.", ), - ( - "results.[forecast]", - "Not an int", - ValueError, - "Please provide an integer ``forecast_horizon``.", - ), ], ) def test_time_series_model_config_invalid( self, forecast, - forecast_horizon, error, error_message, ): @@ -1044,7 +1005,6 @@ def test_time_series_model_config_invalid( with pytest.raises(error, match=error_message): TimeSeriesModelConfig( forecast=forecast, - forecast_horizon=forecast_horizon, ) @pytest.mark.parametrize( @@ -1079,10 +1039,8 @@ def test_model_config_with_time_series(self, content_type, accept_type): record_template = "$features_kvp" if content_type == "application/json" else None # create mock config for TimeSeriesModelConfig forecast = "results.[forecast]" # mock JMESPath expression for forecast - forecast_horizon = 25 # non-default forecast horizon mock_ts_model_config_dict = { "forecast": forecast, - "forecast_horizon": forecast_horizon, } # create expected config expected_config = { From 61b02f8c3eb7759a0fe6fafc15df24dac2fcf25a Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Tue, 24 Oct 2023 19:58:54 +0000 Subject: [PATCH 13/35] fix: modified type hints in ``TimeSeriesDataConfig`` to match schema --- src/sagemaker/clarify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 19579f40ec..1e93e0cbeb 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -435,8 +435,8 @@ def __init__( target_time_series: Union[str, int], item_id: Union[str, int], timestamp: Union[str, int], - related_time_series: Optional[List[Union[str, int]]] = None, - item_metadata: Optional[List[Union[str, int]]] = None, + related_time_series: Optional[Union[List[str], List[int]]] = None, + item_metadata: Optional[Union[List[str], List[int]]] = None, ): """Initialises TimeSeries explainability data configuration fields. From 0c28c4f3af94a3763c26ab7e25cbc51366604f3b Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 25 Oct 2023 22:40:53 +0000 Subject: [PATCH 14/35] change: add errors when ts data or model config are given but no asym_shap config change: add unit tests for _AnalysisConfigGenerator.explainability fix: slightly modify AsymmetricSHAPConfig mock builder function documentation: minor docstring update in tests --- src/sagemaker/clarify.py | 18 +++-- tests/unit/test_clarify.py | 140 +++++++++++++++++++++++++++++++++++-- 2 files changed, 148 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 1e93e0cbeb..e7798d7795 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -435,8 +435,8 @@ def __init__( target_time_series: Union[str, int], item_id: Union[str, int], timestamp: Union[str, int], - related_time_series: Optional[Union[List[str], List[int]]] = None, - item_metadata: Optional[Union[List[str], List[int]]] = None, + related_time_series: Optional[List[Union[str, int]]] = None, + item_metadata: Optional[List[Union[str, int]]] = None, ): """Initialises TimeSeries explainability data configuration fields. @@ -2329,10 +2329,20 @@ def explainability( ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config if isinstance(explainability_config, AsymmetricSHAPConfig): - assert ts_data_config_present, "Please provide a TimeSeriesDataConfig" - assert ts_model_config_present, "Please provide a TimeSeriesModelConfig" + assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." + assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." time_series_case = True else: + if ts_data_config_present: + raise ValueError( + "Please provide an AsymmetricSHAPConfig for time series explainability cases." + "For non time series cases, please do not provide a TimeSeriesDataConfig." + ) + if ts_model_config_present: + raise ValueError( + "Please provide an AsymmetricSHAPConfig for time series explainability cases." + "For non time series cases, please do not provide a TimeSeriesModelConfig." + ) time_series_case = False # construct whole analysis config diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 11336a542f..b0f3fa953b 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -18,6 +18,7 @@ import pytest from mock import MagicMock, Mock, patch from typing import List, NamedTuple, Optional, Union +from unittest.mock import ANY from sagemaker import Processor, image_uris from sagemaker.clarify import ( @@ -37,7 +38,6 @@ DatasetType, ProcessingOutputHandler, SegmentationConfig, - ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, ASYM_SHAP_EXPLANATION_TYPES, ) @@ -2458,16 +2458,144 @@ def _build_pdp_config_mock(): def _build_asymmetric_shap_config_mock(): asym_shap_config_dict = { - "asymmetric_shap": { - "explanation_type": ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, - }, + "explanation_type": "fine_grained", + "num_samples": 20, } asym_shap_config = Mock(spec=AsymmetricSHAPConfig) - asym_shap_config.get_explainability_config.return_value = asym_shap_config_dict + asym_shap_config.get_explainability_config.return_value = { + "asymmetric_shap": asym_shap_config_dict + } return asym_shap_config +def _build_data_config_mock(): + """ + Builds a mock DataConfig for the time series _AnalysisConfigGenerator unit tests. + """ + # setup a time_series_data_config dictionary + time_series_data_config = { + "target_time_series": 1, + "item_id": 2, + "timestamp": 3, + "related_time_series": [4, 5, 6], + "item_metadata": [7, 8, 9, 10], + } + # setup DataConfig mock + data_config = Mock(spec=DataConfig) + data_config.analysis_config = {"time_series_data_config": time_series_data_config} + return data_config + + +def _build_model_config_mock(): + """ + Builds a mock ModelConfig for the time series _AnalysisConfigGenerator unit tests. + """ + time_series_model_config = {"forecast": "mean"} + model_config = Mock(spec=ModelConfig) + model_config.predictor_config = {"time_series_predictor_config": time_series_model_config} + return model_config + + class TestAnalysisConfigGeneratorForTimeSeriesExplainability: + @patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods") + @patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor") + def test_explainability_for_time_series(self, _add_predictor, _add_methods): + """ + GIVEN a valid DataConfig and ModelConfig that contain time_series_data_config and + time_series_model_config respectively as well as an AsymmetricSHAPConfig + WHEN _AnalysisConfigGenerator.explainability() is called with those args + THEN _add_predictor and _add methods calls are as expected + """ + # GIVEN + # get DataConfig mock + data_config_mock = _build_data_config_mock() + # get ModelConfig mock + model_config_mock = _build_model_config_mock() + # get AsymmetricSHAPConfig mock for explainability_config + explainability_config = _build_asymmetric_shap_config_mock() + # get time_series_data_config dict from mock + time_series_data_config = copy.deepcopy( + data_config_mock.analysis_config.get("time_series_data_config") + ) + # get time_series_predictor_config from mock + time_series_model_config = copy.deepcopy( + model_config_mock.predictor_config.get("time_series_model_config") + ) + # setup _add_predictor call to return what would be expected at that stage + analysis_config_after_add_predictor = { + "time_series_data_config": time_series_data_config, + "time_series_predictor_config": time_series_model_config, + } + _add_predictor.return_value = analysis_config_after_add_predictor + # WHEN + _AnalysisConfigGenerator.explainability( + data_config=data_config_mock, + model_config=model_config_mock, + model_predicted_label_config=None, + explainability_config=explainability_config, + ) + # THEN + _add_predictor.assert_called_once_with( + data_config_mock.analysis_config, + model_config_mock, + ANY, + ) + _add_methods.assert_called_once_with( + ANY, + explainability_config=explainability_config, + time_series_case=True, + ) + + def test_explainability_for_time_series_invalid(self): + # data config mocks + data_config_with_ts = _build_data_config_mock() + data_config_without_ts = Mock(spec=DataConfig) + data_config_without_ts.analysis_config = dict() + # model config mocks + model_config_with_ts = _build_model_config_mock() + model_config_without_ts = Mock(spec=ModelConfig) + model_config_without_ts.predictor_config = dict() + # asymmetric shap config mock (for ts) + asym_shap_config_mock = _build_asymmetric_shap_config_mock() + # pdp config mock (for non-ts) + pdp_config_mock = _build_pdp_config_mock() + # case 1: asymmetric shap (ts case) and no timeseries data config given + with pytest.raises( + AssertionError, match="Please provide a TimeSeriesDataConfig to DataConfig." + ): + _AnalysisConfigGenerator.explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_config_mock, + ) + # case 2: asymmetric shap (ts case) and no timeseries model config given + with pytest.raises( + AssertionError, match="Please provide a TimeSeriesModelConfig to ModelConfig." + ): + _AnalysisConfigGenerator.explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_config_mock, + ) + # case 3: pdp (non ts case) and timeseries data config given + with pytest.raises(ValueError, match="please do not provide a TimeSeriesDataConfig."): + _AnalysisConfigGenerator.explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + ) + # case 4: pdp (non ts case) and timeseries model config given + with pytest.raises(ValueError, match="please do not provide a TimeSeriesModelConfig."): + _AnalysisConfigGenerator.explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + ) + @pytest.mark.parametrize( ("mock_config", "error", "error_message"), [ @@ -2494,7 +2622,7 @@ def test_merge_explainability_configs_with_timeseries_invalid( ): """ GIVEN _merge_explainability_configs is called with a explainability config or list thereof - WHEN the provided config(s) aren't the right type for the given case + WHEN explainability_config is or contains an AsymmetricSHAPConfig THEN the function will raise the appropriate error """ with pytest.raises(error, match=error_message): From 882dcd93375fe4bad9b2862a9b5a16a60f9cc87e Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 26 Oct 2023 00:06:33 +0000 Subject: [PATCH 15/35] change: remove flag ``time_series_case``, modify unit tests accordingly change: add a check to _AnalysisConfiGenerator.bias_and_explainability to prevent time series components being provided, added unit tests for this --- src/sagemaker/clarify.py | 13 ++++++---- tests/unit/test_clarify.py | 49 +++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index e7798d7795..d01e463510 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2302,6 +2302,13 @@ def bias_and_explainability( post_training_methods: Union[str, List[str]] = "all", ): """Generates a config for Bias and Explainability""" + # TimeSeries bias metrics are not supported + if ( + isinstance(explainability_config, AsymmetricSHAPConfig) + or "time_series_data_config" in data_config.analysis_config + or (model_config and "time_series_predictor_config" in model_config.predictor_config) + ): + raise ValueError("Bias metrics are unsupported for time series.") analysis_config = {**data_config.get_config(), **bias_config.get_config()} analysis_config = cls._add_methods( analysis_config, @@ -2331,7 +2338,6 @@ def explainability( if isinstance(explainability_config, AsymmetricSHAPConfig): assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." - time_series_case = True else: if ts_data_config_present: raise ValueError( @@ -2343,7 +2349,6 @@ def explainability( "Please provide an AsymmetricSHAPConfig for time series explainability cases." "For non time series cases, please do not provide a TimeSeriesModelConfig." ) - time_series_case = False # construct whole analysis config analysis_config = data_config.analysis_config @@ -2353,7 +2358,6 @@ def explainability( analysis_config = cls._add_methods( analysis_config, explainability_config=explainability_config, - time_series_case=time_series_case, ) return analysis_config @@ -2456,7 +2460,6 @@ def _add_methods( post_training_methods: Union[str, List[str]] = None, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None, report: bool = True, - time_series_case: bool = False, ): """Extends analysis config with methods.""" # validate @@ -2486,7 +2489,7 @@ def _add_methods( analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} if explainability_config is not None: - if time_series_case: + if isinstance(explainability_config, AsymmetricSHAPConfig): explainability_methods = explainability_config.get_explainability_config() else: explainability_methods = cls._merge_explainability_configs( diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index b0f3fa953b..0a2f806a06 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -2543,7 +2543,6 @@ def test_explainability_for_time_series(self, _add_predictor, _add_methods): _add_methods.assert_called_once_with( ANY, explainability_config=explainability_config, - time_series_case=True, ) def test_explainability_for_time_series_invalid(self): @@ -2596,6 +2595,54 @@ def test_explainability_for_time_series_invalid(self): explainability_config=pdp_config_mock, ) + def test_bias_and_explainability_invalid_for_time_series(self): + """ + GIVEN user provides TimeSeriesDataConfig, TimeSeriesModelConfig, and/or + AsymmetricSHAPConfig for DataConfig, ModelConfig, and as explainability_config + respectively + WHEN _AnalysisConfigGenerator.bias_and_explainability is called + THEN the appropriate error is raised + """ + # data config mocks + data_config_with_ts = _build_data_config_mock() + data_config_without_ts = Mock(spec=DataConfig) + data_config_without_ts.analysis_config = dict() + # model config mocks + model_config_with_ts = _build_model_config_mock() + model_config_without_ts = Mock(spec=ModelConfig) + model_config_without_ts.predictor_config = dict() + # asymmetric shap config mock (for ts) + asym_shap_config_mock = _build_asymmetric_shap_config_mock() + # pdp config mock (for non-ts) + pdp_config_mock = _build_pdp_config_mock() + # case 1: asymmetric shap is given as explainability_config + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_without_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_config_mock, + bias_config=None, + ) + # case 2: TimeSeriesModelConfig given to ModelConfig + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + bias_config=None, + ) + # case 3: TimeSeriesDataConfig given to DataConfig + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + bias_config=None, + ) + @pytest.mark.parametrize( ("mock_config", "error", "error_message"), [ From a4f7170be543d5ab0cb751f9cbdb269d198748eb Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 26 Oct 2023 18:35:33 +0000 Subject: [PATCH 16/35] change: rename `AsymmetricSHAPConfig` to `AsymmetricShapleyValueConfig` --- src/sagemaker/clarify.py | 70 +++++++++++++++--------------- tests/unit/test_clarify.py | 88 +++++++++++++++++++------------------- 2 files changed, 80 insertions(+), 78 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index d01e463510..cc7afeb05a 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -40,9 +40,9 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" # TODO: verify these are sensible/sound values -# asymmetric shap default values (timeseries) -ASYM_SHAP_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" -ASYM_SHAP_EXPLANATION_TYPES = [ +# asym shap val config default values (timeseries) +ASYM_SHAP_VAL_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" +ASYM_SHAP_VAL_EXPLANATION_TYPES = [ "timewise_chronological", "timewise_anti_chronological", "timewise_bidirectional", @@ -294,7 +294,7 @@ SchemaOptional("top_k_features"): int, }, SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, - SchemaOptional("asymmetric_shap"): { + SchemaOptional("asymmetric_shapley_value"): { "explanation_type": And( str, Use(str.lower), @@ -1580,8 +1580,8 @@ def get_explainability_config(self): return copy.deepcopy({"shap": self.shap_config}) -class AsymmetricSHAPConfig(ExplainabilityConfig): - """Config class for Asymmetric SHAP algorithm for TimeSeries explainability""" +class AsymmetricShapleyValueConfig(ExplainabilityConfig): + """Config class for Asymmetric Shapley value algorithm for time series explainability.""" def __init__( self, @@ -1590,44 +1590,47 @@ def __init__( "timewise_anti_chronological", "timewise_bidirectional", "fine_grained", - ] = ASYM_SHAP_DEFAULT_EXPLANATION_TYPE, + ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_TYPE, num_samples: Optional[int] = None, ): - """Initialises config for asymmetric SHAP config. + """Initialises config for time series explainability with Asymmetric Shapley Values. - AsymmetricSHAPConfig is used specifically and only for TimeSeries explainability purposes. + AsymmetricShapleyValueConfig is used specifically and only for TimeSeries explainability + purposes. Args: explanation_type (str): Type of explanation to be used. Available explanation types are ``"timewise_chronological"``, ``"timewise_anti_chronological"``, ``"timewise_bidirectional"``, and ``"fine_grained"``. - num_samples (None or int): Number of samples to be used in the Asymmetric SHAP - algorithm. Only applicable when using ``"fine_grained"`` explanations. + num_samples (None or int): Number of samples to be used in the Asymmetric Shapley + Value algorithm. Only applicable when using ``"fine_grained"`` explanations. Raises: AssertionError: when ``explanation_type`` is not valid or ``num_samples`` is not provided for fine-grained explanations ValueError: when ``num_samples`` is provided for non fine-grained explanations """ - self.asymmetric_shap_config = dict() + self.asymmetric_shapley_value_config = dict() # validate explanation type assert ( - explanation_type in ASYM_SHAP_EXPLANATION_TYPES - ), "Please provide a valid explanation type from: " + ", ".join(ASYM_SHAP_EXPLANATION_TYPES) + explanation_type in ASYM_SHAP_VAL_EXPLANATION_TYPES + ), "Please provide a valid explanation type from: " + ", ".join( + ASYM_SHAP_VAL_EXPLANATION_TYPES + ) # validate integer num_samples is provided when necessary if explanation_type == "fine_grained": assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``." elif num_samples: # validate num_samples is not provided when unnecessary raise ValueError("``num_samples`` is only used for fine-grained explanations.") # set explanation type and (if provided) num_samples in internal config dictionary - _set(explanation_type, "explanation_type", self.asymmetric_shap_config) + _set(explanation_type, "explanation_type", self.asymmetric_shapley_value_config) _set( - num_samples, "num_samples", self.asymmetric_shap_config + num_samples, "num_samples", self.asymmetric_shapley_value_config ) # _set() does nothing if a given argument is None def get_explainability_config(self): """Returns an asymmetric shap config dictionary.""" - return copy.deepcopy({"asymmetric_shap": self.asymmetric_shap_config}) + return copy.deepcopy({"asymmetric_shapley_value": self.asymmetric_shapley_value_config}) class SageMakerClarifyProcessor(Processor): @@ -2304,7 +2307,7 @@ def bias_and_explainability( """Generates a config for Bias and Explainability""" # TimeSeries bias metrics are not supported if ( - isinstance(explainability_config, AsymmetricSHAPConfig) + isinstance(explainability_config, AsymmetricShapleyValueConfig) or "time_series_data_config" in data_config.analysis_config or (model_config and "time_series_predictor_config" in model_config.predictor_config) ): @@ -2335,19 +2338,21 @@ def explainability( ts_data_config_present = "time_series_data_config" in data_config.analysis_config ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config - if isinstance(explainability_config, AsymmetricSHAPConfig): + if isinstance(explainability_config, AsymmetricShapleyValueConfig): assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." else: if ts_data_config_present: raise ValueError( - "Please provide an AsymmetricSHAPConfig for time series explainability cases." - "For non time series cases, please do not provide a TimeSeriesDataConfig." + "Please provide an AsymmetricShapleyValueConfig for time series " + "explainability cases. For non time series cases, please do not provide a " + "TimeSeriesDataConfig." ) if ts_model_config_present: raise ValueError( - "Please provide an AsymmetricSHAPConfig for time series explainability cases." - "For non time series cases, please do not provide a TimeSeriesModelConfig." + "Please provide an AsymmetricShapleyValueConfig for time series " + "explainability cases. For non time series cases, please do not provide a " + "TimeSeriesModelConfig." ) # construct whole analysis config @@ -2427,7 +2432,7 @@ def _add_predictor( if ( "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"] - or "asymmetric_shap" in analysis_config["methods"] + or "asymmetric_shapley_value" in analysis_config["methods"] ): raise ValueError( "model_config must be provided when explainability methods are selected." @@ -2489,7 +2494,7 @@ def _add_methods( analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} if explainability_config is not None: - if isinstance(explainability_config, AsymmetricSHAPConfig): + if isinstance(explainability_config, AsymmetricShapleyValueConfig): explainability_methods = explainability_config.get_explainability_config() else: explainability_methods = cls._merge_explainability_configs( @@ -2507,11 +2512,10 @@ def _merge_explainability_configs( explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Merges explainability configs, when more than one.""" + non_ts = "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses." # validation - if isinstance(explainability_config, AsymmetricSHAPConfig): - raise ValueError( - "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." - ) + if isinstance(explainability_config, AsymmetricShapleyValueConfig): + raise ValueError(non_ts) if ( isinstance(explainability_config, PDPConfig) and "features" not in explainability_config.get_explainability_config()["pdp"] @@ -2522,11 +2526,9 @@ def _merge_explainability_configs( raise ValueError("Please provide at least one explainability config.") # list validation for config in explainability_config: - # ensure all provided explainability configs are not AsymmetricSHAPConfig - if isinstance(config, AsymmetricSHAPConfig): - raise ValueError( - "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses." - ) + # ensure all provided explainability configs are not AsymmetricShapleyValueConfig + if isinstance(config, AsymmetricShapleyValueConfig): + raise ValueError(non_ts) # main logic explainability_methods = {} for config in explainability_config: diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0a2f806a06..acb98edcfe 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -31,14 +31,14 @@ PDPConfig, SageMakerClarifyProcessor, SHAPConfig, - AsymmetricSHAPConfig, + AsymmetricShapleyValueConfig, TextConfig, ImageConfig, _AnalysisConfigGenerator, DatasetType, ProcessingOutputHandler, SegmentationConfig, - ASYM_SHAP_EXPLANATION_TYPES, + ASYM_SHAP_VAL_EXPLANATION_TYPES, ) JOB_NAME_PREFIX = "my-prefix" @@ -1223,31 +1223,31 @@ def test_shap_config_no_parameters(): assert expected_config == shap_config.get_explainability_config() -class AsymmetricSHAPConfigCase(NamedTuple): +class AsymmetricShapleyValueConfigCase(NamedTuple): explanation_type: str num_samples: Optional[int] error: Exception error_message: str -class TestAsymmetricSHAPConfig: +class TestAsymmetricShapleyValueConfig: @pytest.mark.parametrize( "test_case", [ - AsymmetricSHAPConfigCase( # cases for different explanation types + AsymmetricShapleyValueConfigCase( # cases for different explanation types explanation_type=explanation_type, num_samples=1 if explanation_type == "fine_grained" else None, error=None, error_message=None, ) - for explanation_type in ASYM_SHAP_EXPLANATION_TYPES + for explanation_type in ASYM_SHAP_VAL_EXPLANATION_TYPES ], ) - def test_asymmetric_shap_config(self, test_case): + def test_asymmetric_shapley_value_config(self, test_case): """ - GIVEN valid arguments for an AsymmetricSHAPConfig object - WHEN AsymmetricSHAPConfig object is instantiated with those arguments - THEN the asymmetric_shap_config dictionary matches expected + GIVEN valid arguments for an AsymmetricShapleyValueConfig object + WHEN AsymmetricShapleyValueConfig object is instantiated with those arguments + THEN the asymmetric_shapley_value_config dictionary matches expected """ # test case is GIVEN # construct expected config @@ -1255,36 +1255,36 @@ def test_asymmetric_shap_config(self, test_case): if test_case.explanation_type == "fine_grained": expected_config["num_samples"] = test_case.num_samples # WHEN - asym_shap_config = AsymmetricSHAPConfig( + asym_shap_val_config = AsymmetricShapleyValueConfig( explanation_type=test_case.explanation_type, num_samples=test_case.num_samples, ) # THEN - assert asym_shap_config.asymmetric_shap_config == expected_config + assert asym_shap_val_config.asymmetric_shapley_value_config == expected_config @pytest.mark.parametrize( "test_case", [ - AsymmetricSHAPConfigCase( # case for invalid explanation_type + AsymmetricShapleyValueConfigCase( # case for invalid explanation_type explanation_type="coarse_grained", num_samples=None, error=AssertionError, error_message="Please provide a valid explanation type from: " - + ", ".join(ASYM_SHAP_EXPLANATION_TYPES), + + ", ".join(ASYM_SHAP_VAL_EXPLANATION_TYPES), ), - AsymmetricSHAPConfigCase( # case for fine_grained and no num_samples + AsymmetricShapleyValueConfigCase( # case for fine_grained and no num_samples explanation_type="fine_grained", num_samples=None, error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), - AsymmetricSHAPConfigCase( # case for fine_grained and non-integer num_samples + AsymmetricShapleyValueConfigCase( # case for fine_grained and non-integer num_samples explanation_type="fine_grained", num_samples="5", error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), - AsymmetricSHAPConfigCase( # case for num_samples when non fine-grained explanation + AsymmetricShapleyValueConfigCase( # case for num_samples when non fine-grained explanation explanation_type="timewise_chronological", num_samples=5, error=ValueError, @@ -1292,15 +1292,15 @@ def test_asymmetric_shap_config(self, test_case): ), ], ) - def test_asymmetric_shap_config_invalid(self, test_case): + def test_asymmetric_shapley_value_config_invalid(self, test_case): """ - GIVEN invalid parameters for AsymmetricSHAP - WHEN AsymmetricSHAPConfig constructor is called with them + GIVEN invalid parameters for AsymmetricShapleyValue + WHEN AsymmetricShapleyValueConfig constructor is called with them THEN the expected error and message are raised """ # test case is GIVEN with pytest.raises(test_case.error, match=test_case.error_message): # THEN - AsymmetricSHAPConfig( # WHEN + AsymmetricShapleyValueConfig( # WHEN explanation_type=test_case.explanation_type, num_samples=test_case.num_samples, ) @@ -2456,16 +2456,16 @@ def _build_pdp_config_mock(): return pdp_config -def _build_asymmetric_shap_config_mock(): - asym_shap_config_dict = { +def _build_asymmetric_shapley_value_config_mock(): + asym_shap_val_config_dict = { "explanation_type": "fine_grained", "num_samples": 20, } - asym_shap_config = Mock(spec=AsymmetricSHAPConfig) - asym_shap_config.get_explainability_config.return_value = { - "asymmetric_shap": asym_shap_config_dict + asym_shap_val_config = Mock(spec=AsymmetricShapleyValueConfig) + asym_shap_val_config.get_explainability_config.return_value = { + "asymmetric_shapley_value": asym_shap_val_config_dict } - return asym_shap_config + return asym_shap_val_config def _build_data_config_mock(): @@ -2502,7 +2502,7 @@ class TestAnalysisConfigGeneratorForTimeSeriesExplainability: def test_explainability_for_time_series(self, _add_predictor, _add_methods): """ GIVEN a valid DataConfig and ModelConfig that contain time_series_data_config and - time_series_model_config respectively as well as an AsymmetricSHAPConfig + time_series_model_config respectively as well as an AsymmetricShapleyValueConfig WHEN _AnalysisConfigGenerator.explainability() is called with those args THEN _add_predictor and _add methods calls are as expected """ @@ -2511,8 +2511,8 @@ def test_explainability_for_time_series(self, _add_predictor, _add_methods): data_config_mock = _build_data_config_mock() # get ModelConfig mock model_config_mock = _build_model_config_mock() - # get AsymmetricSHAPConfig mock for explainability_config - explainability_config = _build_asymmetric_shap_config_mock() + # get AsymmetricShapleyValueConfig mock for explainability_config + explainability_config = _build_asymmetric_shapley_value_config_mock() # get time_series_data_config dict from mock time_series_data_config = copy.deepcopy( data_config_mock.analysis_config.get("time_series_data_config") @@ -2554,11 +2554,11 @@ def test_explainability_for_time_series_invalid(self): model_config_with_ts = _build_model_config_mock() model_config_without_ts = Mock(spec=ModelConfig) model_config_without_ts.predictor_config = dict() - # asymmetric shap config mock (for ts) - asym_shap_config_mock = _build_asymmetric_shap_config_mock() + # asymmetric shapley value config mock (for ts) + asym_shap_val_config_mock = _build_asymmetric_shapley_value_config_mock() # pdp config mock (for non-ts) pdp_config_mock = _build_pdp_config_mock() - # case 1: asymmetric shap (ts case) and no timeseries data config given + # case 1: ASV (ts case) and no timeseries data config given with pytest.raises( AssertionError, match="Please provide a TimeSeriesDataConfig to DataConfig." ): @@ -2566,9 +2566,9 @@ def test_explainability_for_time_series_invalid(self): data_config=data_config_without_ts, model_config=model_config_with_ts, model_predicted_label_config=None, - explainability_config=asym_shap_config_mock, + explainability_config=asym_shap_val_config_mock, ) - # case 2: asymmetric shap (ts case) and no timeseries model config given + # case 2: ASV (ts case) and no timeseries model config given with pytest.raises( AssertionError, match="Please provide a TimeSeriesModelConfig to ModelConfig." ): @@ -2576,7 +2576,7 @@ def test_explainability_for_time_series_invalid(self): data_config=data_config_with_ts, model_config=model_config_without_ts, model_predicted_label_config=None, - explainability_config=asym_shap_config_mock, + explainability_config=asym_shap_val_config_mock, ) # case 3: pdp (non ts case) and timeseries data config given with pytest.raises(ValueError, match="please do not provide a TimeSeriesDataConfig."): @@ -2598,7 +2598,7 @@ def test_explainability_for_time_series_invalid(self): def test_bias_and_explainability_invalid_for_time_series(self): """ GIVEN user provides TimeSeriesDataConfig, TimeSeriesModelConfig, and/or - AsymmetricSHAPConfig for DataConfig, ModelConfig, and as explainability_config + AsymmetricShapleyValueConfig for DataConfig, ModelConfig, and as explainability_config respectively WHEN _AnalysisConfigGenerator.bias_and_explainability is called THEN the appropriate error is raised @@ -2612,7 +2612,7 @@ def test_bias_and_explainability_invalid_for_time_series(self): model_config_without_ts = Mock(spec=ModelConfig) model_config_without_ts.predictor_config = dict() # asymmetric shap config mock (for ts) - asym_shap_config_mock = _build_asymmetric_shap_config_mock() + asym_shap_val_config_mock = _build_asymmetric_shapley_value_config_mock() # pdp config mock (for non-ts) pdp_config_mock = _build_pdp_config_mock() # case 1: asymmetric shap is given as explainability_config @@ -2621,7 +2621,7 @@ def test_bias_and_explainability_invalid_for_time_series(self): data_config=data_config_without_ts, model_config=model_config_without_ts, model_predicted_label_config=None, - explainability_config=asym_shap_config_mock, + explainability_config=asym_shap_val_config_mock, bias_config=None, ) # case 2: TimeSeriesModelConfig given to ModelConfig @@ -2647,17 +2647,17 @@ def test_bias_and_explainability_invalid_for_time_series(self): ("mock_config", "error", "error_message"), [ ( # single asym shap config for non TSX - _build_asymmetric_shap_config_mock(), + _build_asymmetric_shapley_value_config_mock(), ValueError, - "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", + "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses.", ), ( # list with asym shap config for non-TSX [ - _build_asymmetric_shap_config_mock(), + _build_asymmetric_shapley_value_config_mock(), _build_pdp_config_mock(), ], ValueError, - "Please do not provide Asymmetric SHAP configs for non-TimeSeries uses.", + "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses.", ), ], ) @@ -2669,7 +2669,7 @@ def test_merge_explainability_configs_with_timeseries_invalid( ): """ GIVEN _merge_explainability_configs is called with a explainability config or list thereof - WHEN explainability_config is or contains an AsymmetricSHAPConfig + WHEN explainability_config is or contains an AsymmetricShapleyValueConfig THEN the function will raise the appropriate error """ with pytest.raises(error, match=error_message): From 4fade41c6b8938abbd934a49fb0ef48cb9e4d8ff Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 26 Oct 2023 18:45:32 +0000 Subject: [PATCH 17/35] documentation: add description to ``AsymmetricShapleyValueConfig`` documentation: reword `target_time_series` parameter description documentation: remove TODOs --- src/sagemaker/clarify.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index cc7afeb05a..b0e4fc469d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -39,7 +39,6 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" -# TODO: verify these are sensible/sound values # asym shap val config default values (timeseries) ASYM_SHAP_VAL_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" ASYM_SHAP_VAL_EXPLANATION_TYPES = [ @@ -428,7 +427,7 @@ def to_dict(self) -> Dict[str, Any]: # pragma: no cover class TimeSeriesDataConfig: - """Config object for TimeSeries explainability specific data fields.""" + """Config object for TimeSeries explainability data configuration fields.""" def __init__( self, @@ -440,21 +439,22 @@ def __init__( ): """Initialises TimeSeries explainability data configuration fields. - Args: #TODO: verify param descriptions are accurate + Args: target_time_series (str or int): A string or a zero-based integer index. - Used to locate target time series in the shared input dataset. Also - used to determine the correct type of all succeeding arguments, as - parameters must either be all strings or all ints. + Used to locate the target time series in the shared input dataset. + If this parameter is a string, then all other parameters must also + be strings or lists of strings. If this parameter is an int, then + all others must be ints or lists of ints. item_id (str or int): A string or a zero-based integer index. Used to locate item id in the shared input dataset. timestamp (str or int): A string or a zero-based integer index. Used to locate timestamp in the shared input dataset. related_time_series (list[str] or list[int]): Optional. An array of strings or array of zero-based integer indices. Used to locate all related time - series in the shared input dataset. + series in the shared input dataset (if present). item_metadata (list[str] or list[int]): Optional. An array of strings or array of zero-based integer indices. Used to locate all item metadata - fields in the shared input dataset. + fields in the shared input dataset (if present). Raises: AssertionError: If any required arguments are not provided. @@ -1581,7 +1581,17 @@ def get_explainability_config(self): class AsymmetricShapleyValueConfig(ExplainabilityConfig): - """Config class for Asymmetric Shapley value algorithm for time series explainability.""" + """Config class for Asymmetric Shapley value algorithm for time series explainability. + + Asymmetric Shapley Values are a variant of the Shapley Value that drop the symmetry axiom [1]. + We use these to determine how features contribute to the forecasting outcome. Asymmetric + Shapley values can take into account the temporal dependencies of the time series that + forecasting models take as input. + + [1] Frye, Christopher, Colin Rowat, and Ilya Feige. "Asymmetric shapley values: incorporating + causal knowledge into model-agnostic explainability." NeurIPS (2020). + https://doi.org/10.48550/arXiv.1910.06358 + """ def __init__( self, @@ -1603,7 +1613,8 @@ def __init__( types are ``"timewise_chronological"``, ``"timewise_anti_chronological"``, ``"timewise_bidirectional"``, and ``"fine_grained"``. num_samples (None or int): Number of samples to be used in the Asymmetric Shapley - Value algorithm. Only applicable when using ``"fine_grained"`` explanations. + Value forecasting algorithm. Only applicable when using ``"fine_grained"`` + explanations. Raises: AssertionError: when ``explanation_type`` is not valid or ``num_samples`` From d55e99a953d0d81f73172b80791cb3e708bc5a87 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Tue, 31 Oct 2023 04:21:44 +0000 Subject: [PATCH 18/35] change: split ``explanation_type`` into ``explanation_ direction`` and ``granularity`` update tests and documentation accordingly --- src/sagemaker/clarify.py | 79 +++++++++++++++++++++++++------------- tests/unit/test_clarify.py | 68 ++++++++++++++++++++++++-------- 2 files changed, 103 insertions(+), 44 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index b0e4fc469d..46585103be 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -40,11 +40,15 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" # asym shap val config default values (timeseries) -ASYM_SHAP_VAL_DEFAULT_EXPLANATION_TYPE = "timewise_chronological" -ASYM_SHAP_VAL_EXPLANATION_TYPES = [ - "timewise_chronological", - "timewise_anti_chronological", - "timewise_bidirectional", +ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION = "chronological" +ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY = "timewise" +ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS = [ + "chronological", + "anti_chronological", + "bidirectional", +] +ASYM_SHAP_VAL_GRANULARITIES = [ + "timewise", "fine_grained", ] @@ -294,14 +298,22 @@ }, SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, SchemaOptional("asymmetric_shapley_value"): { - "explanation_type": And( + "explanation_direction": And( + str, + Use(str.lower), + lambda s: s + in ( + "chronological", + "anti_chronological", + "bidirectional", + ), + ), + "explanation_granularity": And( str, Use(str.lower), lambda s: s in ( - "timewise_chronological", - "timewise_anti_chronological", - "timewise_bidirectional", + "timewise", "fine_grained", ), ), @@ -1595,12 +1607,15 @@ class AsymmetricShapleyValueConfig(ExplainabilityConfig): def __init__( self, - explanation_type: Literal[ - "timewise_chronological", - "timewise_anti_chronological", - "timewise_bidirectional", + explanation_direction: Literal[ + "chronological", + "anti_chronological", + "bidirectional", + ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION, + granularity: Literal[ + "timewise", "fine_grained", - ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_TYPE, + ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY, num_samples: Optional[int] = None, ): """Initialises config for time series explainability with Asymmetric Shapley Values. @@ -1609,32 +1624,42 @@ def __init__( purposes. Args: - explanation_type (str): Type of explanation to be used. Available explanation - types are ``"timewise_chronological"``, ``"timewise_anti_chronological"``, - ``"timewise_bidirectional"``, and ``"fine_grained"``. + explanation_direction (str): Type of explanation to be used. Available explanation + types are ``"chronological"``, ``"anti_chronological"``, and ``"bidirectional"``. + granularity (str): Explanation granularity to be used. Available granularity options + are ``"timewise"`` and ``"fine_grained"``. num_samples (None or int): Number of samples to be used in the Asymmetric Shapley Value forecasting algorithm. Only applicable when using ``"fine_grained"`` explanations. Raises: - AssertionError: when ``explanation_type`` is not valid or ``num_samples`` - is not provided for fine-grained explanations - ValueError: when ``num_samples`` is provided for non fine-grained explanations + AssertionError: when ``explanation_direction`` or ``granularity`` are not valid, + or ``num_samples`` is not provided for fine-grained explanations + ValueError: when ``num_samples`` is provided for non fine-grained explanations, or + when explanation_direction is not ``"chronological"`` when granularity is + ``"fine_grained"``. """ self.asymmetric_shapley_value_config = dict() - # validate explanation type + # validate explanation direction assert ( - explanation_type in ASYM_SHAP_VAL_EXPLANATION_TYPES - ), "Please provide a valid explanation type from: " + ", ".join( - ASYM_SHAP_VAL_EXPLANATION_TYPES + explanation_direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS + ), "Please provide a valid explanation direction from: " + ", ".join( + ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS ) - # validate integer num_samples is provided when necessary - if explanation_type == "fine_grained": + # validate granularity + assert ( + granularity in ASYM_SHAP_VAL_GRANULARITIES + ), "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES) + if granularity == "fine_grained": assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``." + assert ( + explanation_direction == "chronological" + ), f"{explanation_direction} and {granularity} granularity are not supported together." elif num_samples: # validate num_samples is not provided when unnecessary raise ValueError("``num_samples`` is only used for fine-grained explanations.") # set explanation type and (if provided) num_samples in internal config dictionary - _set(explanation_type, "explanation_type", self.asymmetric_shapley_value_config) + _set(explanation_direction, "explanation_direction", self.asymmetric_shapley_value_config) + _set(granularity, "granularity", self.asymmetric_shapley_value_config) _set( num_samples, "num_samples", self.asymmetric_shapley_value_config ) # _set() does nothing if a given argument is None diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index acb98edcfe..eee27104a2 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -38,7 +38,7 @@ DatasetType, ProcessingOutputHandler, SegmentationConfig, - ASYM_SHAP_VAL_EXPLANATION_TYPES, + ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS, ) JOB_NAME_PREFIX = "my-prefix" @@ -1224,7 +1224,8 @@ def test_shap_config_no_parameters(): class AsymmetricShapleyValueConfigCase(NamedTuple): - explanation_type: str + explanation_direction: str + granularity: str num_samples: Optional[int] error: Exception error_message: str @@ -1234,13 +1235,23 @@ class TestAsymmetricShapleyValueConfig: @pytest.mark.parametrize( "test_case", [ - AsymmetricShapleyValueConfigCase( # cases for different explanation types - explanation_type=explanation_type, - num_samples=1 if explanation_type == "fine_grained" else None, + AsymmetricShapleyValueConfigCase( # cases for timewise granularity + explanation_direction=explanation_direction, + granularity="timewise", + num_samples=None, + error=None, + error_message=None, + ) + for explanation_direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS + ] + + [ + AsymmetricShapleyValueConfigCase( # cases for fine_grained granularity + explanation_direction="chronological", + granularity="fine_grained", + num_samples=1, error=None, error_message=None, ) - for explanation_type in ASYM_SHAP_VAL_EXPLANATION_TYPES ], ) def test_asymmetric_shapley_value_config(self, test_case): @@ -1251,12 +1262,16 @@ def test_asymmetric_shapley_value_config(self, test_case): """ # test case is GIVEN # construct expected config - expected_config = {"explanation_type": test_case.explanation_type} - if test_case.explanation_type == "fine_grained": + expected_config = { + "explanation_direction": test_case.explanation_direction, + "granularity": test_case.granularity, + } + if test_case.granularity == "fine_grained": expected_config["num_samples"] = test_case.num_samples # WHEN asym_shap_val_config = AsymmetricShapleyValueConfig( - explanation_type=test_case.explanation_type, + explanation_direction=test_case.explanation_direction, + granularity=test_case.granularity, num_samples=test_case.num_samples, ) # THEN @@ -1265,31 +1280,49 @@ def test_asymmetric_shapley_value_config(self, test_case): @pytest.mark.parametrize( "test_case", [ - AsymmetricShapleyValueConfigCase( # case for invalid explanation_type - explanation_type="coarse_grained", + AsymmetricShapleyValueConfigCase( # case for invalid explanation_direction + explanation_direction="non-directional", + granularity="timewise", num_samples=None, error=AssertionError, - error_message="Please provide a valid explanation type from: " - + ", ".join(ASYM_SHAP_VAL_EXPLANATION_TYPES), + error_message="Please provide a valid explanation direction from: " + + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS), ), AsymmetricShapleyValueConfigCase( # case for fine_grained and no num_samples - explanation_type="fine_grained", + explanation_direction="chronological", + granularity="fine_grained", num_samples=None, error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for fine_grained and non-integer num_samples - explanation_type="fine_grained", + explanation_direction="chronological", + granularity="fine_grained", num_samples="5", error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for num_samples when non fine-grained explanation - explanation_type="timewise_chronological", + explanation_direction="chronological", + granularity="timewise", num_samples=5, error=ValueError, error_message="``num_samples`` is only used for fine-grained explanations.", ), + AsymmetricShapleyValueConfigCase( # case for anti_chronological and fine_grained + explanation_direction="anti_chronological", + granularity="fine_grained", + num_samples=5, + error=AssertionError, + error_message="not supported together.", + ), + AsymmetricShapleyValueConfigCase( # case for bidirectional and fine_grained + explanation_direction="bidirectional", + granularity="fine_grained", + num_samples=5, + error=AssertionError, + error_message="not supported together.", + ), ], ) def test_asymmetric_shapley_value_config_invalid(self, test_case): @@ -1301,7 +1334,8 @@ def test_asymmetric_shapley_value_config_invalid(self, test_case): # test case is GIVEN with pytest.raises(test_case.error, match=test_case.error_message): # THEN AsymmetricShapleyValueConfig( # WHEN - explanation_type=test_case.explanation_type, + explanation_direction=test_case.explanation_direction, + granularity=test_case.granularity, num_samples=test_case.num_samples, ) From 1a09ad80300e4f9aae057efd67ce00e4ef26543a Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 1 Nov 2023 02:10:18 +0000 Subject: [PATCH 19/35] fix: rename ``explanation_granularity`` to ``granularity`` --- src/sagemaker/clarify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 46585103be..ea3c92cd24 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -308,7 +308,7 @@ "bidirectional", ), ), - "explanation_granularity": And( + "granularity": And( str, Use(str.lower), lambda s: s From d63fab66d8220b50da8b9b90425c1db2183fea56 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 3 Nov 2023 21:43:15 +0000 Subject: [PATCH 20/35] change: rename ``explanation_direction`` to ``direction`` --- src/sagemaker/clarify.py | 18 +++++++++--------- tests/unit/test_clarify.py | 28 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index ea3c92cd24..6f188f57ea 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -298,7 +298,7 @@ }, SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, SchemaOptional("asymmetric_shapley_value"): { - "explanation_direction": And( + "direction": And( str, Use(str.lower), lambda s: s @@ -1607,7 +1607,7 @@ class AsymmetricShapleyValueConfig(ExplainabilityConfig): def __init__( self, - explanation_direction: Literal[ + direction: Literal[ "chronological", "anti_chronological", "bidirectional", @@ -1624,7 +1624,7 @@ def __init__( purposes. Args: - explanation_direction (str): Type of explanation to be used. Available explanation + direction (str): Type of explanation to be used. Available explanation types are ``"chronological"``, ``"anti_chronological"``, and ``"bidirectional"``. granularity (str): Explanation granularity to be used. Available granularity options are ``"timewise"`` and ``"fine_grained"``. @@ -1633,16 +1633,16 @@ def __init__( explanations. Raises: - AssertionError: when ``explanation_direction`` or ``granularity`` are not valid, + AssertionError: when ``direction`` or ``granularity`` are not valid, or ``num_samples`` is not provided for fine-grained explanations ValueError: when ``num_samples`` is provided for non fine-grained explanations, or - when explanation_direction is not ``"chronological"`` when granularity is + when direction is not ``"chronological"`` when granularity is ``"fine_grained"``. """ self.asymmetric_shapley_value_config = dict() # validate explanation direction assert ( - explanation_direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS + direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS ), "Please provide a valid explanation direction from: " + ", ".join( ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS ) @@ -1653,12 +1653,12 @@ def __init__( if granularity == "fine_grained": assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``." assert ( - explanation_direction == "chronological" - ), f"{explanation_direction} and {granularity} granularity are not supported together." + direction == "chronological" + ), f"{direction} and {granularity} granularity are not supported together." elif num_samples: # validate num_samples is not provided when unnecessary raise ValueError("``num_samples`` is only used for fine-grained explanations.") # set explanation type and (if provided) num_samples in internal config dictionary - _set(explanation_direction, "explanation_direction", self.asymmetric_shapley_value_config) + _set(direction, "direction", self.asymmetric_shapley_value_config) _set(granularity, "granularity", self.asymmetric_shapley_value_config) _set( num_samples, "num_samples", self.asymmetric_shapley_value_config diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index eee27104a2..5aca8e2797 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1224,7 +1224,7 @@ def test_shap_config_no_parameters(): class AsymmetricShapleyValueConfigCase(NamedTuple): - explanation_direction: str + direction: str granularity: str num_samples: Optional[int] error: Exception @@ -1236,17 +1236,17 @@ class TestAsymmetricShapleyValueConfig: "test_case", [ AsymmetricShapleyValueConfigCase( # cases for timewise granularity - explanation_direction=explanation_direction, + direction=direction, granularity="timewise", num_samples=None, error=None, error_message=None, ) - for explanation_direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS + for direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS ] + [ AsymmetricShapleyValueConfigCase( # cases for fine_grained granularity - explanation_direction="chronological", + direction="chronological", granularity="fine_grained", num_samples=1, error=None, @@ -1263,14 +1263,14 @@ def test_asymmetric_shapley_value_config(self, test_case): # test case is GIVEN # construct expected config expected_config = { - "explanation_direction": test_case.explanation_direction, + "direction": test_case.direction, "granularity": test_case.granularity, } if test_case.granularity == "fine_grained": expected_config["num_samples"] = test_case.num_samples # WHEN asym_shap_val_config = AsymmetricShapleyValueConfig( - explanation_direction=test_case.explanation_direction, + direction=test_case.direction, granularity=test_case.granularity, num_samples=test_case.num_samples, ) @@ -1280,8 +1280,8 @@ def test_asymmetric_shapley_value_config(self, test_case): @pytest.mark.parametrize( "test_case", [ - AsymmetricShapleyValueConfigCase( # case for invalid explanation_direction - explanation_direction="non-directional", + AsymmetricShapleyValueConfigCase( # case for invalid direction + direction="non-directional", granularity="timewise", num_samples=None, error=AssertionError, @@ -1289,35 +1289,35 @@ def test_asymmetric_shapley_value_config(self, test_case): + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS), ), AsymmetricShapleyValueConfigCase( # case for fine_grained and no num_samples - explanation_direction="chronological", + direction="chronological", granularity="fine_grained", num_samples=None, error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for fine_grained and non-integer num_samples - explanation_direction="chronological", + direction="chronological", granularity="fine_grained", num_samples="5", error=AssertionError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for num_samples when non fine-grained explanation - explanation_direction="chronological", + direction="chronological", granularity="timewise", num_samples=5, error=ValueError, error_message="``num_samples`` is only used for fine-grained explanations.", ), AsymmetricShapleyValueConfigCase( # case for anti_chronological and fine_grained - explanation_direction="anti_chronological", + direction="anti_chronological", granularity="fine_grained", num_samples=5, error=AssertionError, error_message="not supported together.", ), AsymmetricShapleyValueConfigCase( # case for bidirectional and fine_grained - explanation_direction="bidirectional", + direction="bidirectional", granularity="fine_grained", num_samples=5, error=AssertionError, @@ -1334,7 +1334,7 @@ def test_asymmetric_shapley_value_config_invalid(self, test_case): # test case is GIVEN with pytest.raises(test_case.error, match=test_case.error_message): # THEN AsymmetricShapleyValueConfig( # WHEN - explanation_direction=test_case.explanation_direction, + direction=test_case.direction, granularity=test_case.granularity, num_samples=test_case.num_samples, ) From d0002bc5f21918e97f7c4557965cbce98db5f290 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 6 Dec 2023 22:41:11 +0000 Subject: [PATCH 21/35] change: rename ``item_metadata`` to ``static_covariates`` change: add ``dataset_format`` as a parameter for time series cases change: allow features jmespaths to be none for time series cases change: add validation to prevent non-json dataset formats for time series cases test: update unit tests to reflect above changes --- src/sagemaker/clarify.py | 72 +++++++++---- tests/unit/test_clarify.py | 204 ++++++++++++++++++++++++------------- 2 files changed, 184 insertions(+), 92 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6f188f57ea..d56aa7ae2b 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -102,7 +102,16 @@ "item_id": Or(str, int), "timestamp": Or(str, int), SchemaOptional("related_time_series"): Or([str], [int]), - SchemaOptional("item_metadata"): Or([str], [int]), + SchemaOptional("static_covariates"): Or([str], [int]), + SchemaOptional("dataset_format"): And( + str, + Use(str.lower), + lambda s: s + in ( + "columns", + "timestamp_records", + ), + ), }, "methods": { SchemaOptional("shap"): { @@ -370,6 +379,13 @@ class DatasetType(Enum): IMAGE = "application/x-image" +class TimeSeriesJSONDatasetFormat(Enum): + """Possible dataset formats for JSON time series data files.""" + + COLUMNS = "columns" + TIMESTAMP_RECORDS = "timestamp_records" + + class SegmentationConfig: """Config object that defines segment(s) of the dataset on which metrics are computed.""" @@ -447,16 +463,18 @@ def __init__( item_id: Union[str, int], timestamp: Union[str, int], related_time_series: Optional[List[Union[str, int]]] = None, - item_metadata: Optional[List[Union[str, int]]] = None, + static_covariates: Optional[List[Union[str, int]]] = None, + dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None, ): """Initialises TimeSeries explainability data configuration fields. Args: target_time_series (str or int): A string or a zero-based integer index. Used to locate the target time series in the shared input dataset. - If this parameter is a string, then all other parameters must also - be strings or lists of strings. If this parameter is an int, then - all others must be ints or lists of ints. + If this parameter is a string, then all other parameters except + `dataset_format` must be strings or lists of strings. If + this parameter is an int, then all other parameters except + `dataset_format` must be ints or lists of ints. item_id (str or int): A string or a zero-based integer index. Used to locate item id in the shared input dataset. timestamp (str or int): A string or a zero-based integer index. Used to @@ -464,9 +482,12 @@ def __init__( related_time_series (list[str] or list[int]): Optional. An array of strings or array of zero-based integer indices. Used to locate all related time series in the shared input dataset (if present). - item_metadata (list[str] or list[int]): Optional. An array of strings or - array of zero-based integer indices. Used to locate all item metadata + static_covariates (list[str] or list[int]): Optional. An array of strings or + array of zero-based integer indices. Used to locate all static covariate fields in the shared input dataset (if present). + dataset_format (TimeSeriesJSONDatasetFormat): Describes the format + of the data files provided for analysis. Should only be provided + when dataset is in JSON format. Raises: AssertionError: If any required arguments are not provided. @@ -484,7 +505,7 @@ def __init__( raise ValueError(f"Please provide {params_type} for ``item_id``") if not isinstance(timestamp, params_type): raise ValueError(f"Please provide {params_type} for ``timestamp``") - # add remaining fields to an internal dictionary + # add mandatory fields to an internal dictionary self.time_series_data_config = dict() _set(target_time_series, "target_time_series", self.time_series_data_config) _set(item_id, "item_id", self.time_series_data_config) @@ -502,22 +523,32 @@ def __init__( raise ValueError( related_time_series_error_message ) # related_time_series is not a list of strings or list of ints + if params_type == str and not all(related_time_series): + raise ValueError("Please do not provide empty strings in ``related_time_series``.") _set( related_time_series, "related_time_series", self.time_series_data_config ) # related_time_series is valid, add it - item_metadata_series_error_message = ( - f"Please provide a list of {params_type} for ``item_metadata``" + static_covariates_series_error_message = ( + f"Please provide a list of {params_type} for ``static_covariates``" ) - if item_metadata: - if not isinstance(item_metadata, list): - raise ValueError(item_metadata_series_error_message) # item_metadata is not a list - if not all([isinstance(value, params_type) for value in item_metadata]): + if static_covariates: + if not isinstance(static_covariates, list): + raise ValueError(static_covariates_series_error_message) # static_covariates is not a list + if not all([isinstance(value, params_type) for value in static_covariates]): raise ValueError( - item_metadata_series_error_message - ) # item_metadata is not a list of strings or list of ints + static_covariates_series_error_message + ) # static_covariates is not a list of strings or list of ints + if params_type == str and not all(static_covariates): + raise ValueError("Please do not provide empty strings in ``static_covariates``.") _set( - item_metadata, "item_metadata", self.time_series_data_config - ) # item_metadata is valid, add it + static_covariates, "static_covariates", self.time_series_data_config + ) # static_covariates is valid, add it + if params_type == str: + # check dataset_format is provided and valid + assert isinstance(dataset_format, TimeSeriesJSONDatasetFormat), "Please provide a valid dataset format." + _set(dataset_format.value, "dataset_format", self.time_series_data_config) + else: + assert not dataset_format, "Dataset format should only be provided when data files are JSONs." def get_time_series_data_config(self): """Returns part of an analysis config dictionary.""" @@ -666,8 +697,11 @@ def __init__( f" are not supported for dataset_type '{dataset_type}'." f" Please check the API documentation for the supported dataset types." ) + # check if any other format other than JSON is provided for time series case + if time_series_data_config and dataset_type != "application/json": + raise ValueError("Currently time series explainability only supports JSON format data") # features JMESPath is required for JSON as we can't derive it ourselves - if dataset_type == "application/json" and features is None: + if dataset_type == "application/json" and features is None and not time_series_data_config: raise ValueError("features JMESPath is required for application/json dataset_type") self.s3_data_input_path = s3_data_input_path self.s3_output_path = s3_output_path diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 5aca8e2797..7fb9f48d5e 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -16,9 +16,8 @@ import copy import pytest -from mock import MagicMock, Mock, patch +from mock import ANY, MagicMock, Mock, patch from typing import List, NamedTuple, Optional, Union -from unittest.mock import ANY from sagemaker import Processor, image_uris from sagemaker.clarify import ( @@ -39,6 +38,7 @@ ProcessingOutputHandler, SegmentationConfig, ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS, + TimeSeriesJSONDatasetFormat, ) JOB_NAME_PREFIX = "my-prefix" @@ -332,7 +332,8 @@ class TimeSeriesDataConfigCase(NamedTuple): item_id: Union[str, int] timestamp: Union[str, int] related_time_series: Optional[List[Union[str, int]]] - item_metadata: Optional[List[Union[str, int]]] + static_covariates: Optional[List[Union[str, int]]] + dataset_format: Optional[TimeSeriesJSONDatasetFormat] error: Exception error_message: Optional[str] @@ -344,7 +345,8 @@ class TestTimeSeriesDataConfig: item_id="item_id", timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=None, error_message=None, ), @@ -353,25 +355,28 @@ class TestTimeSeriesDataConfig: item_id="item_id", timestamp="timestamp", related_time_series=["ts1", "ts2", "ts3"], - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=None, error_message=None, ), - TimeSeriesDataConfigCase( # item_metadata provided str case + TimeSeriesDataConfigCase( # static_covariates provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", related_time_series=None, - item_metadata=["a", "b", "c", "d"], + static_covariates=["a", "b", "c", "d"], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=None, error_message=None, ), - TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided str case + TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided str case target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", related_time_series=["ts1", "ts2", "ts3"], - item_metadata=["a", "b", "c", "d"], + static_covariates=["a", "b", "c", "d"], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=None, error_message=None, ), @@ -380,7 +385,8 @@ class TestTimeSeriesDataConfig: item_id=2, timestamp=3, related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=None, error=None, error_message=None, ), @@ -389,32 +395,35 @@ class TestTimeSeriesDataConfig: item_id=2, timestamp=3, related_time_series=[4, 5, 6], - item_metadata=None, + static_covariates=None, + dataset_format=None, error=None, error_message=None, ), - TimeSeriesDataConfigCase( # item_metadata provided int case + TimeSeriesDataConfigCase( # static_covariates provided int case target_time_series=1, item_id=2, timestamp=3, related_time_series=None, - item_metadata=[7, 8, 9, 10], + static_covariates=[7, 8, 9, 10], + dataset_format=None, error=None, error_message=None, ), - TimeSeriesDataConfigCase( # both related_time_series and item_metadata provided int case + TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided int case target_time_series=1, item_id=2, timestamp=3, related_time_series=[4, 5, 6], - item_metadata=[7, 8, 9, 10], + static_covariates=[7, 8, 9, 10], + dataset_format=None, error=None, error_message=None, ), ] @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) - def test_time_series_data_config(self, test_case): + def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): """ GIVEN A set of valid parameters are given WHEN A TimeSeriesDataConfig object is instantiated @@ -426,17 +435,20 @@ def test_time_series_data_config(self, test_case): "item_id": test_case.item_id, "timestamp": test_case.timestamp, } + if isinstance(test_case.target_time_series, str): + expected_output["dataset_format"] = test_case.dataset_format.value if test_case.related_time_series: expected_output["related_time_series"] = test_case.related_time_series - if test_case.item_metadata: - expected_output["item_metadata"] = test_case.item_metadata + if test_case.static_covariates: + expected_output["static_covariates"] = test_case.static_covariates # GIVEN, WHEN ts_data_config = TimeSeriesDataConfig( target_time_series=test_case.target_time_series, item_id=test_case.item_id, timestamp=test_case.timestamp, related_time_series=test_case.related_time_series, - item_metadata=test_case.item_metadata, + static_covariates=test_case.static_covariates, + dataset_format=test_case.dataset_format, ) # THEN assert ts_data_config.time_series_data_config == expected_output @@ -449,7 +461,8 @@ def test_time_series_data_config(self, test_case): item_id="item_id", timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=AssertionError, error_message="Please provide a target time series.", ), @@ -458,7 +471,8 @@ def test_time_series_data_config(self, test_case): item_id=None, timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=AssertionError, error_message="Please provide an item id.", ), @@ -467,7 +481,8 @@ def test_time_series_data_config(self, test_case): item_id="item_id", timestamp=None, related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=AssertionError, error_message="Please provide a timestamp.", ), @@ -476,7 +491,8 @@ def test_time_series_data_config(self, test_case): item_id="item_id", timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, error_message="Please provide a string or an int for ``target_time_series``", ), @@ -485,7 +501,8 @@ def test_time_series_data_config(self, test_case): item_id=5, timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, error_message=f"Please provide {str} for ``item_id``", ), @@ -494,7 +511,8 @@ def test_time_series_data_config(self, test_case): item_id="item_id", timestamp=10, related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, error_message=f"Please provide {str} for ``timestamp``", ), @@ -503,25 +521,28 @@ def test_time_series_data_config(self, test_case): item_id="item_id", timestamp="timestamp", related_time_series=["ts1", "ts2", "ts3", 4], - item_metadata=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, error_message=f"Please provide a list of {str} for ``related_time_series``", ), - TimeSeriesDataConfigCase( # item_metadata not str list if str target_time_series + TimeSeriesDataConfigCase( # static_covariates not str list if str target_time_series target_time_series="target_time_series", item_id="item_id", timestamp="timestamp", related_time_series=None, - item_metadata=[4, 5, 6.0], + static_covariates=[4, 5, 6.0], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, - error_message=f"Please provide a list of {str} for ``item_metadata``", + error_message=f"Please provide a list of {str} for ``static_covariates``", ), TimeSeriesDataConfigCase( # item_id differing type from int target_time_series target_time_series=1, item_id="item_id", timestamp=3, related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=None, error=ValueError, error_message=f"Please provide {int} for ``item_id``", ), @@ -530,7 +551,8 @@ def test_time_series_data_config(self, test_case): item_id=2, timestamp="timestamp", related_time_series=None, - item_metadata=None, + static_covariates=None, + dataset_format=None, error=ValueError, error_message=f"Please provide {int} for ``timestamp``", ), @@ -539,22 +561,74 @@ def test_time_series_data_config(self, test_case): item_id=2, timestamp=3, related_time_series=[4, 5, 6, "ts7"], - item_metadata=None, + static_covariates=None, + dataset_format=None, error=ValueError, error_message=f"Please provide a list of {int} for ``related_time_series``", ), - TimeSeriesDataConfigCase( # item_metadata not int list if int target_time_series + TimeSeriesDataConfigCase( # static_covariates not int list if int target_time_series target_time_series=1, item_id=2, timestamp=3, related_time_series=[4, 5, 6, 7], - item_metadata=[8, 9, "10"], + static_covariates=[8, 9, "10"], + dataset_format=None, error=ValueError, - error_message=f"Please provide a list of {int} for ``item_metadata``", + error_message=f"Please provide a list of {int} for ``static_covariates``", + ), + TimeSeriesDataConfigCase( # related_time_series contains blank string + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3", ""], + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please do not provide empty strings in ``related_time_series``", + ), + TimeSeriesDataConfigCase( # static_covariates contains blank string + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=["scv4", "scv5", "scv6", ""], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please do not provide empty strings in ``static_covariates``", + ), + TimeSeriesDataConfigCase( # dataset_format provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + static_covariates=[7, 8, 9, 10], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=AssertionError, + error_message="Dataset format should only be provided when data files are JSONs.", + ), + TimeSeriesDataConfigCase( # dataset_format not provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=["a", "b", "c", "d"], + dataset_format=None, + error=AssertionError, + error_message="Please provide a valid dataset format.", + ), + TimeSeriesDataConfigCase( # dataset_format wrong type str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=["a", "b", "c", "d"], + dataset_format="made_up_format", + error=AssertionError, + error_message="Please provide a valid dataset format.", ), ], ) - def test_time_series_data_config_invalid(self, test_case): + def test_time_series_data_config_invalid(self, test_case: TimeSeriesDataConfigCase): """ GIVEN required parameters are incomplete or invalid WHEN TimeSeriesDataConfig constructor is called @@ -566,52 +640,40 @@ def test_time_series_data_config_invalid(self, test_case): item_id=test_case.item_id, timestamp=test_case.timestamp, related_time_series=test_case.related_time_series, - item_metadata=test_case.item_metadata, + static_covariates=test_case.static_covariates, + dataset_format=test_case.dataset_format, ) @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) - def test_data_config_with_time_series(self, test_case): + def test_data_config_with_time_series(self, test_case: TimeSeriesDataConfigCase): """ GIVEN a TimeSeriesDataConfig object is created WHEN a DataConfig object is created and given valid params + the TimeSeriesDataConfig THEN the internal config dictionary matches what's expected """ + # currently TSX only supports json so skip non-json tests + if isinstance(test_case.target_time_series, int): + return # setup - headers = ["Label", "F1", "F2", "F3", "F4", "Predicted Label"] - dataset_type = "application/json" - segment_config = [ - SegmentationConfig( - name_or_index="F1", - segments=[[0]], - config_name="c1", - display_aliases=["a1"], - ) - ] + headers = ["item_id", "timestamp", "target_ts", "rts1", "scv1"] # construct expected output mock_ts_data_config_dict = { "target_time_series": test_case.target_time_series, "item_id": test_case.item_id, "timestamp": test_case.timestamp, } + if isinstance(test_case.target_time_series, str): + dataset_type = "application/json" + mock_ts_data_config_dict["dataset_format"] = test_case.dataset_format.value + else: + dataset_type = "text/csv" if test_case.related_time_series: mock_ts_data_config_dict["related_time_series"] = test_case.related_time_series - if test_case.item_metadata: - mock_ts_data_config_dict["item_metadata"] = test_case.item_metadata + if test_case.static_covariates: + mock_ts_data_config_dict["static_covariates"] = test_case.static_covariates expected_config = { "dataset_type": dataset_type, "headers": headers, - "label": "Label", - "segment_config": [ - { - "config_name": "c1", - "display_aliases": ["a1"], - "name_or_index": "F1", - "segments": [[0]], - } - ], - "excluded_columns": ["F4"], - "features": "[*].[F1,F2,F3]", - "predicted_label": "Predicted Label", "time_series_data_config": mock_ts_data_config_dict, } # GIVEN @@ -623,13 +685,8 @@ def test_data_config_with_time_series(self, test_case): data_config = DataConfig( s3_data_input_path="s3://path/to/input.csv", s3_output_path="s3://path/to/output", - features="[*].[F1,F2,F3]", - label="Label", headers=headers, - dataset_type="application/json", - excluded_columns=["F4"], - predicted_label="Predicted Label", - segmentation_config=segment_config, + dataset_type=dataset_type, time_series_data_config=ts_data_config, ) # THEN @@ -2508,11 +2565,12 @@ def _build_data_config_mock(): """ # setup a time_series_data_config dictionary time_series_data_config = { - "target_time_series": 1, - "item_id": 2, - "timestamp": 3, - "related_time_series": [4, 5, 6], - "item_metadata": [7, 8, 9, 10], + "target_time_series": "target_ts", + "item_id": "id", + "timestamp": "timestamp", + "related_time_series": ["rts1", "rts2", "rts3"], + "static_covariates": ["scv1", "scv2", "scv3"], + "dataset_format": TimeSeriesJSONDatasetFormat.COLUMNS, } # setup DataConfig mock data_config = Mock(spec=DataConfig) From da0b8f7751a2958328bb59509534b776d6396854 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 6 Dec 2023 23:01:53 +0000 Subject: [PATCH 22/35] fix: update clarify files to meet formatting reqs --- src/sagemaker/clarify.py | 16 +++++++++++----- tests/unit/test_clarify.py | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index d56aa7ae2b..d5473375b3 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -471,8 +471,8 @@ def __init__( Args: target_time_series (str or int): A string or a zero-based integer index. Used to locate the target time series in the shared input dataset. - If this parameter is a string, then all other parameters except - `dataset_format` must be strings or lists of strings. If + If this parameter is a string, then all other parameters except + `dataset_format` must be strings or lists of strings. If this parameter is an int, then all other parameters except `dataset_format` must be ints or lists of ints. item_id (str or int): A string or a zero-based integer index. Used to @@ -533,7 +533,9 @@ def __init__( ) if static_covariates: if not isinstance(static_covariates, list): - raise ValueError(static_covariates_series_error_message) # static_covariates is not a list + raise ValueError( + static_covariates_series_error_message + ) # static_covariates is not a list if not all([isinstance(value, params_type) for value in static_covariates]): raise ValueError( static_covariates_series_error_message @@ -545,10 +547,14 @@ def __init__( ) # static_covariates is valid, add it if params_type == str: # check dataset_format is provided and valid - assert isinstance(dataset_format, TimeSeriesJSONDatasetFormat), "Please provide a valid dataset format." + assert isinstance( + dataset_format, TimeSeriesJSONDatasetFormat + ), "Please provide a valid dataset format." _set(dataset_format.value, "dataset_format", self.time_series_data_config) else: - assert not dataset_format, "Dataset format should only be provided when data files are JSONs." + assert ( + not dataset_format + ), "Dataset format should only be provided when data files are JSONs." def get_time_series_data_config(self): """Returns part of an analysis config dictionary.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7fb9f48d5e..14f77aadd9 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -584,7 +584,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, - error_message=f"Please do not provide empty strings in ``related_time_series``", + error_message="Please do not provide empty strings in ``related_time_series``", ), TimeSeriesDataConfigCase( # static_covariates contains blank string target_time_series="target_time_series", @@ -594,7 +594,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): static_covariates=["scv4", "scv5", "scv6", ""], dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, error=ValueError, - error_message=f"Please do not provide empty strings in ``static_covariates``", + error_message="Please do not provide empty strings in ``static_covariates``", ), TimeSeriesDataConfigCase( # dataset_format provided int case target_time_series=1, From a849d1cec9840b638932c107b3edd406ad52577f Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 7 Dec 2023 22:01:56 +0000 Subject: [PATCH 23/35] change: require headers for time series explainability --- src/sagemaker/clarify.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index d5473375b3..044a09e0bb 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -704,8 +704,12 @@ def __init__( f" Please check the API documentation for the supported dataset types." ) # check if any other format other than JSON is provided for time series case - if time_series_data_config and dataset_type != "application/json": - raise ValueError("Currently time series explainability only supports JSON format data") + if time_series_data_config: + if dataset_type != "application/json": + raise ValueError( + "Currently time series explainability only supports JSON format data" + ) + assert headers, "Headers are required for time series explainability" # features JMESPath is required for JSON as we can't derive it ourselves if dataset_type == "application/json" and features is None and not time_series_data_config: raise ValueError("features JMESPath is required for application/json dataset_type") From 90bafe00ebfbf514a740c1f44d3a80403b89d99a Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 13 Dec 2023 18:53:43 +0000 Subject: [PATCH 24/35] feat: add (early version of) baseline config to asym shap val config --- src/sagemaker/clarify.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 044a09e0bb..5b073964bd 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -327,6 +327,14 @@ ), ), SchemaOptional("num_samples"): int, + SchemaOptional("baseline"): Or( + str, + { + SchemaOptional("target_ts", default="zero"): str, + SchemaOptional("related_ts"): str, + SchemaOptional("static_covariates"): [Or(str, int, float)], + }, + ), }, }, SchemaOptional("predictor"): { @@ -1661,6 +1669,7 @@ def __init__( "fine_grained", ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY, num_samples: Optional[int] = None, + baseline: Optional[Union[str, Dict[str, Any]]] = None, ): """Initialises config for time series explainability with Asymmetric Shapley Values. @@ -1675,6 +1684,8 @@ def __init__( num_samples (None or int): Number of samples to be used in the Asymmetric Shapley Value forecasting algorithm. Only applicable when using ``"fine_grained"`` explanations. + baseline (str or dict): Link to a baseline configuration or a dictionary for it. + # TODO: improve above. Raises: AssertionError: when ``direction`` or ``granularity`` are not valid, @@ -1707,6 +1718,8 @@ def __init__( _set( num_samples, "num_samples", self.asymmetric_shapley_value_config ) # _set() does nothing if a given argument is None + # TODO: add sdk-side validation to baseline + _set(baseline, "baseline", self.asymmetric_shapley_value_config) def get_explainability_config(self): """Returns an asymmetric shap config dictionary.""" From 567a4db856576a8d6097857b13800da115246def Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 13 Dec 2023 19:26:16 +0000 Subject: [PATCH 25/35] refactor: change baseline config param names to keep with convention and be more cx friendly --- src/sagemaker/clarify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 5b073964bd..6979201258 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -330,8 +330,8 @@ SchemaOptional("baseline"): Or( str, { - SchemaOptional("target_ts", default="zero"): str, - SchemaOptional("related_ts"): str, + SchemaOptional("target_time_series", default="zero"): str, + SchemaOptional("related_time_series"): str, SchemaOptional("static_covariates"): [Or(str, int, float)], }, ), From 467c6799262f1111d757670c9403f8697821d42f Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Wed, 13 Dec 2023 19:36:29 +0000 Subject: [PATCH 26/35] refactor: baseline config from list to dictionary where key is item_id value --- src/sagemaker/clarify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6979201258..4b8031ad77 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -332,7 +332,7 @@ { SchemaOptional("target_time_series", default="zero"): str, SchemaOptional("related_time_series"): str, - SchemaOptional("static_covariates"): [Or(str, int, float)], + SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]}, }, ), }, From 66081d40a3b9e278e18e187043c45110d0e8330d Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 14 Dec 2023 08:55:32 +0000 Subject: [PATCH 27/35] fix: set dataset_uri from s3_data_input_path --- src/sagemaker/clarify.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4b8031ad77..79173ed537 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -765,6 +765,8 @@ def __init__( "time_series_data_config", self.analysis_config, ) + # Temporary bug fix + _set(s3_data_input_path, "dataset_uri", self.analysis_config) def get_config(self): """Returns part of an analysis config dictionary.""" From abafbddcf806d7544a9df94e521eaf6a83fd1648 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 14 Dec 2023 19:54:20 +0000 Subject: [PATCH 28/35] fix: undo previous bug fix what i believed was a bug was actually intended behaviour --- src/sagemaker/clarify.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 79173ed537..4b8031ad77 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -765,8 +765,6 @@ def __init__( "time_series_data_config", self.analysis_config, ) - # Temporary bug fix - _set(s3_data_input_path, "dataset_uri", self.analysis_config) def get_config(self): """Returns part of an analysis config dictionary.""" From 640dbd11b4d5db4c2241c3569b00c5c83ff27c9a Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Thu, 14 Mar 2024 19:55:10 +0000 Subject: [PATCH 29/35] feat: add ``ITEM_RECORDS`` as a supported dataset format change: remove ``headers`` as a requirement for time series doc: add example dataset formats to ``TimeSeriesJSONDatasetFormat`` --- src/sagemaker/clarify.py | 96 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 5de2485791..3540f8e6a3 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -389,9 +389,95 @@ class DatasetType(Enum): class TimeSeriesJSONDatasetFormat(Enum): - """Possible dataset formats for JSON time series data files.""" + """Possible dataset formats for JSON time series data files. + + Below is an example ``COLUMNS`` dataset for time series explainability: + + ``` + { + "ids": [1, 2], + "timestamps": [3, 4], + "target_ts": [5, 6], + "rts1": [0.25, 0.5], + "rts2": [1.25, 1.5], + "scv1": [10, 20], + "scv2": [30, 40] + } + + ``` + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: + + ``` + item_id="ids" + timestamp="timestamps" + target_time_series="target_ts" + related_time_series=["rts1", "rts2"] + static_covariates=["scv1", "scv2"] + ``` + + Below is an example ``ITEM_RECORDS`` dataset for time series explainability: + + ``` + [ + { + "id": 1, + "scv1": 10, + "scv2": "red", + "timeseries": [ + {"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10}, + {"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20}, + {"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30} + ] + }, + { + "id": 2, + "scv1": 20, + "scv2": "blue", + "timeseries": [ + {"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40}, + {"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50} + ] + } + ] + ``` + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: + + ``` + item_id="[*].id" + timestamp="[*].timeseries[].timestamp" + target_time_series="[*].timeseries[].target_ts" + related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"] + static_covariates=["[*].scv1", "[*].scv2"] + ``` + + Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability: + + ``` + [ + {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25}, + {"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5}, + {"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75}, + {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1} + ] + + ``` + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: + + ``` + item_id="[*].id" + timestamp="[*].timestamp" + target_time_series="[*].target_ts" + related_time_series=["[*].rts1"] + static_covariates=["[*].scv1"] + ``` + + """ COLUMNS = "columns" + ITEM_RECORDS = "item_records" TIMESTAMP_RECORDS = "timestamp_records" @@ -607,6 +693,11 @@ def __init__( Note: For JSON, the JMESPath query must result in a list of labels for each sample. For JSON Lines, it must result in the label for each line. Only a single label per sample is supported at this time. + headers (str): List of column names in the dataset. If not provided, Clarify will + generate headers to use internally. For time series explainability cases, + please provide headers in the following order: + item_id, timestamp, target_time_series, all related_time_series columns, + all static_covariate columns features (str): JMESPath expression to locate the feature values if the dataset format is JSON/JSON Lines. Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of @@ -716,9 +807,8 @@ def __init__( if time_series_data_config: if dataset_type != "application/json": raise ValueError( - "Currently time series explainability only supports JSON format data" + "Currently time series explainability only supports JSON format data." ) - assert headers, "Headers are required for time series explainability" # features JMESPath is required for JSON as we can't derive it ourselves if dataset_type == "application/json" and features is None and not time_series_data_config: raise ValueError("features JMESPath is required for application/json dataset_type") From 8aae27fd32a374f6b9c6ccd5aa175ee36c1224be Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 15 Mar 2024 19:58:28 +0000 Subject: [PATCH 30/35] doc: make docs for TimeSeriesJSONDatasetFormat sphinx-compliant change: add ``item_records`` as a supported format to the schema doc: add documentation for baseline doc: remove references to deprecated ``forecast_horizon`` --- src/sagemaker/clarify.py | 162 +++++++++++++++++++-------------------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 3540f8e6a3..795c938ff2 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -110,6 +110,7 @@ lambda s: s in ( "columns", + "item_records", "timestamp_records", ), ), @@ -391,89 +392,74 @@ class DatasetType(Enum): class TimeSeriesJSONDatasetFormat(Enum): """Possible dataset formats for JSON time series data files. - Below is an example ``COLUMNS`` dataset for time series explainability: + Below is an example ``COLUMNS`` dataset for time series explainability.:: - ``` - { - "ids": [1, 2], - "timestamps": [3, 4], - "target_ts": [5, 6], - "rts1": [0.25, 0.5], - "rts2": [1.25, 1.5], - "scv1": [10, 20], - "scv2": [30, 40] - } + { + "ids": [1, 2], + "timestamps": [3, 4], + "target_ts": [5, 6], + "rts1": [0.25, 0.5], + "rts2": [1.25, 1.5], + "scv1": [10, 20], + "scv2": [30, 40] + } - ``` + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: + item_id="ids" + timestamp="timestamps" + target_time_series="target_ts" + related_time_series=["rts1", "rts2"] + static_covariates=["scv1", "scv2"] - ``` - item_id="ids" - timestamp="timestamps" - target_time_series="target_ts" - related_time_series=["rts1", "rts2"] - static_covariates=["scv1", "scv2"] - ``` + Below is an example ``ITEM_RECORDS`` dataset for time series explainability.:: - Below is an example ``ITEM_RECORDS`` dataset for time series explainability: + [ + { + "id": 1, + "scv1": 10, + "scv2": "red", + "timeseries": [ + {"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10}, + {"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20}, + {"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30} + ] + }, + { + "id": 2, + "scv1": 20, + "scv2": "blue", + "timeseries": [ + {"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40}, + {"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50} + ] + } + ] - ``` - [ - { - "id": 1, - "scv1": 10, - "scv2": "red", - "timeseries": [ - {"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10}, - {"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20}, - {"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30} - ] - }, - { - "id": 2, - "scv1": 20, - "scv2": "blue", - "timeseries": [ - {"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40}, - {"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50} - ] - } - ] - ``` - - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: - - ``` - item_id="[*].id" - timestamp="[*].timeseries[].timestamp" - target_time_series="[*].timeseries[].target_ts" - related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"] - static_covariates=["[*].scv1", "[*].scv2"] - ``` - - Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability: - - ``` - [ - {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25}, - {"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5}, - {"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75}, - {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1} - ] + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: + + item_id="[*].id" + timestamp="[*].timeseries[].timestamp" + target_time_series="[*].timeseries[].target_ts" + related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"] + static_covariates=["[*].scv1", "[*].scv2"] - ``` + Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability.:: - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows: + [ + {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25}, + {"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5}, + {"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75}, + {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1} + ] - ``` - item_id="[*].id" - timestamp="[*].timestamp" - target_time_series="[*].target_ts" - related_time_series=["[*].rts1"] - static_covariates=["[*].scv1"] - ``` + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: + item_id="[*].id" + timestamp="[*].timestamp" + target_time_series="[*].target_ts" + related_time_series=["[*].rts1"] + static_covariates=["[*].scv1"] """ COLUMNS = "columns" @@ -693,11 +679,10 @@ def __init__( Note: For JSON, the JMESPath query must result in a list of labels for each sample. For JSON Lines, it must result in the label for each line. Only a single label per sample is supported at this time. - headers (str): List of column names in the dataset. If not provided, Clarify will + headers ([str]): List of column names in the dataset. If not provided, Clarify will generate headers to use internally. For time series explainability cases, - please provide headers in the following order: - item_id, timestamp, target_time_series, all related_time_series columns, - all static_covariate columns + please provide headers in the order of item_id, timestamp, target_time_series, + all related_time_series columns, and then all static_covariate columns. features (str): JMESPath expression to locate the feature values if the dataset format is JSON/JSON Lines. Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of @@ -959,10 +944,10 @@ def __init__( forecast (str): JMESPath expression to extract the forecast result. Raises: - AssertionError: when either ``forecast`` or ``forecast_horizon`` are not provided + AssertionError: when ``forecast`` is not provided ValueError: when any provided argument are not of specified type """ - # assert forecast and forecast_horizon are provided + # assert forecast is provided assert ( forecast ), "Please provide ``forecast``, a JMESPath expression to extract the forecast result." @@ -1775,8 +1760,23 @@ def __init__( num_samples (None or int): Number of samples to be used in the Asymmetric Shapley Value forecasting algorithm. Only applicable when using ``"fine_grained"`` explanations. - baseline (str or dict): Link to a baseline configuration or a dictionary for it. - # TODO: improve above. + baseline (str or dict): Link to a baseline configuration or a dictionary for it. The + baseline config is used to replace out-of-coalition values for the corresponding + datasets (also known as background data). For temporal data (target time series, + related time series), the baseline value types are "zero", where all + out-of-coalition values will be replaced with 0.0, or "mean", all out-of-coalition + values will be replaced with the average of a time series. For static data + (static covariates), a baseline value for each covariate should be provided for + each possible item_id. An example config follows, where ``item1`` and ``item2`` + are item ids.:: + { + "related_time_series": "zero", + "static_covariates": { + "item1": [1, 1], + "item2": [0, 1] + }, + "target_time_series": "zero" + } Raises: AssertionError: when ``direction`` or ``granularity`` are not valid, From bbeba313c40ce5e173eb730de7bdc0ecbbf8d74a Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Fri, 15 Mar 2024 23:01:11 +0000 Subject: [PATCH 31/35] feat: validation for asymmetric shapley value config baseline doc: fix baseline doc to be sphinx-compliant --- src/sagemaker/clarify.py | 41 +++++++++++++++++++++++++++++---- tests/unit/test_clarify.py | 47 +++++++++++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 795c938ff2..af099ff2ef 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -332,8 +332,24 @@ SchemaOptional("baseline"): Or( str, { - SchemaOptional("target_time_series", default="zero"): str, - SchemaOptional("related_time_series"): str, + SchemaOptional("target_time_series", default="zero"): And( + str, + Use(str.lower), + lambda s: s + in ( + "zero", + "mean", + ), + ), + SchemaOptional("related_time_series"): And( + str, + Use(str.lower), + lambda s: s + in ( + "zero", + "mean", + ), + ), SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]}, }, ), @@ -1769,13 +1785,14 @@ def __init__( (static covariates), a baseline value for each covariate should be provided for each possible item_id. An example config follows, where ``item1`` and ``item2`` are item ids.:: + { + "target_time_series": "zero", "related_time_series": "zero", "static_covariates": { "item1": [1, 1], "item2": [0, 1] - }, - "target_time_series": "zero" + } } Raises: @@ -1803,13 +1820,27 @@ def __init__( ), f"{direction} and {granularity} granularity are not supported together." elif num_samples: # validate num_samples is not provided when unnecessary raise ValueError("``num_samples`` is only used for fine-grained explanations.") + # validate baseline if provided as a dictionary + if isinstance(baseline, dict): + temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields + if "target_time_series" in baseline: + target_baseline = baseline.get("target_time_series") + assert target_baseline in temporal_baselines, ( + f"Provided value {target_baseline} for ``target_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) + if "related_time_series" in baseline: + related_baseline = baseline.get("related_time_series") + assert related_baseline in temporal_baselines, ( + f"Provided value {related_baseline} for ``related_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) # set explanation type and (if provided) num_samples in internal config dictionary _set(direction, "direction", self.asymmetric_shapley_value_config) _set(granularity, "granularity", self.asymmetric_shapley_value_config) _set( num_samples, "num_samples", self.asymmetric_shapley_value_config ) # _set() does nothing if a given argument is None - # TODO: add sdk-side validation to baseline _set(baseline, "baseline", self.asymmetric_shapley_value_config) def get_explainability_config(self): diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 14f77aadd9..8228f7850a 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -17,7 +17,7 @@ import pytest from mock import ANY, MagicMock, Mock, patch -from typing import List, NamedTuple, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union from sagemaker import Processor, image_uris from sagemaker.clarify import ( @@ -1283,9 +1283,10 @@ def test_shap_config_no_parameters(): class AsymmetricShapleyValueConfigCase(NamedTuple): direction: str granularity: str - num_samples: Optional[int] - error: Exception - error_message: str + num_samples: Optional[int] = None + baseline: Optional[Union[str, Dict[str, Any]]] = None + error: Exception = None + error_message: str = None class TestAsymmetricShapleyValueConfig: @@ -1296,22 +1297,28 @@ class TestAsymmetricShapleyValueConfig: direction=direction, granularity="timewise", num_samples=None, - error=None, - error_message=None, ) for direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS ] + [ - AsymmetricShapleyValueConfigCase( # cases for fine_grained granularity + AsymmetricShapleyValueConfigCase( # case for fine_grained granularity direction="chronological", granularity="fine_grained", num_samples=1, - error=None, - error_message=None, - ) + ), + AsymmetricShapleyValueConfigCase( # case for target time series baseline + direction="chronological", + granularity="timewise", + baseline={"target_time_series": "mean"}, + ), + AsymmetricShapleyValueConfigCase( # case for related time series baseline + direction="chronological", + granularity="timewise", + baseline={"related_time_series": "zero"}, + ), ], ) - def test_asymmetric_shapley_value_config(self, test_case): + def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValueConfigCase): """ GIVEN valid arguments for an AsymmetricShapleyValueConfig object WHEN AsymmetricShapleyValueConfig object is instantiated with those arguments @@ -1325,11 +1332,14 @@ def test_asymmetric_shapley_value_config(self, test_case): } if test_case.granularity == "fine_grained": expected_config["num_samples"] = test_case.num_samples + if test_case.baseline: + expected_config["baseline"] = test_case.baseline # WHEN asym_shap_val_config = AsymmetricShapleyValueConfig( direction=test_case.direction, granularity=test_case.granularity, num_samples=test_case.num_samples, + baseline=test_case.baseline, ) # THEN assert asym_shap_val_config.asymmetric_shapley_value_config == expected_config @@ -1380,6 +1390,20 @@ def test_asymmetric_shapley_value_config(self, test_case): error=AssertionError, error_message="not supported together.", ), + AsymmetricShapleyValueConfigCase( # case for unsupported target time series baseline value + direction="chronological", + granularity="timewise", + baseline={"target_time_series": "median"}, + error=AssertionError, + error_message="for ``target_time_series`` is invalid.", + ), + AsymmetricShapleyValueConfigCase( # case for unsupported related time series baseline value + direction="chronological", + granularity="timewise", + baseline={"related_time_series": "mode"}, + error=AssertionError, + error_message="for ``related_time_series`` is invalid.", + ), ], ) def test_asymmetric_shapley_value_config_invalid(self, test_case): @@ -1394,6 +1418,7 @@ def test_asymmetric_shapley_value_config_invalid(self, test_case): direction=test_case.direction, granularity=test_case.granularity, num_samples=test_case.num_samples, + baseline=test_case.baseline, ) From a8ad1d1bdece5f760be9ae8bfa9a0e644bc71c0f Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Mon, 18 Mar 2024 17:25:44 +0000 Subject: [PATCH 32/35] feat: add validation for static covariates in tsx baseline --- src/sagemaker/clarify.py | 56 ++++++++++-- tests/unit/test_clarify.py | 176 ++++++++++++++++++++++++++++++++++++- 2 files changed, 223 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index af099ff2ef..6a19325add 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -408,7 +408,7 @@ class DatasetType(Enum): class TimeSeriesJSONDatasetFormat(Enum): """Possible dataset formats for JSON time series data files. - Below is an example ``COLUMNS`` dataset for time series explainability.:: + Below is an example ``COLUMNS`` dataset for time series explainability:: { "ids": [1, 2], @@ -420,7 +420,7 @@ class TimeSeriesJSONDatasetFormat(Enum): "scv2": [30, 40] } - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: item_id="ids" timestamp="timestamps" @@ -428,7 +428,7 @@ class TimeSeriesJSONDatasetFormat(Enum): related_time_series=["rts1", "rts2"] static_covariates=["scv1", "scv2"] - Below is an example ``ITEM_RECORDS`` dataset for time series explainability.:: + Below is an example ``ITEM_RECORDS`` dataset for time series explainability:: [ { @@ -452,7 +452,7 @@ class TimeSeriesJSONDatasetFormat(Enum): } ] - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: item_id="[*].id" timestamp="[*].timeseries[].timestamp" @@ -460,7 +460,7 @@ class TimeSeriesJSONDatasetFormat(Enum): related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"] static_covariates=["[*].scv1", "[*].scv2"] - Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability.:: + Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability:: [ {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25}, @@ -469,7 +469,7 @@ class TimeSeriesJSONDatasetFormat(Enum): {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1} ] - For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.:: + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: item_id="[*].id" timestamp="[*].timestamp" @@ -1784,7 +1784,7 @@ def __init__( values will be replaced with the average of a time series. For static data (static covariates), a baseline value for each covariate should be provided for each possible item_id. An example config follows, where ``item1`` and ``item2`` - are item ids.:: + are item ids:: { "target_time_series": "zero", @@ -2548,7 +2548,7 @@ def explainability( explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Generates a config for Explainability""" - # determine if this is a timeseries explainability case by checking + # determine if this is a time series explainability case by checking # if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given ts_data_config_present = "time_series_data_config" in data_config.analysis_config ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config @@ -2556,6 +2556,8 @@ def explainability( if isinstance(explainability_config, AsymmetricShapleyValueConfig): assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." + # Check static covariates baseline matches number of provided static covariate columns + else: if ts_data_config_present: raise ValueError( @@ -2759,6 +2761,44 @@ def _merge_explainability_configs( return explainability_methods return explainability_config.get_explainability_config() + @classmethod + def _validate_time_series_static_covariates_baseline( + cls, + explainability_config: AsymmetricShapleyValueConfig, + data_config: DataConfig, + ): + """Validates static covariates in baseline for asymmetric shapley value (for time series). + + Checks that baseline values set for static covariate columns are + consistent between every item_id and the number of static covariate columns + provided in DataConfig. + """ + baseline = explainability_config.get_explainability_config()[ + "asymmetric_shapley_value" + ].get("baseline") + if baseline and "static_covariates" in baseline: + covariate_count = len( + data_config.get_config()["time_series_data_config"].get("static_covariates", []) + ) + if covariate_count > 0: + for item_id in baseline.get("static_covariates", []): + baseline_entry = baseline["static_covariates"][item_id] + assert isinstance(baseline_entry, list), ( + f"Baseline entry for {item_id} must be a list, is " + f"{type(baseline_entry)}." + ) + assert len(baseline_entry) == covariate_count, ( + f"Length of baseline entry for {item_id} does not match number " + f"of static covariate columns. Please ensure every covariate " + f"has a baseline value for every item id." + ) + else: + raise ValueError( + "Static covariate baselines are provided in AsymmetricShapleyValueConfig " + "when no static covariate columns are provided in TimeSeriesDataConfig. " + "Please check these configs." + ) + def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 8228f7850a..f716005988 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -16,6 +16,7 @@ import copy import pytest +from dataclasses import dataclass from mock import ANY, MagicMock, Mock, patch from typing import Any, Dict, List, NamedTuple, Optional, Union @@ -2574,7 +2575,8 @@ def _build_pdp_config_mock(): def _build_asymmetric_shapley_value_config_mock(): asym_shap_val_config_dict = { - "explanation_type": "fine_grained", + "direction": "chronological", + "granularity": "fine_grained", "num_samples": 20, } asym_shap_val_config = Mock(spec=AsymmetricShapleyValueConfig) @@ -2613,6 +2615,14 @@ def _build_model_config_mock(): return model_config +@dataclass +class ValidateTSXBaselineCase: + explainability_config: AsymmetricShapleyValueConfig + data_config: DataConfig + error: Optional[Exception] = None + error_msg: Optional[str] = None + + class TestAnalysisConfigGeneratorForTimeSeriesExplainability: @patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods") @patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor") @@ -2794,6 +2804,170 @@ def test_merge_explainability_configs_with_timeseries_invalid( explainability_config=mock_config, ) + @pytest.mark.parametrize( + "case", + [ + ValidateTSXBaselineCase( + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0, 0.5, 0.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + ), + ], + ) + def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBaselineCase): + """ + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline + is provided + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called + THEN no error is raised + """ + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=case.explainability_config, + data_config=case.data_config, + ) + + @pytest.mark.parametrize( + "case", + [ + ValidateTSXBaselineCase( # some item ids are missing baseline values + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=AssertionError, + error_msg="baseline entry for item3 does not match number", + ), + ValidateTSXBaselineCase( # no static covariates are in data config + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0, 0.5, 0.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=ValueError, + error_msg="no static covariate columns are provided in TimeSeriesDataConfig", + ), + ValidateTSXBaselineCase( # some item ids do not have a list as their baseline + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": {"cov_1": 1.0, "cov_2": 0.5, "cov_3": 0.0}, + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=AssertionError, + error_msg="Baseline entry for item5 must be a list", + ), + ], + ) + def test_time_series_baseline_invalid_static_covariates(self, case: ValidateTSXBaselineCase): + """ + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline + is provided where the static covariates baseline values are misconfigured + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called + THEN the appropriate error is raised + """ + with pytest.raises(case.error, match=case.error_msg): + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=case.explainability_config, + data_config=case.data_config, + ) + class TestProcessingOutputHandler: def test_get_s3_upload_mode_image(self): From b535f530ab1211b38a63f81750fc4b4072a5af4b Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Mon, 18 Mar 2024 18:36:34 +0000 Subject: [PATCH 33/35] fix: add call to validate baseline static covariates method --- src/sagemaker/clarify.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6a19325add..3342884bb7 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2557,7 +2557,10 @@ def explainability( assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." # Check static covariates baseline matches number of provided static covariate columns - + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=explainability_config, + data_config=data_config, + ) else: if ts_data_config_present: raise ValueError( From 222fb759c0d4fa1c1b201ef111375bafc2c398c8 Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Mon, 18 Mar 2024 19:53:31 +0000 Subject: [PATCH 34/35] fix: check if baseline dict is s3 uri in scv validation function --- src/sagemaker/clarify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 3342884bb7..6d977d7f59 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -2779,7 +2779,7 @@ def _validate_time_series_static_covariates_baseline( baseline = explainability_config.get_explainability_config()[ "asymmetric_shapley_value" ].get("baseline") - if baseline and "static_covariates" in baseline: + if isinstance(baseline, dict) and "static_covariates" in baseline: covariate_count = len( data_config.get_config()["time_series_data_config"].get("static_covariates", []) ) From 311f48637ed7f472bfb3fb33f012f502f92080fb Mon Sep 17 00:00:00 2001 From: Rahul Sahu Date: Tue, 19 Mar 2024 18:29:36 +0000 Subject: [PATCH 35/35] fix: replace all added asserts with ValueError --- src/sagemaker/clarify.py | 126 ++++++++++++++++++++----------------- tests/unit/test_clarify.py | 78 ++++++++++------------- 2 files changed, 100 insertions(+), 104 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6d977d7f59..246cdbcc2d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -587,13 +587,15 @@ def __init__( when dataset is in JSON format. Raises: - AssertionError: If any required arguments are not provided. - ValueError: If any provided arguments are the wrong type. + ValueError: If any required arguments are not provided or are the wrong type. """ # check target_time_series, item_id, and timestamp are provided - assert target_time_series, "Please provide a target time series." - assert item_id, "Please provide an item id." - assert timestamp, "Please provide a timestamp." + if not target_time_series: + raise ValueError("Please provide a target time series.") + if not item_id: + raise ValueError("Please provide an item id.") + if not timestamp: + raise ValueError("Please provide a timestamp.") # check all arguments are the right types if not isinstance(target_time_series, (str, int)): raise ValueError("Please provide a string or an int for ``target_time_series``") @@ -644,14 +646,14 @@ def __init__( ) # static_covariates is valid, add it if params_type == str: # check dataset_format is provided and valid - assert isinstance( - dataset_format, TimeSeriesJSONDatasetFormat - ), "Please provide a valid dataset format." + if not isinstance(dataset_format, TimeSeriesJSONDatasetFormat): + raise ValueError("Please provide a valid dataset format.") _set(dataset_format.value, "dataset_format", self.time_series_data_config) else: - assert ( - not dataset_format - ), "Dataset format should only be provided when data files are JSONs." + if dataset_format: + raise ValueError( + "Dataset format should only be provided when data files are JSONs." + ) def get_time_series_data_config(self): """Returns part of an analysis config dictionary.""" @@ -960,16 +962,14 @@ def __init__( forecast (str): JMESPath expression to extract the forecast result. Raises: - AssertionError: when ``forecast`` is not provided - ValueError: when any provided argument are not of specified type + ValueError: when ``forecast`` is not a string or not provided """ - # assert forecast is provided - assert ( - forecast - ), "Please provide ``forecast``, a JMESPath expression to extract the forecast result." - # check provided arguments are of the right type + # check string forecast is provided if not isinstance(forecast, str): - raise ValueError("Please provide a string JMESPath expression for ``forecast``.") + raise ValueError( + "Please provide a string JMESPath expression for ``forecast`` " + "to extract the forecast result." + ) # add fields to an internal config dictionary self.time_series_model_config = dict() _set(forecast, "forecast", self.time_series_model_config) @@ -1796,28 +1796,30 @@ def __init__( } Raises: - AssertionError: when ``direction`` or ``granularity`` are not valid, - or ``num_samples`` is not provided for fine-grained explanations - ValueError: when ``num_samples`` is provided for non fine-grained explanations, or - when direction is not ``"chronological"`` when granularity is - ``"fine_grained"``. + ValueError: when ``direction`` or ``granularity`` are not valid, ``num_samples`` is not + provided for fine-grained explanations, ``num_samples`` is provided for non + fine-grained explanations, or when ``direction`` is not ``"chronological"`` while + ``granularity`` is ``"fine_grained"``. """ self.asymmetric_shapley_value_config = dict() # validate explanation direction - assert ( - direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS - ), "Please provide a valid explanation direction from: " + ", ".join( - ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS - ) + if direction not in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS: + raise ValueError( + "Please provide a valid explanation direction from: " + + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS) + ) # validate granularity - assert ( - granularity in ASYM_SHAP_VAL_GRANULARITIES - ), "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES) + if granularity not in ASYM_SHAP_VAL_GRANULARITIES: + raise ValueError( + "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES) + ) if granularity == "fine_grained": - assert isinstance(num_samples, int), "Please provide an integer for ``num_samples``." - assert ( - direction == "chronological" - ), f"{direction} and {granularity} granularity are not supported together." + if not isinstance(num_samples, int): + raise ValueError("Please provide an integer for ``num_samples``.") + if direction != "chronological": + raise ValueError( + f"{direction} and {granularity} granularity are not supported together." + ) elif num_samples: # validate num_samples is not provided when unnecessary raise ValueError("``num_samples`` is only used for fine-grained explanations.") # validate baseline if provided as a dictionary @@ -1825,16 +1827,18 @@ def __init__( temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields if "target_time_series" in baseline: target_baseline = baseline.get("target_time_series") - assert target_baseline in temporal_baselines, ( - f"Provided value {target_baseline} for ``target_time_series`` is " - f"invalid. Please select one of {temporal_baselines}." - ) + if target_baseline not in temporal_baselines: + raise ValueError( + f"Provided value {target_baseline} for ``target_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) if "related_time_series" in baseline: related_baseline = baseline.get("related_time_series") - assert related_baseline in temporal_baselines, ( - f"Provided value {related_baseline} for ``related_time_series`` is " - f"invalid. Please select one of {temporal_baselines}." - ) + if related_baseline not in temporal_baselines: + raise ValueError( + f"Provided value {related_baseline} for ``related_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) # set explanation type and (if provided) num_samples in internal config dictionary _set(direction, "direction", self.asymmetric_shapley_value_config) _set(granularity, "granularity", self.asymmetric_shapley_value_config) @@ -2550,25 +2554,27 @@ def explainability( """Generates a config for Explainability""" # determine if this is a time series explainability case by checking # if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given - ts_data_config_present = "time_series_data_config" in data_config.analysis_config - ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config + ts_data_conf_absent = "time_series_data_config" not in data_config.analysis_config + ts_model_conf_absent = "time_series_predictor_config" not in model_config.predictor_config if isinstance(explainability_config, AsymmetricShapleyValueConfig): - assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig." - assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig." + if ts_data_conf_absent: + raise ValueError("Please provide a TimeSeriesDataConfig to DataConfig.") + if ts_model_conf_absent: + raise ValueError("Please provide a TimeSeriesModelConfig to ModelConfig.") # Check static covariates baseline matches number of provided static covariate columns _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( explainability_config=explainability_config, data_config=data_config, ) else: - if ts_data_config_present: + if not ts_data_conf_absent: raise ValueError( "Please provide an AsymmetricShapleyValueConfig for time series " "explainability cases. For non time series cases, please do not provide a " "TimeSeriesDataConfig." ) - if ts_model_config_present: + if not ts_model_conf_absent: raise ValueError( "Please provide an AsymmetricShapleyValueConfig for time series " "explainability cases. For non time series cases, please do not provide a " @@ -2786,15 +2792,17 @@ def _validate_time_series_static_covariates_baseline( if covariate_count > 0: for item_id in baseline.get("static_covariates", []): baseline_entry = baseline["static_covariates"][item_id] - assert isinstance(baseline_entry, list), ( - f"Baseline entry for {item_id} must be a list, is " - f"{type(baseline_entry)}." - ) - assert len(baseline_entry) == covariate_count, ( - f"Length of baseline entry for {item_id} does not match number " - f"of static covariate columns. Please ensure every covariate " - f"has a baseline value for every item id." - ) + if not isinstance(baseline_entry, list): + raise ValueError( + f"Baseline entry for {item_id} must be a list, is " + f"{type(baseline_entry)}." + ) + if len(baseline_entry) != covariate_count: + raise ValueError( + f"Length of baseline entry for {item_id} does not match number " + f"of static covariate columns. Please ensure every covariate " + f"has a baseline value for every item id." + ) else: raise ValueError( "Static covariate baselines are provided in AsymmetricShapleyValueConfig " diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index f716005988..82ac2cd8bc 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -328,15 +328,16 @@ def test_s3_data_distribution_type_ignorance(): assert data_config.s3_data_distribution_type == "FullyReplicated" -class TimeSeriesDataConfigCase(NamedTuple): +@dataclass +class TimeSeriesDataConfigCase: target_time_series: Union[str, int] item_id: Union[str, int] timestamp: Union[str, int] related_time_series: Optional[List[Union[str, int]]] static_covariates: Optional[List[Union[str, int]]] - dataset_format: Optional[TimeSeriesJSONDatasetFormat] - error: Exception - error_message: Optional[str] + dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None + error: Optional[Exception] = None + error_message: Optional[str] = None class TestTimeSeriesDataConfig: @@ -348,8 +349,6 @@ class TestTimeSeriesDataConfig: related_time_series=None, static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # related_time_series provided str case target_time_series="target_time_series", @@ -358,8 +357,6 @@ class TestTimeSeriesDataConfig: related_time_series=["ts1", "ts2", "ts3"], static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # static_covariates provided str case target_time_series="target_time_series", @@ -368,8 +365,6 @@ class TestTimeSeriesDataConfig: related_time_series=None, static_covariates=["a", "b", "c", "d"], dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided str case target_time_series="target_time_series", @@ -378,8 +373,6 @@ class TestTimeSeriesDataConfig: related_time_series=["ts1", "ts2", "ts3"], static_covariates=["a", "b", "c", "d"], dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # no optional args provided int case target_time_series=1, @@ -387,9 +380,6 @@ class TestTimeSeriesDataConfig: timestamp=3, related_time_series=None, static_covariates=None, - dataset_format=None, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # related_time_series provided int case target_time_series=1, @@ -397,9 +387,6 @@ class TestTimeSeriesDataConfig: timestamp=3, related_time_series=[4, 5, 6], static_covariates=None, - dataset_format=None, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # static_covariates provided int case target_time_series=1, @@ -407,9 +394,6 @@ class TestTimeSeriesDataConfig: timestamp=3, related_time_series=None, static_covariates=[7, 8, 9, 10], - dataset_format=None, - error=None, - error_message=None, ), TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided int case target_time_series=1, @@ -417,9 +401,6 @@ class TestTimeSeriesDataConfig: timestamp=3, related_time_series=[4, 5, 6], static_covariates=[7, 8, 9, 10], - dataset_format=None, - error=None, - error_message=None, ), ] @@ -464,7 +445,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=None, static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=AssertionError, + error=ValueError, error_message="Please provide a target time series.", ), TimeSeriesDataConfigCase( # no item_id provided @@ -474,7 +455,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=None, static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=AssertionError, + error=ValueError, error_message="Please provide an item id.", ), TimeSeriesDataConfigCase( # no timestamp provided @@ -484,7 +465,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=None, static_covariates=None, dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=AssertionError, + error=ValueError, error_message="Please provide a timestamp.", ), TimeSeriesDataConfigCase( # target_time_series not int or str @@ -604,7 +585,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=[4, 5, 6], static_covariates=[7, 8, 9, 10], dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, - error=AssertionError, + error=ValueError, error_message="Dataset format should only be provided when data files are JSONs.", ), TimeSeriesDataConfigCase( # dataset_format not provided str case @@ -614,7 +595,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=["ts1", "ts2", "ts3"], static_covariates=["a", "b", "c", "d"], dataset_format=None, - error=AssertionError, + error=ValueError, error_message="Please provide a valid dataset format.", ), TimeSeriesDataConfigCase( # dataset_format wrong type str case @@ -624,7 +605,7 @@ def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): related_time_series=["ts1", "ts2", "ts3"], static_covariates=["a", "b", "c", "d"], dataset_format="made_up_format", - error=AssertionError, + error=ValueError, error_message="Please provide a valid dataset format.", ), ], @@ -1039,13 +1020,13 @@ def test_time_series_model_config(self): [ ( None, - AssertionError, - "Please provide ``forecast``, a JMESPath expression to extract the forecast result.", + ValueError, + "Please provide a string JMESPath expression for ``forecast``", ), ( 123, ValueError, - "Please provide a string JMESPath expression for ``forecast``.", + "Please provide a string JMESPath expression for ``forecast``", ), ], ) @@ -1352,7 +1333,7 @@ def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValue direction="non-directional", granularity="timewise", num_samples=None, - error=AssertionError, + error=ValueError, error_message="Please provide a valid explanation direction from: " + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS), ), @@ -1360,14 +1341,14 @@ def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValue direction="chronological", granularity="fine_grained", num_samples=None, - error=AssertionError, + error=ValueError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for fine_grained and non-integer num_samples direction="chronological", granularity="fine_grained", num_samples="5", - error=AssertionError, + error=ValueError, error_message="Please provide an integer for ``num_samples``.", ), AsymmetricShapleyValueConfigCase( # case for num_samples when non fine-grained explanation @@ -1381,28 +1362,28 @@ def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValue direction="anti_chronological", granularity="fine_grained", num_samples=5, - error=AssertionError, + error=ValueError, error_message="not supported together.", ), AsymmetricShapleyValueConfigCase( # case for bidirectional and fine_grained direction="bidirectional", granularity="fine_grained", num_samples=5, - error=AssertionError, + error=ValueError, error_message="not supported together.", ), AsymmetricShapleyValueConfigCase( # case for unsupported target time series baseline value direction="chronological", granularity="timewise", baseline={"target_time_series": "median"}, - error=AssertionError, + error=ValueError, error_message="for ``target_time_series`` is invalid.", ), AsymmetricShapleyValueConfigCase( # case for unsupported related time series baseline value direction="chronological", granularity="timewise", baseline={"related_time_series": "mode"}, - error=AssertionError, + error=ValueError, error_message="for ``related_time_series`` is invalid.", ), ], @@ -2624,9 +2605,12 @@ class ValidateTSXBaselineCase: class TestAnalysisConfigGeneratorForTimeSeriesExplainability: + @patch( + "sagemaker.clarify._AnalysisConfigGenerator._validate_time_series_static_covariates_baseline" + ) @patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods") @patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor") - def test_explainability_for_time_series(self, _add_predictor, _add_methods): + def test_explainability_for_time_series(self, _add_predictor, _add_methods, _validate_ts_scv): """ GIVEN a valid DataConfig and ModelConfig that contain time_series_data_config and time_series_model_config respectively as well as an AsymmetricShapleyValueConfig @@ -2671,6 +2655,10 @@ def test_explainability_for_time_series(self, _add_predictor, _add_methods): ANY, explainability_config=explainability_config, ) + _validate_ts_scv.assert_called_once_with( + explainability_config=explainability_config, + data_config=data_config_mock, + ) def test_explainability_for_time_series_invalid(self): # data config mocks @@ -2687,7 +2675,7 @@ def test_explainability_for_time_series_invalid(self): pdp_config_mock = _build_pdp_config_mock() # case 1: ASV (ts case) and no timeseries data config given with pytest.raises( - AssertionError, match="Please provide a TimeSeriesDataConfig to DataConfig." + ValueError, match="Please provide a TimeSeriesDataConfig to DataConfig." ): _AnalysisConfigGenerator.explainability( data_config=data_config_without_ts, @@ -2697,7 +2685,7 @@ def test_explainability_for_time_series_invalid(self): ) # case 2: ASV (ts case) and no timeseries model config given with pytest.raises( - AssertionError, match="Please provide a TimeSeriesModelConfig to ModelConfig." + ValueError, match="Please provide a TimeSeriesModelConfig to ModelConfig." ): _AnalysisConfigGenerator.explainability( data_config=data_config_with_ts, @@ -2885,7 +2873,7 @@ def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBas dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, ), ), - error=AssertionError, + error=ValueError, error_msg="baseline entry for item3 does not match number", ), ValidateTSXBaselineCase( # no static covariates are in data config @@ -2950,7 +2938,7 @@ def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBas dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, ), ), - error=AssertionError, + error=ValueError, error_msg="Baseline entry for item5 must be a list", ), ],