Skip to content

Commit

Permalink
Rewrite the deprecated functions in widget_view
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713451970
  • Loading branch information
zhouhao138 authored and Responsible ML Infra Team committed Jan 9, 2025
1 parent e5f9c73 commit 7b35943
Showing 1 changed file with 162 additions and 5 deletions.
167 changes: 162 additions & 5 deletions tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,187 @@
# limitations under the License.
# ==============================================================================
"""TensorBoard Fairnss Indicators plugin."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from typing import Any

from absl import logging
from tensorboard_plugin_fairness_indicators import metadata
import six
import tensorflow as tf
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.addons.fairness.view import widget_view
from werkzeug import wrappers

from google.protobuf import json_format
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin


_TEMPLATE_LOCATION = os.path.normpath(
os.path.join(
__file__, '../../'
'tensorflow_model_analysis/static/vulcanized_tfma.js'))


def stringify_slice_key_value(
slice_key: tfma.slicer.slicer_lib.SliceKeyType,
) -> str:
"""Stringifies a slice key value.
The string representation of a SingletonSliceKeyType is "feature:value". This
function returns value.
When
multiple columns / features are specified, the string representation of a
SliceKeyType's value is "v1_X_v2_X_..." where v1, v2, ... are values. For
example,
('gender, 'f'), ('age', 5) becomes f_X_5. If no columns / feature
specified, return "Overall".
Note that we do not perform special escaping for slice values that contain
'_X_'. This stringified representation is meant to be human-readbale rather
than a reversible encoding.
The columns will be in the same order as in SliceKeyType. If they are
generated using SingleSliceSpec.generate_slices, they will be in sorted order,
ascending.
Technically float values are not supported, but we don't check for them here.
Args:
slice_key: Slice key to stringify. The constituent SingletonSliceKeyTypes
should be sorted in ascending order.
Returns:
String representation of the slice key's value.
"""
if not slice_key:
return 'Overall'

# Since this is meant to be a human-readable string, we assume that the
# feature values are valid UTF-8 strings (might not be true in cases where
# people store serialised protos in the features for instance).
# We need to call as_str_any to convert non-string (e.g. integer) values to
# string first before converting to text.
# We use u'{}' instead of '{}' here to avoid encoding a unicode character with
# ascii codec.
values = [
'{}'.format(tf.compat.as_text(tf.compat.as_str_any(value)))
for _, value in slice_key
]
return '_X_'.join(values)


def _add_cross_slice_key_data(
slice_key: tfma.slicer.slicer_lib.CrossSliceKeyType,
metrics: tfma.view.view_types.MetricsByTextKey,
data: list[Any],
):
"""Adds data for cross slice key.
Baseline and comparison slice keys are joined by '__XX__'.
Args:
slice_key: Cross slice key.
metrics: Metrics data for the cross slice key.
data: List where UI data is to be appended.
"""
baseline_key = slice_key[0]
comparison_key = slice_key[1]
stringify_slice_value = (
stringify_slice_key_value(baseline_key)
+ '__XX__'
+ stringify_slice_key_value(comparison_key)
)
stringify_slice = (
tfma.slicer.slicer_lib.stringify_slice_key(baseline_key)
+ '__XX__'
+ tfma.slicer.slicer_lib.stringify_slice_key(comparison_key)
)
data.append({
'sliceValue': stringify_slice_value,
'slice': stringify_slice,
'metrics': metrics,
})


def convert_slicing_metrics_to_ui_input(
slicing_metrics: list[
tuple[
tfma.slicer.slicer_lib.SliceKeyOrCrossSliceKeyType,
tfma.view.view_types.MetricsByOutputName,
]
],
slicing_column: str | None = None,
slicing_spec: tfma.slicer.slicer_lib.SingleSliceSpec | None = None,
output_name: str = '',
multi_class_key: str = '',
) -> list[dict[str, Any]] | None:
"""Renders the Fairness Indicator view.
Args:
slicing_metrics: tfma.EvalResult.slicing_metrics.
slicing_column: The slicing column to to filter results. If both
slicing_column and slicing_spec are None, show all eval results.
slicing_spec: The slicing spec to filter results. If both slicing_column and
slicing_spec are None, show all eval results.
output_name: The output name associated with metric (for multi-output
models).
multi_class_key: The multi-class key associated with metric (for multi-class
models).
Returns:
A list of dicts for each slice, where each dict contains keys 'sliceValue',
'slice', and 'metrics'.
Raises:
ValueError if no related eval result found or both slicing_column and
slicing_spec are not None.
"""
if slicing_column and slicing_spec:
raise ValueError(
'Only one of the "slicing_column" and "slicing_spec" parameters '
'can be set.'
)
if slicing_column:
slicing_spec = tfma.slicer.slicer_lib.SingleSliceSpec(
columns=[slicing_column]
)

data = []
for slice_key, metric_value in slicing_metrics:
if (
metric_value is not None
and output_name in metric_value
and multi_class_key in metric_value[output_name]
):
metrics = metric_value[output_name][multi_class_key]
# To add evaluation data for cross slice comparison.
if tfma.slicer.slicer_lib.is_cross_slice_key(slice_key):
_add_cross_slice_key_data(slice_key, metrics, data)
# To add evaluation data for regular slices.
elif (
slicing_spec is None
or not slice_key
or slicing_spec.is_slice_applicable(slice_key)
):
data.append({
'sliceValue': stringify_slice_key_value(slice_key),
'slice': tfma.slicer.slicer_lib.stringify_slice_key(slice_key),
'metrics': metrics,
})
if not data:
raise ValueError(
'No eval result found for output_name:"%s" and '
'multi_class_key:"%s" and slicing_column:"%s" and slicing_spec:"%s".'
% (output_name, multi_class_key, slicing_column, slicing_spec)
)
return data


class FairnessIndicatorsPlugin(base_plugin.TBPlugin):
"""A plugin to visualize Fairness Indicators."""

Expand Down Expand Up @@ -122,8 +281,7 @@ def _get_evaluation_result(self, request):
eval_result = tfma.load_eval_result(output_path=eval_result_output_dir)
# TODO(b/141283811): Allow users to choose different model output names
# and class keys in case of multi-output and multi-class model.
data = widget_view.convert_slicing_metrics_to_ui_input(
eval_result.slicing_metrics)
data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics)
except (KeyError, json_format.ParseError) as error:
logging.info('Error while fetching evaluation data, %s', error)
return http_util.Respond(request, data, content_type='application/json')
Expand All @@ -147,8 +305,7 @@ def _get_evaluation_result_from_remote_path(self, request):
os.path.dirname(evaluation_output_path),
output_file_format=self._get_output_file_format(
evaluation_output_path))
data = widget_view.convert_slicing_metrics_to_ui_input(
eval_result.slicing_metrics)
data = convert_slicing_metrics_to_ui_input(eval_result.slicing_metrics)
except (KeyError, json_format.ParseError) as error:
logging.info('Error while fetching evaluation data, %s', error)
data = []
Expand Down

0 comments on commit 7b35943

Please sign in to comment.