Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 466673111
Change-Id: If18d6db1bb8065c92c2cb81636b3629df623ad0b
  • Loading branch information
davidstutz authored and copybara-github committed Aug 10, 2022
1 parent a098422 commit 8cad46b
Show file tree
Hide file tree
Showing 30 changed files with 370 additions and 70 deletions.
10 changes: 5 additions & 5 deletions colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
"""Utils for evaluation in Colabs or notebooks."""
from typing import Tuple, Callable, Dict, Any, List

from absl import logging
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import sklearn.metrics

import conformal_training.conformal_prediction as cp
import conformal_training.evaluation as cpeval
import conformal_training.open_source_utils as cpstaging
import conformal_prediction as cp
import evaluation as cpeval
import open_source_utils as cpstaging


_CalibrateFn = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], float]
Expand Down Expand Up @@ -504,11 +505,10 @@ def evaluate_conformal_prediction(
test_results_t = pd.concat([tau_t] + test_results_t, axis=1)

test_results = pd.concat((test_results, test_results_t), axis=0)
print(f'\t trial {t}: {tau}', flush=True)
logging.info('Trial %d: %f', t, tau)

results = {
'mean': {'val': val_results.mean(0), 'test': test_results.mean(0)},
'std': {'val': val_results.std(0), 'test': test_results.std(0)},
}
print('\t reduced', flush=True)
return results
6 changes: 3 additions & 3 deletions colab_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import ml_collections as collections
import numpy as np

import conformal_training.colab_utils as cpcolab
import conformal_training.data_utils as cpdatautils
import conformal_training.test_utils as cptutils
import colab_utils as cpcolab
import data_utils as cpdatautils
import test_utils as cptutils


class ColabUtilsTest(parameterized.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions conformal_prediction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import jax.numpy as jnp
import numpy as np

import conformal_training.conformal_prediction as cp
import conformal_training.test_utils as cptutils
import conformal_prediction as cp
import test_utils as cptutils


class ConformalPredictionTest(parameterized.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tensorflow as tf
import tensorflow_datasets as tfds

import conformal_training.auto_augment as augment
import auto_augment as augment


def load_data_split(
Expand Down
2 changes: 1 addition & 1 deletion data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
import tensorflow_datasets as tfds

import conformal_training.data as cpdata
import data as cpdata
DATA_DIR = './data/'


Expand Down
2 changes: 1 addition & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import ml_collections as collections
import tensorflow as tf

import conformal_training.data as cpdata
import data as cpdata


def apply_cifar_augmentation(
Expand Down
4 changes: 2 additions & 2 deletions data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import chex
import ml_collections as collections

import conformal_training.data as cpdata
import conformal_training.data_utils as cpdatautils
import data as cpdata
import data_utils as cpdatautils
DATA_DIR = './data/'


Expand Down
137 changes: 137 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
name: conformal_training
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- _tflow_select=2.3.0=mkl
- abseil-cpp=20211102.0=h27087fc_1
- absl-py=0.15.0=pyhd3eb1b0_0
- aiohttp=3.8.1=py39hb9d737c_1
- aiosignal=1.2.0=pyhd8ed1ab_0
- astor=0.8.1=pyh9f0ad1d_0
- astunparse=1.6.3=pyhd8ed1ab_0
- async-timeout=4.0.2=pyhd8ed1ab_0
- attrs=21.4.0=pyhd8ed1ab_0
- blas=1.0=openblas
- blinker=1.4=py_1
- bottleneck=1.3.5=py39h7deecbd_0
- brotlipy=0.7.0=py39hb9d737c_1004
- bzip2=1.0.8=h7f98852_4
- c-ares=1.18.1=h7f98852_0
- ca-certificates=2022.6.15=ha878542_0
- cachetools=4.2.4=pyhd8ed1ab_0
- certifi=2022.6.15=py39hf3d152e_0
- cffi=1.15.1=py39he91dace_0
- charset-normalizer=2.1.0=pyhd8ed1ab_0
- click=8.1.3=py39hf3d152e_0
- cryptography=37.0.1=py39h9ce1e76_0
- dataclasses=0.8=pyhc8e2a94_3
- dm-haiku=0.0.7=pyhd8ed1ab_0
- etils=0.6.0=pyhd8ed1ab_0
- frozenlist=1.3.0=py39hb9d737c_1
- gast=0.4.0=pyh9f0ad1d_0
- google-auth=1.35.0=pyh6c4a22f_0
- google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
- google-pasta=0.2.0=pyh8c360ce_0
- grpc-cpp=1.46.3=h00ec82a_2
- grpcio=1.46.3=py39h2edfe15_2
- h5py=2.10.0=nompi_py39h98ba4bc_106
- hdf5=1.10.6=h3ffc7dd_1
- idna=3.3=pyhd8ed1ab_0
- importlib-metadata=4.11.4=py39hf3d152e_0
- importlib_resources=5.8.0=pyhd8ed1ab_0
- jax=0.3.14=pyhd8ed1ab_1
- jaxlib=0.3.14=cpu_py39h79d7c74_0
- jmp=0.0.2=pyhd8ed1ab_0
- joblib=1.1.0=pyhd3eb1b0_0
- keras-preprocessing=1.1.2=pyhd8ed1ab_0
- ld_impl_linux-64=2.36.1=hea4e1c9_2
- libblas=3.9.0=15_linux64_openblas
- libcblas=3.9.0=15_linux64_openblas
- libffi=3.4.2=h7f98852_5
- libgcc-ng=12.1.0=h8d9b700_16
- libgfortran-ng=12.1.0=h69a702a_16
- libgfortran5=12.1.0=hdcd56e2_16
- libgomp=12.1.0=h8d9b700_16
- liblapack=3.9.0=15_linux64_openblas
- libnsl=2.0.0=h7f98852_0
- libopenblas=0.3.20=pthreads_h78a6416_0
- libprotobuf=3.20.1=h6239696_0
- libstdcxx-ng=12.1.0=ha89aaad_16
- libuuid=2.32.1=h7f98852_1000
- libzlib=1.2.12=h166bdaf_1
- markdown=3.3.7=pyhd8ed1ab_0
- multidict=6.0.2=py39hb9d737c_1
- ncurses=6.3=h27087fc_1
- numexpr=2.8.3=py39hd2a5715_0
- numpy=1.19.5=py39hd249d9e_3
- oauthlib=3.2.0=pyhd8ed1ab_0
- openssl=3.0.5=h166bdaf_0
- opt_einsum=3.3.0=pyhd8ed1ab_1
- packaging=21.3=pyhd3eb1b0_0
- pandas=1.4.2=py39h295c915_0
- pip=22.1.2=pyhd8ed1ab_0
- protobuf=3.20.1=py39h5a03fae_0
- pyasn1=0.4.8=py_0
- pyasn1-modules=0.2.7=py_0
- pycparser=2.21=pyhd8ed1ab_0
- pyjwt=2.4.0=pyhd8ed1ab_0
- pyopenssl=22.0.0=pyhd8ed1ab_0
- pysocks=1.7.1=py39hf3d152e_5
- python=3.9.13=h2660328_0_cpython
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-flatbuffers=2.0=pyhd8ed1ab_0
- python_abi=3.9=2_cp39
- pytz=2022.1=py39h06a4308_0
- pyu2f=0.1.5=pyhd8ed1ab_0
- re2=2022.06.01=h27087fc_0
- readline=8.1.2=h0f457ee_0
- requests=2.28.1=pyhd8ed1ab_0
- requests-oauthlib=1.3.1=pyhd8ed1ab_0
- rsa=4.8=pyhd8ed1ab_0
- scikit-learn=1.0.2=py39h51133e4_1
- scipy=1.8.1=py39he49c0e8_0
- setuptools=63.1.0=py39hf3d152e_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.39.0=h4ff8645_0
- tabulate=0.8.10=pyhd8ed1ab_0
- tensorboard=2.4.1=pyhd8ed1ab_1
- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
- tensorflow=2.4.1=mkl_py39h4683426_0
- tensorflow-base=2.4.1=mkl_py39h43e0292_0
- tensorflow-estimator=2.6.0=py39he80948d_0
- termcolor=1.1.0=pyhd8ed1ab_3
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h27826a3_0
- typing-extensions=4.3.0=hd8ed1ab_0
- typing_extensions=4.3.0=pyha770c72_0
- tzdata=2022a=h191b570_0
- urllib3=1.26.9=pyhd8ed1ab_0
- werkzeug=2.1.2=pyhd8ed1ab_1
- wheel=0.37.1=pyhd8ed1ab_0
- wrapt=1.14.1=py39hb9d737c_0
- xz=5.2.5=h516909a_1
- yarl=1.7.2=py39hb9d737c_2
- zipp=3.8.0=pyhd8ed1ab_0
- zlib=1.2.12=h166bdaf_1
- pip:
- chex==0.1.3
- contextlib2==21.6.0
- dill==0.3.5.1
- dm-tree==0.1.7
- googleapis-common-protos==1.56.3
- install==1.3.5
- ml-collections==0.1.1
- optax==0.1.2
- promise==2.3
- pyparsing==3.0.9
- pyyaml==6.0
- tensorflow-addons==0.17.1
- tensorflow-datasets==4.6.0
- tensorflow-metadata==1.9.0
- toml==0.10.2
- toolz==0.11.2
- tqdm==4.64.0
- typeguard==2.13.3
118 changes: 118 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Evaluate experiment."""
import os
import sys

from absl import flags
from absl import logging
import jax

from absl import app
import colab_utils as cbutils

FLAGS = flags.FLAGS
flags.DEFINE_string('experiment_path', './', 'base path for experiments')
flags.DEFINE_string('experiment_dataset', '', 'dataset to evaluate')
flags.DEFINE_string(
'experiment_method', 'thr', 'conformal predictor to use, thr or apr')
flags.DEFINE_boolean('experiment_logfile', False,
'log results to file in experiment_path')


def main(argv):
del argv

if FLAGS.experiment_logfile:
logging.get_absl_handler().use_absl_log_file(
f'eval_{FLAGS.experiment_method}', FLAGS.experiment_path)
else:
logging.get_absl_handler().python_handler.stream = sys.stdout

if not os.path.exists(FLAGS.experiment_path):
logging.error('could not find experiment path %s', FLAGS.experiment_path)
return

alpha = 0.01
if FLAGS.experiment_method == 'thr':
calibrate_fn, predict_fn = cbutils.get_threshold_fns(alpha)
elif FLAGS.experiment_method == 'aps':
calibrate_fn, predict_fn = cbutils.get_raps_fns(alpha, 0, 0)
else:
raise ValueError('Invalid conformal predictor, choose thr or aps.')

if FLAGS.experiment_dataset == 'mnist':
num_classes = 10
groups = ['singleton', 'groups']
elif FLAGS.experiment_dataset == 'emnist_byclass':
num_classes = 52
groups = ['groups']
elif FLAGS.experiment_dataset == 'fashion_mnist':
num_classes = 10
groups = ['singleton']
elif FLAGS.experiment_dataset == 'cifar10':
num_classes = 10
groups = ['singleton', 'groups']
elif FLAGS.experiment_dataset == 'cifar100':
num_classes = 100
groups = ['groups', 'hierarchy']
else:
raise ValueError('Invalid dataset %s.' % FLAGS.experiment_dataset)

model = cbutils.load_predictions(FLAGS.experiment_path, val_examples=5000)

for group in groups:
model['data']['groups'][group] = cbutils.get_groups(
FLAGS.experiment_dataset, group)

results = cbutils.evaluate_conformal_prediction(
model, calibrate_fn, predict_fn, trials=10, rng=jax.random.PRNGKey(0))

logging.info('Accuracy: %f', results['mean']['test']['accuracy'])
logging.info('Coverage: %f', results['mean']['test']['coverage'])
logging.info('Size: %f', results['mean']['test']['size'])

for k in range(num_classes):
logging.info(
'Class size %d: %f', k, results['mean']['test'][f'class_size_{k}'])

for group in groups:
k = 0
key = f'{group}_size_{k}'
while key in results['mean']['test'].keys():
logging.info(
'Group %s size %d: %f', group, k, results['mean']['test'][key])
k += 1
key = f'{group}_size_{k}'

logging.info(
'Group %s miscoverage 0: %f',
group, results['mean']['test'][f'{group}_miscoverage_0'])
logging.info(
'Group %s miscoverage 1: %f',
group, results['mean']['test'][f'{group}_miscoverage_1'])

# Selected coverage confusion combinations:
logging.info(
'Coverage confusion 4-6: %f',
results['mean']['test']['coverage_confusion_4_6'])
logging.info(
'Coverage confusion 6-4: %f',
results['mean']['test']['coverage_confusion_6_4'])


if __name__ == '__main__':
app.run(main)
4 changes: 2 additions & 2 deletions evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import jax.numpy as jnp
import numpy as np

import conformal_training.evaluation as cpeval
import conformal_training.test_utils as cptutils
import evaluation as cpeval
import test_utils as cptutils


class EvaluationTest(parameterized.TestCase):
Expand Down
16 changes: 16 additions & 0 deletions experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Experiments configuration."""
4 changes: 2 additions & 2 deletions experiments/run_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import ml_collections as collections

import conformal_training.experiments.experiment_utils as cpeutils
import experiments.experiment_utils as cpeutils


def get_parameters(
Expand Down Expand Up @@ -52,7 +52,7 @@ def get_parameters(
else:
config.epochs = 50
config.finetune.enabled = True
config.finetune.path = './cifar10_models_seed0/'
config.finetune.path = 'cifar10_models_seed0/'
config.finetune.model_state = False
config.finetune.layers = 'res_net/~/logits'
config.finetune.reinitialize = True
Expand Down
Loading

0 comments on commit 8cad46b

Please sign in to comment.