From 7b3594397852181fd4a26518fd2322a5879f123b Mon Sep 17 00:00:00 2001 From: Hao Zhou Date: Wed, 8 Jan 2025 16:18:31 -0800 Subject: [PATCH] Rewrite the deprecated functions in widget_view PiperOrigin-RevId: 713451970 --- .../plugin.py | 167 +++++++++++++++++- 1 file changed, 162 insertions(+), 5 deletions(-) diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py index f3e1856..869a7bf 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py @@ -13,28 +13,187 @@ # limitations under the License. # ============================================================================== """TensorBoard Fairnss Indicators plugin.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function import os +from typing import Any from absl import logging from tensorboard_plugin_fairness_indicators import metadata import six +import tensorflow as tf import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.addons.fairness.view import widget_view from werkzeug import wrappers + from google.protobuf import json_format from tensorboard.backend import http_util from tensorboard.plugins import base_plugin + _TEMPLATE_LOCATION = os.path.normpath( os.path.join( __file__, '../../' 'tensorflow_model_analysis/static/vulcanized_tfma.js')) +def stringify_slice_key_value( + slice_key: tfma.slicer.slicer_lib.SliceKeyType, +) -> str: + """Stringifies a slice key value. + + The string representation of a SingletonSliceKeyType is "feature:value". This + function returns value. + + When + multiple columns / features are specified, the string representation of a + SliceKeyType's value is "v1_X_v2_X_..." where v1, v2, ... are values. For + example, + ('gender, 'f'), ('age', 5) becomes f_X_5. If no columns / feature + specified, return "Overall". + + Note that we do not perform special escaping for slice values that contain + '_X_'. This stringified representation is meant to be human-readbale rather + than a reversible encoding. + + The columns will be in the same order as in SliceKeyType. If they are + generated using SingleSliceSpec.generate_slices, they will be in sorted order, + ascending. + + Technically float values are not supported, but we don't check for them here. + + Args: + slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes + should be sorted in ascending order. + + Returns: + String representation of the slice key's value. + """ + if not slice_key: + return 'Overall' + + # Since this is meant to be a human-readable string, we assume that the + # feature values are valid UTF-8 strings (might not be true in cases where + # people store serialised protos in the features for instance). + # We need to call as_str_any to convert non-string (e.g. integer) values to + # string first before converting to text. + # We use u'{}' instead of '{}' here to avoid encoding a unicode character with + # ascii codec. + values = [ + '{}'.format(tf.compat.as_text(tf.compat.as_str_any(value))) + for _, value in slice_key + ] + return '_X_'.join(values) + + +def _add_cross_slice_key_data( + slice_key: tfma.slicer.slicer_lib.CrossSliceKeyType, + metrics: tfma.view.view_types.MetricsByTextKey, + data: list[Any], +): + """Adds data for cross slice key. + + Baseline and comparison slice keys are joined by '__XX__'. + Args: + slice_key: Cross slice key. + metrics: Metrics data for the cross slice key. + data: List where UI data is to be appended. + """ + baseline_key = slice_key[0] + comparison_key = slice_key[1] + stringify_slice_value = ( + stringify_slice_key_value(baseline_key) + + '__XX__' + + stringify_slice_key_value(comparison_key) + ) + stringify_slice = ( + tfma.slicer.slicer_lib.stringify_slice_key(baseline_key) + + '__XX__' + + tfma.slicer.slicer_lib.stringify_slice_key(comparison_key) + ) + data.append({ + 'sliceValue': stringify_slice_value, + 'slice': stringify_slice, + 'metrics': metrics, + }) + + +def convert_slicing_metrics_to_ui_input( + slicing_metrics: list[ + tuple[ + tfma.slicer.slicer_lib.SliceKeyOrCrossSliceKeyType, + tfma.view.view_types.MetricsByOutputName, + ] + ], + slicing_column: str | None = None, + slicing_spec: tfma.slicer.slicer_lib.SingleSliceSpec | None = None, + output_name: str = '', + multi_class_key: str = '', +) -> list[dict[str, Any]] | None: + """Renders the Fairness Indicator view. + + Args: + slicing_metrics: tfma.EvalResult.slicing_metrics. + slicing_column: The slicing column to to filter results. If both + slicing_column and slicing_spec are None, show all eval results. + slicing_spec: The slicing spec to filter results. If both slicing_column and + slicing_spec are None, show all eval results. + output_name: The output name associated with metric (for multi-output + models). + multi_class_key: The multi-class key associated with metric (for multi-class + models). + + Returns: + A list of dicts for each slice, where each dict contains keys 'sliceValue', + 'slice', and 'metrics'. + + Raises: + ValueError if no related eval result found or both slicing_column and + slicing_spec are not None. + """ + if slicing_column and slicing_spec: + raise ValueError( + 'Only one of the "slicing_column" and "slicing_spec" parameters ' + 'can be set.' + ) + if slicing_column: + slicing_spec = tfma.slicer.slicer_lib.SingleSliceSpec( + columns=[slicing_column] + ) + + data = [] + for slice_key, metric_value in slicing_metrics: + if ( + metric_value is not None + and output_name in metric_value + and multi_class_key in metric_value[output_name] + ): + metrics = metric_value[output_name][multi_class_key] + # To add evaluation data for cross slice comparison. + if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key): + _add_cross_slice_key_data(slice_key, metrics, data) + # To add evaluation data for regular slices. + elif ( + slicing_spec is None + or not slice_key + or slicing_spec.is_slice_applicable(slice_key) + ): + data.append({ + 'sliceValue': stringify_slice_key_value(slice_key), + 'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key), + 'metrics': metrics, + }) + if not data: + raise ValueError( + 'No eval result found for output_name:"%s" and ' + 'multi_class_key:"%s" and slicing_column:"%s" and slicing_spec:"%s".' + % (output_name, multi_class_key, slicing_column, slicing_spec) + ) + return data + + class FairnessIndicatorsPlugin(base_plugin.TBPlugin): """A plugin to visualize Fairness Indicators.""" @@ -122,8 +281,7 @@ def _get_evaluation_result(self, request): eval_result = tfma.load_eval_result(output_path=eval_result_output_dir) # TODO(b/141283811): Allow users to choose different model output names # and class keys in case of multi-output and multi-class model. - data = widget_view.convert_slicing_metrics_to_ui_input( - eval_result.slicing_metrics) + data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) except (KeyError, json_format.ParseError) as error: logging.info('Error while fetching evaluation data, %s', error) return http_util.Respond(request, data, content_type='application/json') @@ -147,8 +305,7 @@ def _get_evaluation_result_from_remote_path(self, request): os.path.dirname(evaluation_output_path), output_file_format=self._get_output_file_format( evaluation_output_path)) - data = widget_view.convert_slicing_metrics_to_ui_input( - eval_result.slicing_metrics) + data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics) except (KeyError, json_format.ParseError) as error: logging.info('Error while fetching evaluation data, %s', error) data = []