diff --git a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py index e465bef..e3c49a3 100644 --- a/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py +++ b/tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py @@ -28,10 +28,11 @@ import tensorflow.compat.v1 as tf import tensorflow.compat.v2 as tf2 import tensorflow_model_analysis as tfma -from tensorflow_model_analysis.eval_saved_model.example_trainers import linear_classifier +from tensorflow_model_analysis.utils import example_keras_model from werkzeug import test as werkzeug_test from werkzeug import wrappers +from google.protobuf import text_format from tensorboard.backend import application from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer from tensorboard.plugins import base_plugin @@ -74,19 +75,20 @@ def tearDown(self): super(PluginTest, self).tearDown() shutil.rmtree(self._log_dir, ignore_errors=True) - def _exportEvalSavedModel(self, classifier): + def _export_keras_model(self, classifier): temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir") - _, eval_export_dir = classifier(None, temp_eval_export_dir) - return eval_export_dir + classifier.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse") + tf.saved_model.save(classifier, temp_eval_export_dir) + return temp_eval_export_dir - def _writeTFExamplesToTFRecords(self, examples): + def _write_tf_examples_to_tfrecords(self, examples): data_location = os.path.join(self.get_temp_dir(), "input_data.rio") with tf.io.TFRecordWriter(data_location) as writer: for example in examples: writer.write(example.SerializeToString()) return data_location - def _makeExample(self, age, language, label): + def _make_example(self, age, language, label): example = tf.train.Example() example.features.feature["age"].float_list.value[:] = [age] example.features.feature["language"].bytes_list.value[:] = [ @@ -95,6 +97,27 @@ def _makeExample(self, age, language, label): example.features.feature["label"].float_list.value[:] = [label] return example + def _make_eval_config(self): + return text_format.Parse( + """ + model_specs { + signature_name: "serving_default" + prediction_key: "predictions" # placeholder + label_key: "label" # placeholder + } + slicing_specs {} + metrics_specs { + metrics { + class_name: "ExampleCount" + } + metrics { + class_name: "Accuracy" + } + } + """, + tfma.EvalConfig(), + ) + def testRoutes(self): self.assertIsInstance(self._routes["/get_evaluation_result"], abc.Callable) @@ -112,14 +135,14 @@ def testRoutes(self): "foo": "".encode("utf-8") }}, ) - def testIsActive(self, get_random_stub): + def testIsActive(self, get_random_stub): # pylint: disable=unused-argument self.assertTrue(self._plugin.is_active()) @mock.patch.object( event_multiplexer.EventMultiplexer, "PluginRunToTagToContent", return_value={}) - def testIsInactive(self, get_random_stub): + def testIsInactive(self, get_random_stub): # pylint: disable=unused-argument self.assertFalse(self._plugin.is_active()) def testIndexJsRoute(self): @@ -130,57 +153,75 @@ def testIndexJsRoute(self): def testVulcanizedTemplateRoute(self): """Tests that the /tags route offers the correct run to tag mapping.""" response = self._server.get( - "/data/plugin/fairness_indicators/vulcanized_tfma.js") + "/data/plugin/fairness_indicators/vulcanized_tfma.js" + ) self.assertEqual(200, response.status_code) def testGetEvalResultsRoute(self): - model_location = self._exportEvalSavedModel( - linear_classifier.simple_linear_classifier) + model_location = self._export_keras_model( + example_keras_model.get_example_classifier_model( + input_feature_key="language" + ) + ) examples = [ - self._makeExample(age=3.0, language="english", label=1.0), - self._makeExample(age=3.0, language="chinese", label=0.0), - self._makeExample(age=4.0, language="english", label=1.0), - self._makeExample(age=5.0, language="chinese", label=1.0), - self._makeExample(age=5.0, language="hindi", label=1.0) + self._make_example(age=3.0, language="english", label=1.0), + self._make_example(age=3.0, language="chinese", label=0.0), + self._make_example(age=4.0, language="english", label=1.0), + self._make_example(age=5.0, language="chinese", label=1.0), + self._make_example(age=5.0, language="hindi", label=1.0), ] - data_location = self._writeTFExamplesToTFRecords(examples) + eval_config = self._make_eval_config() + data_location = self._write_tf_examples_to_tfrecords(examples) _ = tfma.run_model_analysis( eval_shared_model=tfma.default_eval_shared_model( - eval_saved_model_path=model_location, example_weight_key="age"), + eval_saved_model_path=model_location, eval_config=eval_config + ), + eval_config=eval_config, data_location=data_location, - output_path=self._eval_result_output_dir) + output_path=self._eval_result_output_dir, + ) response = self._server.get( - "/data/plugin/fairness_indicators/get_evaluation_result?run=.") + "/data/plugin/fairness_indicators/get_evaluation_result?run=." + ) self.assertEqual(200, response.status_code) def testGetEvalResultsFromURLRoute(self): - model_location = self._exportEvalSavedModel( - linear_classifier.simple_linear_classifier) + model_location = self._export_keras_model( + example_keras_model.get_example_classifier_model( + input_feature_key="language" + ) + ) examples = [ - self._makeExample(age=3.0, language="english", label=1.0), - self._makeExample(age=3.0, language="chinese", label=0.0), - self._makeExample(age=4.0, language="english", label=1.0), - self._makeExample(age=5.0, language="chinese", label=1.0), - self._makeExample(age=5.0, language="hindi", label=1.0) + self._make_example(age=3.0, language="english", label=1.0), + self._make_example(age=3.0, language="chinese", label=0.0), + self._make_example(age=4.0, language="english", label=1.0), + self._make_example(age=5.0, language="chinese", label=1.0), + self._make_example(age=5.0, language="hindi", label=1.0), ] - data_location = self._writeTFExamplesToTFRecords(examples) + eval_config = self._make_eval_config() + data_location = self._write_tf_examples_to_tfrecords(examples) _ = tfma.run_model_analysis( eval_shared_model=tfma.default_eval_shared_model( - eval_saved_model_path=model_location, example_weight_key="age"), + eval_saved_model_path=model_location, eval_config=eval_config + ), + eval_config=eval_config, data_location=data_location, - output_path=self._eval_result_output_dir) + output_path=self._eval_result_output_dir, + ) response = self._server.get( - "/data/plugin/fairness_indicators/" + - "get_evaluation_result_from_remote_path?evaluation_output_path=" + - os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY)) + "/data/plugin/fairness_indicators/" + + "get_evaluation_result_from_remote_path?evaluation_output_path=" + + os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY) + ) self.assertEqual(200, response.status_code) def testGetOutputFileFormat(self): self.assertEqual("", self._plugin._get_output_file_format("abc_path")) - self.assertEqual("tfrecord", - self._plugin._get_output_file_format("abc_path.tfrecord")) + self.assertEqual( + "tfrecord", self._plugin._get_output_file_format("abc_path.tfrecord") + ) if __name__ == "__main__":