From c9fb519786f1acfcd0456e6db0c96b4e457ed137 Mon Sep 17 00:00:00 2001 From: Hao Zhou Date: Thu, 2 Jan 2025 11:26:53 -0800 Subject: [PATCH] NA PiperOrigin-RevId: 711478162 --- fairness_indicators/example_model.py | 16 ++++++++-- fairness_indicators/example_model_test.py | 8 +++-- setup.py | 39 ++++++++++++----------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/fairness_indicators/example_model.py b/fairness_indicators/example_model.py index 2b7390a..d1f3dfd 100644 --- a/fairness_indicators/example_model.py +++ b/fairness_indicators/example_model.py @@ -83,8 +83,20 @@ def get_example_model(input_feature_key: str): text_vectorization.adapt( ['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random'] ) - dense1 = keras.layers.Dense(32, activation='relu') - dense2 = keras.layers.Dense(1) + dense1 = keras.layers.Dense( + 32, + activation=None, + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + ) + dense2 = keras.layers.Dense( + 1, + activation=None, + use_bias=False, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + ) inputs = tf.keras.Input(shape=(), dtype=tf.string) parsed_example = parser(inputs) diff --git a/fairness_indicators/example_model_test.py b/fairness_indicators/example_model_test.py index 09266a2..ff10d7c 100644 --- a/fairness_indicators/example_model_test.py +++ b/fairness_indicators/example_model_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for example_model.""" +"""Tests for example_model.py. + +It also serves as an example of how to use fairness indicators with a Keras +model. +""" from __future__ import absolute_import from __future__ import division @@ -91,7 +95,7 @@ def test_example_model(self): ]), batch_size=1, ) - classifier.save(self._model_dir, save_format='tf') + tf.saved_model.save(classifier, self._model_dir) eval_config = text_format.Parse( """ diff --git a/setup.py b/setup.py index 0451434..2321349 100644 --- a/setup.py +++ b/setup.py @@ -17,8 +17,7 @@ import os import sys -from setuptools import find_packages -from setuptools import setup +import setuptools if sys.version_info >= (3, 11): @@ -36,32 +35,32 @@ def select_constraint(default, nightly=None, git_master=None): return git_master else: return default - REQUIRED_PACKAGES = [ - 'tensorflow>=2.15,<2.16', + 'tensorflow>=2.16,<2.17', 'tensorflow-hub>=0.16.1,<1.0.0', - 'tensorflow-data-validation' + select_constraint( - default='>=1.15.1,<2.0.0', - nightly='>=1.16.0.dev', - git_master='@git+https://github.com/tensorflow/data-validation@master'), - 'tensorflow-model-analysis' + select_constraint( - default='>=0.46,<0.47', - nightly='>=0.47.0.dev', - git_master='@git+https://github.com/tensorflow/model-analysis@master'), + 'tensorflow-data-validation' + + select_constraint( + default='>=1.16.1,<2.0.0', + nightly='>=1.17.0.dev', + git_master='@git+https://github.com/tensorflow/data-validation@master', + ), + 'tensorflow-model-analysis' + + select_constraint( + default='>=0.47.0,<0.48.0', + nightly='>=0.48.0.dev', + git_master='@git+https://github.com/tensorflow/model-analysis@master', + ), 'witwidget>=1.4.4,<2', 'protobuf>=3.20.3,<5', ] - # Get version from version module. with open('fairness_indicators/version.py') as fp: globals_dict = {} exec(fp.read(), globals_dict) # pylint: disable=exec-used __version__ = globals_dict['__version__'] - with open('README.md', 'r', encoding='utf-8') as fh: long_description = fh.read() - -setup( +setuptools.setup( name='fairness_indicators', version=__version__, description='Fairness Indicators', @@ -70,7 +69,7 @@ def select_constraint(default, nightly=None, git_master=None): url='https://github.com/tensorflow/fairness-indicators', author='Google LLC', author_email='packages@tensorflow.org', - packages=find_packages(exclude=['tensorboard_plugin']), + packages=setuptools.find_packages(exclude=['tensorboard_plugin']), package_data={ 'fairness_indicators': ['documentation/*'], }, @@ -96,6 +95,8 @@ def select_constraint(default, nightly=None, git_master=None): 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', - keywords='tensorflow model analysis fairness indicators tensorboard machine' - ' learning', + keywords=( + 'tensorflow model analysis fairness indicators tensorboard machine' + ' learning' + ), )