From 092a9fc4ef48073760afb93367aa6cb5bc216c5a Mon Sep 17 00:00:00 2001 From: rvasahu-amazon <106207732+rvasahu-amazon@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:07:16 -0700 Subject: [PATCH] feature: add support to ``clarify.py`` for time series explainability jobs (#4503) * feature: Added components to support time series explainability with Clarify. These components are TimeSeriesDataConfig, TimeSeriesModelConfig, and AsymmetricSHAPConfig, alomg with unit tests for them. fix: Modified DataConfig, ModelConfig, and _AnalysisConfigGenerator to support the new components and time series explainability. documentation: added docstrings for the new components and their tests. * fix: removed field use_future_covariates and related unit tests from TimeSeriesModelConfig * change: rename ``TimeSeriesDataConfig.analysis_config`` to ``time_series_data_config`` * change: validate DataConfig content_type and accept_type for TS exp. change: renamed TimeSeriesDataConfig.predictor_config to time_series_model_config change: modified default value of forecast_horizon to 1 * change: reworked validation in AsymmetricSHAPConfig change: removed default value for num_samples change: changed default value for explanation_type change: added more explicit type hint for explanation_type documentation: added more information for parameters change: reworked unit tests for AsymmetricSHAPConfig * change: time series case no longer uses _merge_explainability_configs change: _merge_explainability_configs reordered to put validation first change: unit tests reworked * fix: minor style changes to meet formatting reqs * change: modified how time_series_case flag is set change: removed now-redundant check in time_series_case * fix: set time_series_case to False to prevent exception * change: params for ``TimeSeriesDataConfig`` now must all be same type change: updated ``TimeSeriesDataConfig`` unit tests to reflect above change * fix: schema entries for related_ts and item_metadata to keep list items same type * change: remove forecast_horizon from TimeSeriesModelConfig * fix: modified type hints in ``TimeSeriesDataConfig`` to match schema * change: add errors when ts data or model config are given but no asym_shap config change: add unit tests for _AnalysisConfigGenerator.explainability fix: slightly modify AsymmetricSHAPConfig mock builder function documentation: minor docstring update in tests * change: remove flag ``time_series_case``, modify unit tests accordingly change: add a check to _AnalysisConfiGenerator.bias_and_explainability to prevent time series components being provided, added unit tests for this * change: rename `AsymmetricSHAPConfig` to `AsymmetricShapleyValueConfig` * documentation: add description to ``AsymmetricShapleyValueConfig`` documentation: reword `target_time_series` parameter description documentation: remove TODOs * change: split ``explanation_type`` into ``explanation_ direction`` and ``granularity`` update tests and documentation accordingly * fix: rename ``explanation_granularity`` to ``granularity`` * change: rename ``explanation_direction`` to ``direction`` * change: rename ``item_metadata`` to ``static_covariates`` change: add ``dataset_format`` as a parameter for time series cases change: allow features jmespaths to be none for time series cases change: add validation to prevent non-json dataset formats for time series cases test: update unit tests to reflect above changes * fix: update clarify files to meet formatting reqs * change: require headers for time series explainability * feat: add (early version of) baseline config to asym shap val config * refactor: change baseline config param names to keep with convention and be more cx friendly * refactor: baseline config from list to dictionary where key is item_id value * fix: set dataset_uri from s3_data_input_path * fix: undo previous bug fix what i believed was a bug was actually intended behaviour * feat: add ``ITEM_RECORDS`` as a supported dataset format change: remove ``headers`` as a requirement for time series doc: add example dataset formats to ``TimeSeriesJSONDatasetFormat`` * doc: make docs for TimeSeriesJSONDatasetFormat sphinx-compliant change: add ``item_records`` as a supported format to the schema doc: add documentation for baseline doc: remove references to deprecated ``forecast_horizon`` * feat: validation for asymmetric shapley value config baseline doc: fix baseline doc to be sphinx-compliant * feat: add validation for static covariates in tsx baseline * fix: add call to validate baseline static covariates method * fix: check if baseline dict is s3 uri in scv validation function * fix: replace all added asserts with ValueError --------- Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/clarify.py | 573 +++++++++++++++++++- tests/unit/test_clarify.py | 1042 +++++++++++++++++++++++++++++++++++- 2 files changed, 1602 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 11bc43c43a..246cdbcc2d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,7 +25,7 @@ import tempfile from abc import ABC, abstractmethod -from typing import List, Union, Dict, Optional, Any +from typing import List, Literal, Union, Dict, Optional, Any from enum import Enum from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex @@ -40,6 +40,19 @@ ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" +# asym shap val config default values (timeseries) +ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION = "chronological" +ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY = "timewise" +ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS = [ + "chronological", + "anti_chronological", + "bidirectional", +] +ASYM_SHAP_VAL_GRANULARITIES = [ + "timewise", + "fine_grained", +] + ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema( { SchemaOptional("version"): str, @@ -85,6 +98,23 @@ SchemaOptional("excluded_columns"): [Or(int, str)], SchemaOptional("joinsource_name_or_index"): Or(str, int), SchemaOptional("group_variable"): Or(str, int), + SchemaOptional("time_series_data_config"): { + "target_time_series": Or(str, int), + "item_id": Or(str, int), + "timestamp": Or(str, int), + SchemaOptional("related_time_series"): Or([str], [int]), + SchemaOptional("static_covariates"): Or([str], [int]), + SchemaOptional("dataset_format"): And( + str, + Use(str.lower), + lambda s: s + in ( + "columns", + "item_records", + "timestamp_records", + ), + ), + }, "methods": { SchemaOptional("shap"): { SchemaOptional("baseline"): Or( @@ -278,6 +308,52 @@ SchemaOptional("top_k_features"): int, }, SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, + SchemaOptional("asymmetric_shapley_value"): { + "direction": And( + str, + Use(str.lower), + lambda s: s + in ( + "chronological", + "anti_chronological", + "bidirectional", + ), + ), + "granularity": And( + str, + Use(str.lower), + lambda s: s + in ( + "timewise", + "fine_grained", + ), + ), + SchemaOptional("num_samples"): int, + SchemaOptional("baseline"): Or( + str, + { + SchemaOptional("target_time_series", default="zero"): And( + str, + Use(str.lower), + lambda s: s + in ( + "zero", + "mean", + ), + ), + SchemaOptional("related_time_series"): And( + str, + Use(str.lower), + lambda s: s + in ( + "zero", + "mean", + ), + ), + SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]}, + }, + ), + }, }, SchemaOptional("predictor"): { SchemaOptional("endpoint_name"): str, @@ -311,6 +387,9 @@ SchemaOptional("content_template"): Or(str, {str: str}), SchemaOptional("record_template"): str, SchemaOptional("custom_attributes"): str, + SchemaOptional("time_series_predictor_config"): { + "forecast": str, + }, }, } ) @@ -326,6 +405,84 @@ class DatasetType(Enum): IMAGE = "application/x-image" +class TimeSeriesJSONDatasetFormat(Enum): + """Possible dataset formats for JSON time series data files. + + Below is an example ``COLUMNS`` dataset for time series explainability:: + + { + "ids": [1, 2], + "timestamps": [3, 4], + "target_ts": [5, 6], + "rts1": [0.25, 0.5], + "rts2": [1.25, 1.5], + "scv1": [10, 20], + "scv2": [30, 40] + } + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: + + item_id="ids" + timestamp="timestamps" + target_time_series="target_ts" + related_time_series=["rts1", "rts2"] + static_covariates=["scv1", "scv2"] + + Below is an example ``ITEM_RECORDS`` dataset for time series explainability:: + + [ + { + "id": 1, + "scv1": 10, + "scv2": "red", + "timeseries": [ + {"timestamp": 1, "target_ts": 5, "rts1": 0.25, "rts2": 10}, + {"timestamp": 2, "target_ts": 6, "rts1": 0.35, "rts2": 20}, + {"timestamp": 3, "target_ts": 4, "rts1": 0.45, "rts2": 30} + ] + }, + { + "id": 2, + "scv1": 20, + "scv2": "blue", + "timeseries": [ + {"timestamp": 1, "target_ts": 4, "rts1": 0.25, "rts2": 40}, + {"timestamp": 2, "target_ts": 2, "rts1": 0.35, "rts2": 50} + ] + } + ] + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: + + item_id="[*].id" + timestamp="[*].timeseries[].timestamp" + target_time_series="[*].timeseries[].target_ts" + related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"] + static_covariates=["[*].scv1", "[*].scv2"] + + Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability:: + + [ + {"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25}, + {"id": 1, "timestamp": 2, "target_ts": 6, "scv1": 10, "rts1": 0.5}, + {"id": 1, "timestamp": 3, "target_ts": 3, "scv1": 10, "rts1": 0.75}, + {"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1} + ] + + For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows:: + + item_id="[*].id" + timestamp="[*].timestamp" + target_time_series="[*].target_ts" + related_time_series=["[*].rts1"] + static_covariates=["[*].scv1"] + """ + + COLUMNS = "columns" + ITEM_RECORDS = "item_records" + TIMESTAMP_RECORDS = "timestamp_records" + + class SegmentationConfig: """Config object that defines segment(s) of the dataset on which metrics are computed.""" @@ -394,6 +551,115 @@ def to_dict(self) -> Dict[str, Any]: # pragma: no cover return segment_config_dict +class TimeSeriesDataConfig: + """Config object for TimeSeries explainability data configuration fields.""" + + def __init__( + self, + target_time_series: Union[str, int], + item_id: Union[str, int], + timestamp: Union[str, int], + related_time_series: Optional[List[Union[str, int]]] = None, + static_covariates: Optional[List[Union[str, int]]] = None, + dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None, + ): + """Initialises TimeSeries explainability data configuration fields. + + Args: + target_time_series (str or int): A string or a zero-based integer index. + Used to locate the target time series in the shared input dataset. + If this parameter is a string, then all other parameters except + `dataset_format` must be strings or lists of strings. If + this parameter is an int, then all other parameters except + `dataset_format` must be ints or lists of ints. + item_id (str or int): A string or a zero-based integer index. Used to + locate item id in the shared input dataset. + timestamp (str or int): A string or a zero-based integer index. Used to + locate timestamp in the shared input dataset. + related_time_series (list[str] or list[int]): Optional. An array of strings + or array of zero-based integer indices. Used to locate all related time + series in the shared input dataset (if present). + static_covariates (list[str] or list[int]): Optional. An array of strings or + array of zero-based integer indices. Used to locate all static covariate + fields in the shared input dataset (if present). + dataset_format (TimeSeriesJSONDatasetFormat): Describes the format + of the data files provided for analysis. Should only be provided + when dataset is in JSON format. + + Raises: + ValueError: If any required arguments are not provided or are the wrong type. + """ + # check target_time_series, item_id, and timestamp are provided + if not target_time_series: + raise ValueError("Please provide a target time series.") + if not item_id: + raise ValueError("Please provide an item id.") + if not timestamp: + raise ValueError("Please provide a timestamp.") + # check all arguments are the right types + if not isinstance(target_time_series, (str, int)): + raise ValueError("Please provide a string or an int for ``target_time_series``") + params_type = type(target_time_series) + if not isinstance(item_id, params_type): + raise ValueError(f"Please provide {params_type} for ``item_id``") + if not isinstance(timestamp, params_type): + raise ValueError(f"Please provide {params_type} for ``timestamp``") + # add mandatory fields to an internal dictionary + self.time_series_data_config = dict() + _set(target_time_series, "target_time_series", self.time_series_data_config) + _set(item_id, "item_id", self.time_series_data_config) + _set(timestamp, "timestamp", self.time_series_data_config) + # check optional arguments are right types if provided + related_time_series_error_message = ( + f"Please provide a list of {params_type} for ``related_time_series``" + ) + if related_time_series: + if not isinstance(related_time_series, list): + raise ValueError( + related_time_series_error_message + ) # related_time_series is not a list + if not all([isinstance(value, params_type) for value in related_time_series]): + raise ValueError( + related_time_series_error_message + ) # related_time_series is not a list of strings or list of ints + if params_type == str and not all(related_time_series): + raise ValueError("Please do not provide empty strings in ``related_time_series``.") + _set( + related_time_series, "related_time_series", self.time_series_data_config + ) # related_time_series is valid, add it + static_covariates_series_error_message = ( + f"Please provide a list of {params_type} for ``static_covariates``" + ) + if static_covariates: + if not isinstance(static_covariates, list): + raise ValueError( + static_covariates_series_error_message + ) # static_covariates is not a list + if not all([isinstance(value, params_type) for value in static_covariates]): + raise ValueError( + static_covariates_series_error_message + ) # static_covariates is not a list of strings or list of ints + if params_type == str and not all(static_covariates): + raise ValueError("Please do not provide empty strings in ``static_covariates``.") + _set( + static_covariates, "static_covariates", self.time_series_data_config + ) # static_covariates is valid, add it + if params_type == str: + # check dataset_format is provided and valid + if not isinstance(dataset_format, TimeSeriesJSONDatasetFormat): + raise ValueError("Please provide a valid dataset format.") + _set(dataset_format.value, "dataset_format", self.time_series_data_config) + else: + if dataset_format: + raise ValueError( + "Dataset format should only be provided when data files are JSONs." + ) + + def get_time_series_data_config(self): + """Returns part of an analysis config dictionary.""" + return copy.deepcopy(self.time_series_data_config) + + class DataConfig: """Config object related to configurations of the input and output dataset.""" @@ -415,6 +681,7 @@ def __init__( predicted_label: Optional[Union[str, int]] = None, excluded_columns: Optional[Union[List[int], List[str]]] = None, segmentation_config: Optional[List[SegmentationConfig]] = None, + time_series_data_config: Optional[TimeSeriesDataConfig] = None, ): """Initializes a configuration of both input and output datasets. @@ -430,6 +697,10 @@ def __init__( Note: For JSON, the JMESPath query must result in a list of labels for each sample. For JSON Lines, it must result in the label for each line. Only a single label per sample is supported at this time. + headers ([str]): List of column names in the dataset. If not provided, Clarify will + generate headers to use internally. For time series explainability cases, + please provide headers in the order of item_id, timestamp, target_time_series, + all related_time_series columns, and then all static_covariate columns. features (str): JMESPath expression to locate the feature values if the dataset format is JSON/JSON Lines. Note: For JSON, the JMESPath query must result in a 2-D list (or a matrix) of @@ -483,6 +754,8 @@ def __init__( which are to be excluded from making model inference API calls. segmentation_config (list[SegmentationConfig]): A list of ``SegmentationConfig`` objects. + time_series_data_config (TimeSeriesDataConfig): Optional. A config object for TimeSeries + data specific fields, required for TimeSeries explainability use cases. Raises: ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters @@ -533,8 +806,14 @@ def __init__( f" are not supported for dataset_type '{dataset_type}'." f" Please check the API documentation for the supported dataset types." ) + # check if any other format other than JSON is provided for time series case + if time_series_data_config: + if dataset_type != "application/json": + raise ValueError( + "Currently time series explainability only supports JSON format data." + ) # features JMESPath is required for JSON as we can't derive it ourselves - if dataset_type == "application/json" and features is None: + if dataset_type == "application/json" and features is None and not time_series_data_config: raise ValueError("features JMESPath is required for application/json dataset_type") self.s3_data_input_path = s3_data_input_path self.s3_output_path = s3_output_path @@ -574,6 +853,12 @@ def __init__( "segment_config", self.analysis_config, ) + if time_series_data_config: + _set( + time_series_data_config.get_time_series_data_config(), + "time_series_data_config", + self.analysis_config, + ) def get_config(self): """Returns part of an analysis config dictionary.""" @@ -664,6 +949,36 @@ def get_config(self): return copy.deepcopy(self.analysis_config) +class TimeSeriesModelConfig: + """Config object for TimeSeries predictor configuration fields.""" + + def __init__( + self, + forecast: str, + ): + """Initializes model configuration fields for TimeSeries explainability use cases. + + Args: + forecast (str): JMESPath expression to extract the forecast result. + + Raises: + ValueError: when ``forecast`` is not a string or not provided + """ + # check string forecast is provided + if not isinstance(forecast, str): + raise ValueError( + "Please provide a string JMESPath expression for ``forecast`` " + "to extract the forecast result." + ) + # add fields to an internal config dictionary + self.time_series_model_config = dict() + _set(forecast, "forecast", self.time_series_model_config) + + def get_time_series_model_config(self): + """Returns TimeSeries model config dictionary""" + return copy.deepcopy(self.time_series_model_config) + + class ModelConfig: """Config object related to a model and its endpoint to be created.""" @@ -681,6 +996,7 @@ def __init__( endpoint_name_prefix: Optional[str] = None, target_model: Optional[str] = None, endpoint_name: Optional[str] = None, + time_series_model_config: Optional[TimeSeriesModelConfig] = None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -797,6 +1113,9 @@ def __init__( endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint. Cannot be set when ``model_name``, ``instance_count``, and ``instance_type`` set + time_series_model_config (TimeSeriesModelConfig): Optional. A config object for + TimeSeries predictor specific fields, required for TimeSeries + explainability use cases. Raises: ValueError: when the @@ -841,6 +1160,10 @@ def __init__( f"Invalid accept_type {accept_type}." f" Please choose text/csv or application/jsonlines." ) + if time_series_model_config and accept_type == "text/csv": + raise ValueError( + "``accept_type`` must be JSON or JSONLines for time series explainability." + ) self.predictor_config["accept_type"] = accept_type if content_type is not None: if content_type not in [ @@ -877,6 +1200,13 @@ def __init__( f"Invalid content_template {content_template}." f" Please include either placeholder $records or $record." ) + if time_series_model_config and content_type not in [ + "application/json", + "application/jsonlines", + ]: + raise ValueError( + "``content_type`` must be JSON or JSONLines for time series explainability." + ) self.predictor_config["content_type"] = content_type if content_template is not None: self.predictor_config["content_template"] = content_template @@ -885,6 +1215,12 @@ def __init__( _set(custom_attributes, "custom_attributes", self.predictor_config) _set(accelerator_type, "accelerator_type", self.predictor_config) _set(target_model, "target_model", self.predictor_config) + if time_series_model_config: + _set( + time_series_model_config.get_time_series_model_config(), + "time_series_predictor_config", + self.predictor_config, + ) def get_predictor_config(self): """Returns part of the predictor dictionary of the analysis config.""" @@ -1400,6 +1736,122 @@ def get_explainability_config(self): return copy.deepcopy({"shap": self.shap_config}) +class AsymmetricShapleyValueConfig(ExplainabilityConfig): + """Config class for Asymmetric Shapley value algorithm for time series explainability. + + Asymmetric Shapley Values are a variant of the Shapley Value that drop the symmetry axiom [1]. + We use these to determine how features contribute to the forecasting outcome. Asymmetric + Shapley values can take into account the temporal dependencies of the time series that + forecasting models take as input. + + [1] Frye, Christopher, Colin Rowat, and Ilya Feige. "Asymmetric shapley values: incorporating + causal knowledge into model-agnostic explainability." NeurIPS (2020). + https://doi.org/10.48550/arXiv.1910.06358 + """ + + def __init__( + self, + direction: Literal[ + "chronological", + "anti_chronological", + "bidirectional", + ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_DIRECTION, + granularity: Literal[ + "timewise", + "fine_grained", + ] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY, + num_samples: Optional[int] = None, + baseline: Optional[Union[str, Dict[str, Any]]] = None, + ): + """Initialises config for time series explainability with Asymmetric Shapley Values. + + AsymmetricShapleyValueConfig is used specifically and only for TimeSeries explainability + purposes. + + Args: + direction (str): Type of explanation to be used. Available explanation + types are ``"chronological"``, ``"anti_chronological"``, and ``"bidirectional"``. + granularity (str): Explanation granularity to be used. Available granularity options + are ``"timewise"`` and ``"fine_grained"``. + num_samples (None or int): Number of samples to be used in the Asymmetric Shapley + Value forecasting algorithm. Only applicable when using ``"fine_grained"`` + explanations. + baseline (str or dict): Link to a baseline configuration or a dictionary for it. The + baseline config is used to replace out-of-coalition values for the corresponding + datasets (also known as background data). For temporal data (target time series, + related time series), the baseline value types are "zero", where all + out-of-coalition values will be replaced with 0.0, or "mean", all out-of-coalition + values will be replaced with the average of a time series. For static data + (static covariates), a baseline value for each covariate should be provided for + each possible item_id. An example config follows, where ``item1`` and ``item2`` + are item ids:: + + { + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [1, 1], + "item2": [0, 1] + } + } + + Raises: + ValueError: when ``direction`` or ``granularity`` are not valid, ``num_samples`` is not + provided for fine-grained explanations, ``num_samples`` is provided for non + fine-grained explanations, or when ``direction`` is not ``"chronological"`` while + ``granularity`` is ``"fine_grained"``. + """ + self.asymmetric_shapley_value_config = dict() + # validate explanation direction + if direction not in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS: + raise ValueError( + "Please provide a valid explanation direction from: " + + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS) + ) + # validate granularity + if granularity not in ASYM_SHAP_VAL_GRANULARITIES: + raise ValueError( + "Please provide a valid granularity from: " + ", ".join(ASYM_SHAP_VAL_GRANULARITIES) + ) + if granularity == "fine_grained": + if not isinstance(num_samples, int): + raise ValueError("Please provide an integer for ``num_samples``.") + if direction != "chronological": + raise ValueError( + f"{direction} and {granularity} granularity are not supported together." + ) + elif num_samples: # validate num_samples is not provided when unnecessary + raise ValueError("``num_samples`` is only used for fine-grained explanations.") + # validate baseline if provided as a dictionary + if isinstance(baseline, dict): + temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields + if "target_time_series" in baseline: + target_baseline = baseline.get("target_time_series") + if target_baseline not in temporal_baselines: + raise ValueError( + f"Provided value {target_baseline} for ``target_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) + if "related_time_series" in baseline: + related_baseline = baseline.get("related_time_series") + if related_baseline not in temporal_baselines: + raise ValueError( + f"Provided value {related_baseline} for ``related_time_series`` is " + f"invalid. Please select one of {temporal_baselines}." + ) + # set explanation type and (if provided) num_samples in internal config dictionary + _set(direction, "direction", self.asymmetric_shapley_value_config) + _set(granularity, "granularity", self.asymmetric_shapley_value_config) + _set( + num_samples, "num_samples", self.asymmetric_shapley_value_config + ) # _set() does nothing if a given argument is None + _set(baseline, "baseline", self.asymmetric_shapley_value_config) + + def get_explainability_config(self): + """Returns an asymmetric shap config dictionary.""" + return copy.deepcopy({"asymmetric_shapley_value": self.asymmetric_shapley_value_config}) + + class SageMakerClarifyProcessor(Processor): """Handles SageMaker Processing tasks to compute bias metrics and model explanations.""" @@ -2072,6 +2524,13 @@ def bias_and_explainability( post_training_methods: Union[str, List[str]] = "all", ): """Generates a config for Bias and Explainability""" + # TimeSeries bias metrics are not supported + if ( + isinstance(explainability_config, AsymmetricShapleyValueConfig) + or "time_series_data_config" in data_config.analysis_config + or (model_config and "time_series_predictor_config" in model_config.predictor_config) + ): + raise ValueError("Bias metrics are unsupported for time series.") analysis_config = {**data_config.get_config(), **bias_config.get_config()} analysis_config = cls._add_methods( analysis_config, @@ -2093,12 +2552,43 @@ def explainability( explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Generates a config for Explainability""" + # determine if this is a time series explainability case by checking + # if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given + ts_data_conf_absent = "time_series_data_config" not in data_config.analysis_config + ts_model_conf_absent = "time_series_predictor_config" not in model_config.predictor_config + + if isinstance(explainability_config, AsymmetricShapleyValueConfig): + if ts_data_conf_absent: + raise ValueError("Please provide a TimeSeriesDataConfig to DataConfig.") + if ts_model_conf_absent: + raise ValueError("Please provide a TimeSeriesModelConfig to ModelConfig.") + # Check static covariates baseline matches number of provided static covariate columns + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=explainability_config, + data_config=data_config, + ) + else: + if not ts_data_conf_absent: + raise ValueError( + "Please provide an AsymmetricShapleyValueConfig for time series " + "explainability cases. For non time series cases, please do not provide a " + "TimeSeriesDataConfig." + ) + if not ts_model_conf_absent: + raise ValueError( + "Please provide an AsymmetricShapleyValueConfig for time series " + "explainability cases. For non time series cases, please do not provide a " + "TimeSeriesModelConfig." + ) + + # construct whole analysis config analysis_config = data_config.analysis_config analysis_config = cls._add_predictor( analysis_config, model_config, model_predicted_label_config ) analysis_config = cls._add_methods( - analysis_config, explainability_config=explainability_config + analysis_config, + explainability_config=explainability_config, ) return analysis_config @@ -2165,7 +2655,11 @@ def _add_predictor( if isinstance(model_config, ModelConfig): analysis_config["predictor"] = model_config.get_predictor_config() else: - if "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"]: + if ( + "shap" in analysis_config["methods"] + or "pdp" in analysis_config["methods"] + or "asymmetric_shapley_value" in analysis_config["methods"] + ): raise ValueError( "model_config must be provided when explainability methods are selected." ) @@ -2196,7 +2690,7 @@ def _add_methods( pre_training_methods: Union[str, List[str]] = None, post_training_methods: Union[str, List[str]] = None, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None, - report=True, + report: bool = True, ): """Extends analysis config with methods.""" # validate @@ -2226,7 +2720,12 @@ def _add_methods( analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} if explainability_config is not None: - explainability_methods = cls._merge_explainability_configs(explainability_config) + if isinstance(explainability_config, AsymmetricShapleyValueConfig): + explainability_methods = explainability_config.get_explainability_config() + else: + explainability_methods = cls._merge_explainability_configs( + explainability_config, + ) analysis_config["methods"] = { **analysis_config["methods"], **explainability_methods, @@ -2239,10 +2738,25 @@ def _merge_explainability_configs( explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Merges explainability configs, when more than one.""" + non_ts = "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses." + # validation + if isinstance(explainability_config, AsymmetricShapleyValueConfig): + raise ValueError(non_ts) + if ( + isinstance(explainability_config, PDPConfig) + and "features" not in explainability_config.get_explainability_config()["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") if isinstance(explainability_config, list): - explainability_methods = {} if len(explainability_config) == 0: raise ValueError("Please provide at least one explainability config.") + # list validation + for config in explainability_config: + # ensure all provided explainability configs are not AsymmetricShapleyValueConfig + if isinstance(config, AsymmetricShapleyValueConfig): + raise ValueError(non_ts) + # main logic + explainability_methods = {} for config in explainability_config: explain_config = config.get_explainability_config() explainability_methods.update(explain_config) @@ -2254,13 +2768,48 @@ def _merge_explainability_configs( ): raise ValueError("PDP features must be provided when ShapConfig is not provided") return explainability_methods - if ( - isinstance(explainability_config, PDPConfig) - and "features" not in explainability_config.get_explainability_config()["pdp"] - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") return explainability_config.get_explainability_config() + @classmethod + def _validate_time_series_static_covariates_baseline( + cls, + explainability_config: AsymmetricShapleyValueConfig, + data_config: DataConfig, + ): + """Validates static covariates in baseline for asymmetric shapley value (for time series). + + Checks that baseline values set for static covariate columns are + consistent between every item_id and the number of static covariate columns + provided in DataConfig. + """ + baseline = explainability_config.get_explainability_config()[ + "asymmetric_shapley_value" + ].get("baseline") + if isinstance(baseline, dict) and "static_covariates" in baseline: + covariate_count = len( + data_config.get_config()["time_series_data_config"].get("static_covariates", []) + ) + if covariate_count > 0: + for item_id in baseline.get("static_covariates", []): + baseline_entry = baseline["static_covariates"][item_id] + if not isinstance(baseline_entry, list): + raise ValueError( + f"Baseline entry for {item_id} must be a list, is " + f"{type(baseline_entry)}." + ) + if len(baseline_entry) != covariate_count: + raise ValueError( + f"Length of baseline entry for {item_id} does not match number " + f"of static covariate columns. Please ensure every covariate " + f"has a baseline value for every item id." + ) + else: + raise ValueError( + "Static covariate baselines are provided in AsymmetricShapleyValueConfig " + "when no static covariate columns are provided in TimeSeriesDataConfig. " + "Please check these configs." + ) + def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 58d3f56639..82ac2cd8bc 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -16,23 +16,30 @@ import copy import pytest -from mock import MagicMock, Mock, patch +from dataclasses import dataclass +from mock import ANY, MagicMock, Mock, patch +from typing import Any, Dict, List, NamedTuple, Optional, Union from sagemaker import Processor, image_uris from sagemaker.clarify import ( BiasConfig, DataConfig, + TimeSeriesDataConfig, ModelConfig, + TimeSeriesModelConfig, ModelPredictedLabelConfig, PDPConfig, SageMakerClarifyProcessor, SHAPConfig, + AsymmetricShapleyValueConfig, TextConfig, ImageConfig, _AnalysisConfigGenerator, DatasetType, ProcessingOutputHandler, SegmentationConfig, + ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS, + TimeSeriesJSONDatasetFormat, ) JOB_NAME_PREFIX = "my-prefix" @@ -321,6 +328,353 @@ def test_s3_data_distribution_type_ignorance(): assert data_config.s3_data_distribution_type == "FullyReplicated" +@dataclass +class TimeSeriesDataConfigCase: + target_time_series: Union[str, int] + item_id: Union[str, int] + timestamp: Union[str, int] + related_time_series: Optional[List[Union[str, int]]] + static_covariates: Optional[List[Union[str, int]]] + dataset_format: Optional[TimeSeriesJSONDatasetFormat] = None + error: Optional[Exception] = None + error_message: Optional[str] = None + + +class TestTimeSeriesDataConfig: + valid_ts_data_config_case_list = [ + TimeSeriesDataConfigCase( # no optional args provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + ), + TimeSeriesDataConfigCase( # related_time_series provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + ), + TimeSeriesDataConfigCase( # static_covariates provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=["a", "b", "c", "d"], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + ), + TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=["a", "b", "c", "d"], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + ), + TimeSeriesDataConfigCase( # no optional args provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=None, + static_covariates=None, + ), + TimeSeriesDataConfigCase( # related_time_series provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + static_covariates=None, + ), + TimeSeriesDataConfigCase( # static_covariates provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=None, + static_covariates=[7, 8, 9, 10], + ), + TimeSeriesDataConfigCase( # both related_time_series and static_covariates provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + static_covariates=[7, 8, 9, 10], + ), + ] + + @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) + def test_time_series_data_config(self, test_case: TimeSeriesDataConfigCase): + """ + GIVEN A set of valid parameters are given + WHEN A TimeSeriesDataConfig object is instantiated + THEN the returned config dictionary matches what's expected + """ + # construct expected output + expected_output = { + "target_time_series": test_case.target_time_series, + "item_id": test_case.item_id, + "timestamp": test_case.timestamp, + } + if isinstance(test_case.target_time_series, str): + expected_output["dataset_format"] = test_case.dataset_format.value + if test_case.related_time_series: + expected_output["related_time_series"] = test_case.related_time_series + if test_case.static_covariates: + expected_output["static_covariates"] = test_case.static_covariates + # GIVEN, WHEN + ts_data_config = TimeSeriesDataConfig( + target_time_series=test_case.target_time_series, + item_id=test_case.item_id, + timestamp=test_case.timestamp, + related_time_series=test_case.related_time_series, + static_covariates=test_case.static_covariates, + dataset_format=test_case.dataset_format, + ) + # THEN + assert ts_data_config.time_series_data_config == expected_output + + @pytest.mark.parametrize( + "test_case", + [ + TimeSeriesDataConfigCase( # no target_time_series provided + target_time_series=None, + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please provide a target time series.", + ), + TimeSeriesDataConfigCase( # no item_id provided + target_time_series="target_time_series", + item_id=None, + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please provide an item id.", + ), + TimeSeriesDataConfigCase( # no timestamp provided + target_time_series="target_time_series", + item_id="item_id", + timestamp=None, + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please provide a timestamp.", + ), + TimeSeriesDataConfigCase( # target_time_series not int or str + target_time_series=["target_time_series"], + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please provide a string or an int for ``target_time_series``", + ), + TimeSeriesDataConfigCase( # item_id differing type from str target_time_series + target_time_series="target_time_series", + item_id=5, + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please provide {str} for ``item_id``", + ), + TimeSeriesDataConfigCase( # timestamp differing type from str target_time_series + target_time_series="target_time_series", + item_id="item_id", + timestamp=10, + related_time_series=None, + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please provide {str} for ``timestamp``", + ), + TimeSeriesDataConfigCase( # related_time_series not str list if str target_time_series + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3", 4], + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please provide a list of {str} for ``related_time_series``", + ), + TimeSeriesDataConfigCase( # static_covariates not str list if str target_time_series + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=[4, 5, 6.0], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message=f"Please provide a list of {str} for ``static_covariates``", + ), + TimeSeriesDataConfigCase( # item_id differing type from int target_time_series + target_time_series=1, + item_id="item_id", + timestamp=3, + related_time_series=None, + static_covariates=None, + dataset_format=None, + error=ValueError, + error_message=f"Please provide {int} for ``item_id``", + ), + TimeSeriesDataConfigCase( # timestamp differing type from int target_time_series + target_time_series=1, + item_id=2, + timestamp="timestamp", + related_time_series=None, + static_covariates=None, + dataset_format=None, + error=ValueError, + error_message=f"Please provide {int} for ``timestamp``", + ), + TimeSeriesDataConfigCase( # related_time_series not int list if int target_time_series + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6, "ts7"], + static_covariates=None, + dataset_format=None, + error=ValueError, + error_message=f"Please provide a list of {int} for ``related_time_series``", + ), + TimeSeriesDataConfigCase( # static_covariates not int list if int target_time_series + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6, 7], + static_covariates=[8, 9, "10"], + dataset_format=None, + error=ValueError, + error_message=f"Please provide a list of {int} for ``static_covariates``", + ), + TimeSeriesDataConfigCase( # related_time_series contains blank string + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3", ""], + static_covariates=None, + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please do not provide empty strings in ``related_time_series``", + ), + TimeSeriesDataConfigCase( # static_covariates contains blank string + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=None, + static_covariates=["scv4", "scv5", "scv6", ""], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Please do not provide empty strings in ``static_covariates``", + ), + TimeSeriesDataConfigCase( # dataset_format provided int case + target_time_series=1, + item_id=2, + timestamp=3, + related_time_series=[4, 5, 6], + static_covariates=[7, 8, 9, 10], + dataset_format=TimeSeriesJSONDatasetFormat.COLUMNS, + error=ValueError, + error_message="Dataset format should only be provided when data files are JSONs.", + ), + TimeSeriesDataConfigCase( # dataset_format not provided str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=["a", "b", "c", "d"], + dataset_format=None, + error=ValueError, + error_message="Please provide a valid dataset format.", + ), + TimeSeriesDataConfigCase( # dataset_format wrong type str case + target_time_series="target_time_series", + item_id="item_id", + timestamp="timestamp", + related_time_series=["ts1", "ts2", "ts3"], + static_covariates=["a", "b", "c", "d"], + dataset_format="made_up_format", + error=ValueError, + error_message="Please provide a valid dataset format.", + ), + ], + ) + def test_time_series_data_config_invalid(self, test_case: TimeSeriesDataConfigCase): + """ + GIVEN required parameters are incomplete or invalid + WHEN TimeSeriesDataConfig constructor is called + THEN the expected error and message are raised + """ + with pytest.raises(test_case.error, match=test_case.error_message): + TimeSeriesDataConfig( + target_time_series=test_case.target_time_series, + item_id=test_case.item_id, + timestamp=test_case.timestamp, + related_time_series=test_case.related_time_series, + static_covariates=test_case.static_covariates, + dataset_format=test_case.dataset_format, + ) + + @pytest.mark.parametrize("test_case", valid_ts_data_config_case_list) + def test_data_config_with_time_series(self, test_case: TimeSeriesDataConfigCase): + """ + GIVEN a TimeSeriesDataConfig object is created + WHEN a DataConfig object is created and given valid params + the TimeSeriesDataConfig + THEN the internal config dictionary matches what's expected + """ + # currently TSX only supports json so skip non-json tests + if isinstance(test_case.target_time_series, int): + return + # setup + headers = ["item_id", "timestamp", "target_ts", "rts1", "scv1"] + # construct expected output + mock_ts_data_config_dict = { + "target_time_series": test_case.target_time_series, + "item_id": test_case.item_id, + "timestamp": test_case.timestamp, + } + if isinstance(test_case.target_time_series, str): + dataset_type = "application/json" + mock_ts_data_config_dict["dataset_format"] = test_case.dataset_format.value + else: + dataset_type = "text/csv" + if test_case.related_time_series: + mock_ts_data_config_dict["related_time_series"] = test_case.related_time_series + if test_case.static_covariates: + mock_ts_data_config_dict["static_covariates"] = test_case.static_covariates + expected_config = { + "dataset_type": dataset_type, + "headers": headers, + "time_series_data_config": mock_ts_data_config_dict, + } + # GIVEN + ts_data_config = Mock() + ts_data_config.get_time_series_data_config.return_value = copy.deepcopy( + mock_ts_data_config_dict + ) + # WHEN + data_config = DataConfig( + s3_data_input_path="s3://path/to/input.csv", + s3_output_path="s3://path/to/output", + headers=headers, + dataset_type=dataset_type, + time_series_data_config=ts_data_config, + ) + # THEN + assert expected_config == data_config.get_config() + + def test_bias_config(): label_values = [1] facet_name = "F1" @@ -641,6 +995,131 @@ def test_invalid_model_predicted_label_config(): ) +class TestTimeSeriesModelConfig: + def test_time_series_model_config(self): + """ + GIVEN a valid forecast expression + WHEN a TimeSeriesModelConfig is constructed with it + THEN the predictor_config dictionary matches the expected + """ + # GIVEN + forecast = "results.[forecast]" # mock JMESPath expression for forecast + # create expected output + expected_config = { + "forecast": forecast, + } + # WHEN + ts_model_config = TimeSeriesModelConfig( + forecast, + ) + # THEN + assert ts_model_config.time_series_model_config == expected_config + + @pytest.mark.parametrize( + ("forecast", "error", "error_message"), + [ + ( + None, + ValueError, + "Please provide a string JMESPath expression for ``forecast``", + ), + ( + 123, + ValueError, + "Please provide a string JMESPath expression for ``forecast``", + ), + ], + ) + def test_time_series_model_config_invalid( + self, + forecast, + error, + error_message, + ): + """ + GIVEN invalid args for a TimeSeriesModelConfig + WHEN TimeSeriesModelConfig constructor is called + THEN The appropriate error is raised + """ + with pytest.raises(error, match=error_message): + TimeSeriesModelConfig( + forecast=forecast, + ) + + @pytest.mark.parametrize( + ("content_type", "accept_type"), + [ + ("application/json", "application/json"), + ("application/json", "application/jsonlines"), + ("application/jsonlines", "application/json"), + ("application/jsonlines", "application/jsonlines"), + ], + ) + def test_model_config_with_time_series(self, content_type, accept_type): + """ + GIVEN valid fields for a ModelConfig and a TimeSeriesModelConfig + WHEN a ModelConfig is constructed with them + THEN actual predictor_config matches expected + """ + # setup + model_name = "xgboost-model" + instance_type = "ml.c5.xlarge" + instance_count = 1 + custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4" + target_model = "target_model_name" + accelerator_type = "ml.eia1.medium" + content_template = ( + '{"instances":$features}' + if content_type == "application/jsonlines" + else "$records" + if content_type == "application/json" + else None + ) + record_template = "$features_kvp" if content_type == "application/json" else None + # create mock config for TimeSeriesModelConfig + forecast = "results.[forecast]" # mock JMESPath expression for forecast + mock_ts_model_config_dict = { + "forecast": forecast, + } + # create expected config + expected_config = { + "model_name": model_name, + "instance_type": instance_type, + "initial_instance_count": instance_count, + "accept_type": accept_type, + "content_type": content_type, + "custom_attributes": custom_attributes, + "accelerator_type": accelerator_type, + "target_model": target_model, + "time_series_predictor_config": mock_ts_model_config_dict, + } + if content_template is not None: + expected_config["content_template"] = content_template + if record_template is not None: + expected_config["record_template"] = record_template + # GIVEN + mock_ts_model_config = Mock() # create mock TimeSeriesModelConfig object + mock_ts_model_config.get_time_series_model_config.return_value = copy.deepcopy( + mock_ts_model_config_dict + ) # set the mock's get_config return value + # WHEN + model_config = ModelConfig( + model_name=model_name, + instance_type=instance_type, + instance_count=instance_count, + accept_type=accept_type, + content_type=content_type, + content_template=content_template, + record_template=record_template, + custom_attributes=custom_attributes, + accelerator_type=accelerator_type, + target_model=target_model, + time_series_model_config=mock_ts_model_config, + ) + # THEN + assert expected_config == model_config.get_predictor_config() + + @pytest.mark.parametrize( "baseline", [ @@ -783,6 +1262,148 @@ def test_shap_config_no_parameters(): assert expected_config == shap_config.get_explainability_config() +class AsymmetricShapleyValueConfigCase(NamedTuple): + direction: str + granularity: str + num_samples: Optional[int] = None + baseline: Optional[Union[str, Dict[str, Any]]] = None + error: Exception = None + error_message: str = None + + +class TestAsymmetricShapleyValueConfig: + @pytest.mark.parametrize( + "test_case", + [ + AsymmetricShapleyValueConfigCase( # cases for timewise granularity + direction=direction, + granularity="timewise", + num_samples=None, + ) + for direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS + ] + + [ + AsymmetricShapleyValueConfigCase( # case for fine_grained granularity + direction="chronological", + granularity="fine_grained", + num_samples=1, + ), + AsymmetricShapleyValueConfigCase( # case for target time series baseline + direction="chronological", + granularity="timewise", + baseline={"target_time_series": "mean"}, + ), + AsymmetricShapleyValueConfigCase( # case for related time series baseline + direction="chronological", + granularity="timewise", + baseline={"related_time_series": "zero"}, + ), + ], + ) + def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValueConfigCase): + """ + GIVEN valid arguments for an AsymmetricShapleyValueConfig object + WHEN AsymmetricShapleyValueConfig object is instantiated with those arguments + THEN the asymmetric_shapley_value_config dictionary matches expected + """ + # test case is GIVEN + # construct expected config + expected_config = { + "direction": test_case.direction, + "granularity": test_case.granularity, + } + if test_case.granularity == "fine_grained": + expected_config["num_samples"] = test_case.num_samples + if test_case.baseline: + expected_config["baseline"] = test_case.baseline + # WHEN + asym_shap_val_config = AsymmetricShapleyValueConfig( + direction=test_case.direction, + granularity=test_case.granularity, + num_samples=test_case.num_samples, + baseline=test_case.baseline, + ) + # THEN + assert asym_shap_val_config.asymmetric_shapley_value_config == expected_config + + @pytest.mark.parametrize( + "test_case", + [ + AsymmetricShapleyValueConfigCase( # case for invalid direction + direction="non-directional", + granularity="timewise", + num_samples=None, + error=ValueError, + error_message="Please provide a valid explanation direction from: " + + ", ".join(ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS), + ), + AsymmetricShapleyValueConfigCase( # case for fine_grained and no num_samples + direction="chronological", + granularity="fine_grained", + num_samples=None, + error=ValueError, + error_message="Please provide an integer for ``num_samples``.", + ), + AsymmetricShapleyValueConfigCase( # case for fine_grained and non-integer num_samples + direction="chronological", + granularity="fine_grained", + num_samples="5", + error=ValueError, + error_message="Please provide an integer for ``num_samples``.", + ), + AsymmetricShapleyValueConfigCase( # case for num_samples when non fine-grained explanation + direction="chronological", + granularity="timewise", + num_samples=5, + error=ValueError, + error_message="``num_samples`` is only used for fine-grained explanations.", + ), + AsymmetricShapleyValueConfigCase( # case for anti_chronological and fine_grained + direction="anti_chronological", + granularity="fine_grained", + num_samples=5, + error=ValueError, + error_message="not supported together.", + ), + AsymmetricShapleyValueConfigCase( # case for bidirectional and fine_grained + direction="bidirectional", + granularity="fine_grained", + num_samples=5, + error=ValueError, + error_message="not supported together.", + ), + AsymmetricShapleyValueConfigCase( # case for unsupported target time series baseline value + direction="chronological", + granularity="timewise", + baseline={"target_time_series": "median"}, + error=ValueError, + error_message="for ``target_time_series`` is invalid.", + ), + AsymmetricShapleyValueConfigCase( # case for unsupported related time series baseline value + direction="chronological", + granularity="timewise", + baseline={"related_time_series": "mode"}, + error=ValueError, + error_message="for ``related_time_series`` is invalid.", + ), + ], + ) + def test_asymmetric_shapley_value_config_invalid(self, test_case): + """ + GIVEN invalid parameters for AsymmetricShapleyValue + WHEN AsymmetricShapleyValueConfig constructor is called with them + THEN the expected error and message are raised + """ + # test case is GIVEN + with pytest.raises(test_case.error, match=test_case.error_message): # THEN + AsymmetricShapleyValueConfig( # WHEN + direction=test_case.direction, + granularity=test_case.granularity, + num_samples=test_case.num_samples, + baseline=test_case.baseline, + ) + + def test_pdp_config(): pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20) expected_config = { @@ -1917,6 +2538,425 @@ def test_invalid_analysis_config(data_config, data_bias_config, model_config): ) +def _build_pdp_config_mock(): + pdp_config_dict = { + "pdp": { + "grid_resolution": 15, + "top_k_features": 10, + "features": [ + "some", + "features", + ], + } + } + pdp_config = Mock(spec=PDPConfig) + pdp_config.get_explainability_config.return_value = pdp_config_dict + return pdp_config + + +def _build_asymmetric_shapley_value_config_mock(): + asym_shap_val_config_dict = { + "direction": "chronological", + "granularity": "fine_grained", + "num_samples": 20, + } + asym_shap_val_config = Mock(spec=AsymmetricShapleyValueConfig) + asym_shap_val_config.get_explainability_config.return_value = { + "asymmetric_shapley_value": asym_shap_val_config_dict + } + return asym_shap_val_config + + +def _build_data_config_mock(): + """ + Builds a mock DataConfig for the time series _AnalysisConfigGenerator unit tests. + """ + # setup a time_series_data_config dictionary + time_series_data_config = { + "target_time_series": "target_ts", + "item_id": "id", + "timestamp": "timestamp", + "related_time_series": ["rts1", "rts2", "rts3"], + "static_covariates": ["scv1", "scv2", "scv3"], + "dataset_format": TimeSeriesJSONDatasetFormat.COLUMNS, + } + # setup DataConfig mock + data_config = Mock(spec=DataConfig) + data_config.analysis_config = {"time_series_data_config": time_series_data_config} + return data_config + + +def _build_model_config_mock(): + """ + Builds a mock ModelConfig for the time series _AnalysisConfigGenerator unit tests. + """ + time_series_model_config = {"forecast": "mean"} + model_config = Mock(spec=ModelConfig) + model_config.predictor_config = {"time_series_predictor_config": time_series_model_config} + return model_config + + +@dataclass +class ValidateTSXBaselineCase: + explainability_config: AsymmetricShapleyValueConfig + data_config: DataConfig + error: Optional[Exception] = None + error_msg: Optional[str] = None + + +class TestAnalysisConfigGeneratorForTimeSeriesExplainability: + @patch( + "sagemaker.clarify._AnalysisConfigGenerator._validate_time_series_static_covariates_baseline" + ) + @patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods") + @patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor") + def test_explainability_for_time_series(self, _add_predictor, _add_methods, _validate_ts_scv): + """ + GIVEN a valid DataConfig and ModelConfig that contain time_series_data_config and + time_series_model_config respectively as well as an AsymmetricShapleyValueConfig + WHEN _AnalysisConfigGenerator.explainability() is called with those args + THEN _add_predictor and _add methods calls are as expected + """ + # GIVEN + # get DataConfig mock + data_config_mock = _build_data_config_mock() + # get ModelConfig mock + model_config_mock = _build_model_config_mock() + # get AsymmetricShapleyValueConfig mock for explainability_config + explainability_config = _build_asymmetric_shapley_value_config_mock() + # get time_series_data_config dict from mock + time_series_data_config = copy.deepcopy( + data_config_mock.analysis_config.get("time_series_data_config") + ) + # get time_series_predictor_config from mock + time_series_model_config = copy.deepcopy( + model_config_mock.predictor_config.get("time_series_model_config") + ) + # setup _add_predictor call to return what would be expected at that stage + analysis_config_after_add_predictor = { + "time_series_data_config": time_series_data_config, + "time_series_predictor_config": time_series_model_config, + } + _add_predictor.return_value = analysis_config_after_add_predictor + # WHEN + _AnalysisConfigGenerator.explainability( + data_config=data_config_mock, + model_config=model_config_mock, + model_predicted_label_config=None, + explainability_config=explainability_config, + ) + # THEN + _add_predictor.assert_called_once_with( + data_config_mock.analysis_config, + model_config_mock, + ANY, + ) + _add_methods.assert_called_once_with( + ANY, + explainability_config=explainability_config, + ) + _validate_ts_scv.assert_called_once_with( + explainability_config=explainability_config, + data_config=data_config_mock, + ) + + def test_explainability_for_time_series_invalid(self): + # data config mocks + data_config_with_ts = _build_data_config_mock() + data_config_without_ts = Mock(spec=DataConfig) + data_config_without_ts.analysis_config = dict() + # model config mocks + model_config_with_ts = _build_model_config_mock() + model_config_without_ts = Mock(spec=ModelConfig) + model_config_without_ts.predictor_config = dict() + # asymmetric shapley value config mock (for ts) + asym_shap_val_config_mock = _build_asymmetric_shapley_value_config_mock() + # pdp config mock (for non-ts) + pdp_config_mock = _build_pdp_config_mock() + # case 1: ASV (ts case) and no timeseries data config given + with pytest.raises( + ValueError, match="Please provide a TimeSeriesDataConfig to DataConfig." + ): + _AnalysisConfigGenerator.explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_val_config_mock, + ) + # case 2: ASV (ts case) and no timeseries model config given + with pytest.raises( + ValueError, match="Please provide a TimeSeriesModelConfig to ModelConfig." + ): + _AnalysisConfigGenerator.explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_val_config_mock, + ) + # case 3: pdp (non ts case) and timeseries data config given + with pytest.raises(ValueError, match="please do not provide a TimeSeriesDataConfig."): + _AnalysisConfigGenerator.explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + ) + # case 4: pdp (non ts case) and timeseries model config given + with pytest.raises(ValueError, match="please do not provide a TimeSeriesModelConfig."): + _AnalysisConfigGenerator.explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + ) + + def test_bias_and_explainability_invalid_for_time_series(self): + """ + GIVEN user provides TimeSeriesDataConfig, TimeSeriesModelConfig, and/or + AsymmetricShapleyValueConfig for DataConfig, ModelConfig, and as explainability_config + respectively + WHEN _AnalysisConfigGenerator.bias_and_explainability is called + THEN the appropriate error is raised + """ + # data config mocks + data_config_with_ts = _build_data_config_mock() + data_config_without_ts = Mock(spec=DataConfig) + data_config_without_ts.analysis_config = dict() + # model config mocks + model_config_with_ts = _build_model_config_mock() + model_config_without_ts = Mock(spec=ModelConfig) + model_config_without_ts.predictor_config = dict() + # asymmetric shap config mock (for ts) + asym_shap_val_config_mock = _build_asymmetric_shapley_value_config_mock() + # pdp config mock (for non-ts) + pdp_config_mock = _build_pdp_config_mock() + # case 1: asymmetric shap is given as explainability_config + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_without_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=asym_shap_val_config_mock, + bias_config=None, + ) + # case 2: TimeSeriesModelConfig given to ModelConfig + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_without_ts, + model_config=model_config_with_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + bias_config=None, + ) + # case 3: TimeSeriesDataConfig given to DataConfig + with pytest.raises(ValueError, match="Bias metrics are unsupported for time series."): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config_with_ts, + model_config=model_config_without_ts, + model_predicted_label_config=None, + explainability_config=pdp_config_mock, + bias_config=None, + ) + + @pytest.mark.parametrize( + ("mock_config", "error", "error_message"), + [ + ( # single asym shap config for non TSX + _build_asymmetric_shapley_value_config_mock(), + ValueError, + "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses.", + ), + ( # list with asym shap config for non-TSX + [ + _build_asymmetric_shapley_value_config_mock(), + _build_pdp_config_mock(), + ], + ValueError, + "Please do not provide Asymmetric Shapley Value configs for non-TimeSeries uses.", + ), + ], + ) + def test_merge_explainability_configs_with_timeseries_invalid( + self, + mock_config, + error, + error_message, + ): + """ + GIVEN _merge_explainability_configs is called with a explainability config or list thereof + WHEN explainability_config is or contains an AsymmetricShapleyValueConfig + THEN the function will raise the appropriate error + """ + with pytest.raises(error, match=error_message): + _AnalysisConfigGenerator._merge_explainability_configs( + explainability_config=mock_config, + ) + + @pytest.mark.parametrize( + "case", + [ + ValidateTSXBaselineCase( + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0, 0.5, 0.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + ), + ], + ) + def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBaselineCase): + """ + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline + is provided + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called + THEN no error is raised + """ + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=case.explainability_config, + data_config=case.data_config, + ) + + @pytest.mark.parametrize( + "case", + [ + ValidateTSXBaselineCase( # some item ids are missing baseline values + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=ValueError, + error_msg="baseline entry for item3 does not match number", + ), + ValidateTSXBaselineCase( # no static covariates are in data config + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": [1.0, 0.5, 0.0], + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=ValueError, + error_msg="no static covariate columns are provided in TimeSeriesDataConfig", + ), + ValidateTSXBaselineCase( # some item ids do not have a list as their baseline + explainability_config=AsymmetricShapleyValueConfig( + direction="chronological", + granularity="timewise", + baseline={ + "target_time_series": "zero", + "related_time_series": "zero", + "static_covariates": { + "item1": [0.0, 0.5, 1.0], + "item2": [0.3, 0.6, 0.9], + "item3": [0.0, 1.0, 1.0], + "item4": [0.9, 0.6, 0.3], + "item5": {"cov_1": 1.0, "cov_2": 0.5, "cov_3": 0.0}, + }, + }, + ), + data_config=DataConfig( + s3_data_input_path="s3://data/input", + s3_output_path="s3://data/output", + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], + dataset_type="application/json", + time_series_data_config=TimeSeriesDataConfig( + item_id="[].id", + timestamp="[].temporal[].timestamp", + target_time_series="[].temporal[].target", + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, + ), + ), + error=ValueError, + error_msg="Baseline entry for item5 must be a list", + ), + ], + ) + def test_time_series_baseline_invalid_static_covariates(self, case: ValidateTSXBaselineCase): + """ + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline + is provided where the static covariates baseline values are misconfigured + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called + THEN the appropriate error is raised + """ + with pytest.raises(case.error, match=case.error_msg): + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( + explainability_config=case.explainability_config, + data_config=case.data_config, + ) + + class TestProcessingOutputHandler: def test_get_s3_upload_mode_image(self): analysis_config = {"dataset_type": DatasetType.IMAGE.value}