diff --git a/colab_utils.py b/colab_utils.py index 38a3337..06a7dee 100644 --- a/colab_utils.py +++ b/colab_utils.py @@ -16,15 +16,16 @@ """Utils for evaluation in Colabs or notebooks.""" from typing import Tuple, Callable, Dict, Any, List +from absl import logging import jax import jax.numpy as jnp import numpy as np import pandas as pd import sklearn.metrics -import conformal_training.conformal_prediction as cp -import conformal_training.evaluation as cpeval -import conformal_training.open_source_utils as cpstaging +import conformal_prediction as cp +import evaluation as cpeval +import open_source_utils as cpstaging _CalibrateFn = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], float] @@ -504,11 +505,10 @@ def evaluate_conformal_prediction( test_results_t = pd.concat([tau_t] + test_results_t, axis=1) test_results = pd.concat((test_results, test_results_t), axis=0) - print(f'\t trial {t}: {tau}', flush=True) + logging.info('Trial %d: %f', t, tau) results = { 'mean': {'val': val_results.mean(0), 'test': test_results.mean(0)}, 'std': {'val': val_results.std(0), 'test': test_results.std(0)}, } - print('\t reduced', flush=True) return results diff --git a/colab_utils_test.py b/colab_utils_test.py index c8a3516..9dc87dc 100644 --- a/colab_utils_test.py +++ b/colab_utils_test.py @@ -22,9 +22,9 @@ import ml_collections as collections import numpy as np -import conformal_training.colab_utils as cpcolab -import conformal_training.data_utils as cpdatautils -import conformal_training.test_utils as cptutils +import colab_utils as cpcolab +import data_utils as cpdatautils +import test_utils as cptutils class ColabUtilsTest(parameterized.TestCase): diff --git a/conformal_prediction_test.py b/conformal_prediction_test.py index 6793df3..975608b 100644 --- a/conformal_prediction_test.py +++ b/conformal_prediction_test.py @@ -23,8 +23,8 @@ import jax.numpy as jnp import numpy as np -import conformal_training.conformal_prediction as cp -import conformal_training.test_utils as cptutils +import conformal_prediction as cp +import test_utils as cptutils class ConformalPredictionTest(parameterized.TestCase): diff --git a/data.py b/data.py index dcf7776..502e74a 100644 --- a/data.py +++ b/data.py @@ -20,7 +20,7 @@ import tensorflow as tf import tensorflow_datasets as tfds -import conformal_training.auto_augment as augment +import auto_augment as augment def load_data_split( diff --git a/data_test.py b/data_test.py index fbc8855..a691789 100644 --- a/data_test.py +++ b/data_test.py @@ -22,7 +22,7 @@ import numpy as np import tensorflow_datasets as tfds -import conformal_training.data as cpdata +import data as cpdata DATA_DIR = './data/' diff --git a/data_utils.py b/data_utils.py index cc4c3bf..7b87a27 100644 --- a/data_utils.py +++ b/data_utils.py @@ -22,7 +22,7 @@ import ml_collections as collections import tensorflow as tf -import conformal_training.data as cpdata +import data as cpdata def apply_cifar_augmentation( diff --git a/data_utils_test.py b/data_utils_test.py index 776e6e4..a22bc57 100644 --- a/data_utils_test.py +++ b/data_utils_test.py @@ -20,8 +20,8 @@ import chex import ml_collections as collections -import conformal_training.data as cpdata -import conformal_training.data_utils as cpdatautils +import data as cpdata +import data_utils as cpdatautils DATA_DIR = './data/' diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..953d84c --- /dev/null +++ b/environment.yml @@ -0,0 +1,137 @@ +name: conformal_training +channels: +- conda-forge +- defaults +dependencies: +- _libgcc_mutex=0.1=conda_forge +- _openmp_mutex=4.5=2_gnu +- _tflow_select=2.3.0=mkl +- abseil-cpp=20211102.0=h27087fc_1 +- absl-py=0.15.0=pyhd3eb1b0_0 +- aiohttp=3.8.1=py39hb9d737c_1 +- aiosignal=1.2.0=pyhd8ed1ab_0 +- astor=0.8.1=pyh9f0ad1d_0 +- astunparse=1.6.3=pyhd8ed1ab_0 +- async-timeout=4.0.2=pyhd8ed1ab_0 +- attrs=21.4.0=pyhd8ed1ab_0 +- blas=1.0=openblas +- blinker=1.4=py_1 +- bottleneck=1.3.5=py39h7deecbd_0 +- brotlipy=0.7.0=py39hb9d737c_1004 +- bzip2=1.0.8=h7f98852_4 +- c-ares=1.18.1=h7f98852_0 +- ca-certificates=2022.6.15=ha878542_0 +- cachetools=4.2.4=pyhd8ed1ab_0 +- certifi=2022.6.15=py39hf3d152e_0 +- cffi=1.15.1=py39he91dace_0 +- charset-normalizer=2.1.0=pyhd8ed1ab_0 +- click=8.1.3=py39hf3d152e_0 +- cryptography=37.0.1=py39h9ce1e76_0 +- dataclasses=0.8=pyhc8e2a94_3 +- dm-haiku=0.0.7=pyhd8ed1ab_0 +- etils=0.6.0=pyhd8ed1ab_0 +- frozenlist=1.3.0=py39hb9d737c_1 +- gast=0.4.0=pyh9f0ad1d_0 +- google-auth=1.35.0=pyh6c4a22f_0 +- google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 +- google-pasta=0.2.0=pyh8c360ce_0 +- grpc-cpp=1.46.3=h00ec82a_2 +- grpcio=1.46.3=py39h2edfe15_2 +- h5py=2.10.0=nompi_py39h98ba4bc_106 +- hdf5=1.10.6=h3ffc7dd_1 +- idna=3.3=pyhd8ed1ab_0 +- importlib-metadata=4.11.4=py39hf3d152e_0 +- importlib_resources=5.8.0=pyhd8ed1ab_0 +- jax=0.3.14=pyhd8ed1ab_1 +- jaxlib=0.3.14=cpu_py39h79d7c74_0 +- jmp=0.0.2=pyhd8ed1ab_0 +- joblib=1.1.0=pyhd3eb1b0_0 +- keras-preprocessing=1.1.2=pyhd8ed1ab_0 +- ld_impl_linux-64=2.36.1=hea4e1c9_2 +- libblas=3.9.0=15_linux64_openblas +- libcblas=3.9.0=15_linux64_openblas +- libffi=3.4.2=h7f98852_5 +- libgcc-ng=12.1.0=h8d9b700_16 +- libgfortran-ng=12.1.0=h69a702a_16 +- libgfortran5=12.1.0=hdcd56e2_16 +- libgomp=12.1.0=h8d9b700_16 +- liblapack=3.9.0=15_linux64_openblas +- libnsl=2.0.0=h7f98852_0 +- libopenblas=0.3.20=pthreads_h78a6416_0 +- libprotobuf=3.20.1=h6239696_0 +- libstdcxx-ng=12.1.0=ha89aaad_16 +- libuuid=2.32.1=h7f98852_1000 +- libzlib=1.2.12=h166bdaf_1 +- markdown=3.3.7=pyhd8ed1ab_0 +- multidict=6.0.2=py39hb9d737c_1 +- ncurses=6.3=h27087fc_1 +- numexpr=2.8.3=py39hd2a5715_0 +- numpy=1.19.5=py39hd249d9e_3 +- oauthlib=3.2.0=pyhd8ed1ab_0 +- openssl=3.0.5=h166bdaf_0 +- opt_einsum=3.3.0=pyhd8ed1ab_1 +- packaging=21.3=pyhd3eb1b0_0 +- pandas=1.4.2=py39h295c915_0 +- pip=22.1.2=pyhd8ed1ab_0 +- protobuf=3.20.1=py39h5a03fae_0 +- pyasn1=0.4.8=py_0 +- pyasn1-modules=0.2.7=py_0 +- pycparser=2.21=pyhd8ed1ab_0 +- pyjwt=2.4.0=pyhd8ed1ab_0 +- pyopenssl=22.0.0=pyhd8ed1ab_0 +- pysocks=1.7.1=py39hf3d152e_5 +- python=3.9.13=h2660328_0_cpython +- python-dateutil=2.8.2=pyhd3eb1b0_0 +- python-flatbuffers=2.0=pyhd8ed1ab_0 +- python_abi=3.9=2_cp39 +- pytz=2022.1=py39h06a4308_0 +- pyu2f=0.1.5=pyhd8ed1ab_0 +- re2=2022.06.01=h27087fc_0 +- readline=8.1.2=h0f457ee_0 +- requests=2.28.1=pyhd8ed1ab_0 +- requests-oauthlib=1.3.1=pyhd8ed1ab_0 +- rsa=4.8=pyhd8ed1ab_0 +- scikit-learn=1.0.2=py39h51133e4_1 +- scipy=1.8.1=py39he49c0e8_0 +- setuptools=63.1.0=py39hf3d152e_0 +- six=1.16.0=pyh6c4a22f_0 +- sqlite=3.39.0=h4ff8645_0 +- tabulate=0.8.10=pyhd8ed1ab_0 +- tensorboard=2.4.1=pyhd8ed1ab_1 +- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 +- tensorflow=2.4.1=mkl_py39h4683426_0 +- tensorflow-base=2.4.1=mkl_py39h43e0292_0 +- tensorflow-estimator=2.6.0=py39he80948d_0 +- termcolor=1.1.0=pyhd8ed1ab_3 +- threadpoolctl=2.2.0=pyh0d69192_0 +- tk=8.6.12=h27826a3_0 +- typing-extensions=4.3.0=hd8ed1ab_0 +- typing_extensions=4.3.0=pyha770c72_0 +- tzdata=2022a=h191b570_0 +- urllib3=1.26.9=pyhd8ed1ab_0 +- werkzeug=2.1.2=pyhd8ed1ab_1 +- wheel=0.37.1=pyhd8ed1ab_0 +- wrapt=1.14.1=py39hb9d737c_0 +- xz=5.2.5=h516909a_1 +- yarl=1.7.2=py39hb9d737c_2 +- zipp=3.8.0=pyhd8ed1ab_0 +- zlib=1.2.12=h166bdaf_1 +- pip: + - chex==0.1.3 + - contextlib2==21.6.0 + - dill==0.3.5.1 + - dm-tree==0.1.7 + - googleapis-common-protos==1.56.3 + - install==1.3.5 + - ml-collections==0.1.1 + - optax==0.1.2 + - promise==2.3 + - pyparsing==3.0.9 + - pyyaml==6.0 + - tensorflow-addons==0.17.1 + - tensorflow-datasets==4.6.0 + - tensorflow-metadata==1.9.0 + - toml==0.10.2 + - toolz==0.11.2 + - tqdm==4.64.0 + - typeguard==2.13.3 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..8dc9cb4 --- /dev/null +++ b/eval.py @@ -0,0 +1,118 @@ +# Copyright 2022 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Evaluate experiment.""" +import os +import sys + +from absl import flags +from absl import logging +import jax + +from absl import app +import colab_utils as cbutils + +FLAGS = flags.FLAGS +flags.DEFINE_string('experiment_path', './', 'base path for experiments') +flags.DEFINE_string('experiment_dataset', '', 'dataset to evaluate') +flags.DEFINE_string( + 'experiment_method', 'thr', 'conformal predictor to use, thr or apr') +flags.DEFINE_boolean('experiment_logfile', False, + 'log results to file in experiment_path') + + +def main(argv): + del argv + + if FLAGS.experiment_logfile: + logging.get_absl_handler().use_absl_log_file( + f'eval_{FLAGS.experiment_method}', FLAGS.experiment_path) + else: + logging.get_absl_handler().python_handler.stream = sys.stdout + + if not os.path.exists(FLAGS.experiment_path): + logging.error('could not find experiment path %s', FLAGS.experiment_path) + return + + alpha = 0.01 + if FLAGS.experiment_method == 'thr': + calibrate_fn, predict_fn = cbutils.get_threshold_fns(alpha) + elif FLAGS.experiment_method == 'aps': + calibrate_fn, predict_fn = cbutils.get_raps_fns(alpha, 0, 0) + else: + raise ValueError('Invalid conformal predictor, choose thr or aps.') + + if FLAGS.experiment_dataset == 'mnist': + num_classes = 10 + groups = ['singleton', 'groups'] + elif FLAGS.experiment_dataset == 'emnist_byclass': + num_classes = 52 + groups = ['groups'] + elif FLAGS.experiment_dataset == 'fashion_mnist': + num_classes = 10 + groups = ['singleton'] + elif FLAGS.experiment_dataset == 'cifar10': + num_classes = 10 + groups = ['singleton', 'groups'] + elif FLAGS.experiment_dataset == 'cifar100': + num_classes = 100 + groups = ['groups', 'hierarchy'] + else: + raise ValueError('Invalid dataset %s.' % FLAGS.experiment_dataset) + + model = cbutils.load_predictions(FLAGS.experiment_path, val_examples=5000) + + for group in groups: + model['data']['groups'][group] = cbutils.get_groups( + FLAGS.experiment_dataset, group) + + results = cbutils.evaluate_conformal_prediction( + model, calibrate_fn, predict_fn, trials=10, rng=jax.random.PRNGKey(0)) + + logging.info('Accuracy: %f', results['mean']['test']['accuracy']) + logging.info('Coverage: %f', results['mean']['test']['coverage']) + logging.info('Size: %f', results['mean']['test']['size']) + + for k in range(num_classes): + logging.info( + 'Class size %d: %f', k, results['mean']['test'][f'class_size_{k}']) + + for group in groups: + k = 0 + key = f'{group}_size_{k}' + while key in results['mean']['test'].keys(): + logging.info( + 'Group %s size %d: %f', group, k, results['mean']['test'][key]) + k += 1 + key = f'{group}_size_{k}' + + logging.info( + 'Group %s miscoverage 0: %f', + group, results['mean']['test'][f'{group}_miscoverage_0']) + logging.info( + 'Group %s miscoverage 1: %f', + group, results['mean']['test'][f'{group}_miscoverage_1']) + + # Selected coverage confusion combinations: + logging.info( + 'Coverage confusion 4-6: %f', + results['mean']['test']['coverage_confusion_4_6']) + logging.info( + 'Coverage confusion 6-4: %f', + results['mean']['test']['coverage_confusion_6_4']) + + +if __name__ == '__main__': + app.run(main) diff --git a/evaluation_test.py b/evaluation_test.py index 7327add..4585a46 100644 --- a/evaluation_test.py +++ b/evaluation_test.py @@ -20,8 +20,8 @@ import jax.numpy as jnp import numpy as np -import conformal_training.evaluation as cpeval -import conformal_training.test_utils as cptutils +import evaluation as cpeval +import test_utils as cptutils class EvaluationTest(parameterized.TestCase): diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..2e66709 --- /dev/null +++ b/experiments/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Experiments configuration.""" diff --git a/experiments/run_cifar10.py b/experiments/run_cifar10.py index 9943c56..aecbd64 100644 --- a/experiments/run_cifar10.py +++ b/experiments/run_cifar10.py @@ -18,7 +18,7 @@ import ml_collections as collections -import conformal_training.experiments.experiment_utils as cpeutils +import experiments.experiment_utils as cpeutils def get_parameters( @@ -52,7 +52,7 @@ def get_parameters( else: config.epochs = 50 config.finetune.enabled = True - config.finetune.path = './cifar10_models_seed0/' + config.finetune.path = 'cifar10_models_seed0/' config.finetune.model_state = False config.finetune.layers = 'res_net/~/logits' config.finetune.reinitialize = True diff --git a/experiments/run_cifar100.py b/experiments/run_cifar100.py index 6481e3a..e4b558e 100644 --- a/experiments/run_cifar100.py +++ b/experiments/run_cifar100.py @@ -19,7 +19,7 @@ import ml_collections as collections import numpy as np -import conformal_training.experiments.experiment_utils as cpeutils +import experiments.experiment_utils as cpeutils def get_parameters( @@ -71,12 +71,12 @@ def get_parameters( config.learning_rate = 0.05 config.batch_size = 100 else: - config.finetune.enabled = True config.epochs = 50 + config.finetune.enabled = True + config.finetune.path = 'cifar100_models_seed0/' config.finetune.model_state = False config.finetune.layers = 'res_net/~/logits' config.finetune.reinitialize = True - config.cifar_augmentation = 'standard+autoaugment+cutout' if experiment == 'baseline_trials': config.mode = 'normal' diff --git a/experiments/run_fashion_mnist.py b/experiments/run_fashion_mnist.py index 0b6cb8d..59a9655 100644 --- a/experiments/run_fashion_mnist.py +++ b/experiments/run_fashion_mnist.py @@ -18,7 +18,7 @@ import ml_collections as collections -import conformal_training.experiments.experiment_utils as cpeutils +import experiments.experiment_utils as cpeutils def get_parameters( diff --git a/experiments/run_mnist.py b/experiments/run_mnist.py index c840571..850d4fd 100644 --- a/experiments/run_mnist.py +++ b/experiments/run_mnist.py @@ -18,7 +18,7 @@ import ml_collections as collections -import conformal_training.experiments.experiment_utils as cpeutils +import experiments.experiment_utils as cpeutils def get_parameters( diff --git a/experiments/run_wine_quality.py b/experiments/run_wine_quality.py index d8e8c3e..a7b9c4e 100644 --- a/experiments/run_wine_quality.py +++ b/experiments/run_wine_quality.py @@ -18,7 +18,7 @@ import ml_collections as collections -import conformal_training.experiments.experiment_utils as cpeutils +import experiments.experiment_utils as cpeutils def get_parameters( diff --git a/main.py b/main.py index 28a750b..27540fc 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ from ml_collections import config_flags from absl import app -from conformal_training.train import train +from train import train FLAGS = flags.FLAGS diff --git a/models_test.py b/models_test.py index 78ad41b..53393e4 100644 --- a/models_test.py +++ b/models_test.py @@ -21,7 +21,7 @@ import jax.numpy as jnp import numpy as np -import conformal_training.models as cpmodels +import models as cpmodels class ModelsTest(parameterized.TestCase): diff --git a/run.py b/run.py index 6bbe0f3..3a657f4 100644 --- a/run.py +++ b/run.py @@ -23,16 +23,16 @@ import ml_collections as collections from absl import app -from conformal_training.config import get_config +from config import get_config # pylint: disable=unused-import -from conformal_training.experiments.run_cifar10 import get_parameters as get_cifar10_parameters -from conformal_training.experiments.run_cifar100 import get_parameters as get_cifar100_parameters -from conformal_training.experiments.run_emnist_byclass import get_parameters as get_emnist_byclass_parameters -from conformal_training.experiments.run_fashion_mnist import get_parameters as get_fashion_mnist_parameters -from conformal_training.experiments.run_mnist import get_parameters as get_mnist_parameters -from conformal_training.experiments.run_wine_quality import get_parameters as get_wine_quality_parameters -from conformal_training.train import train +from experiments.run_cifar10 import get_parameters as get_cifar10_parameters +from experiments.run_cifar100 import get_parameters as get_cifar100_parameters +from experiments.run_emnist_byclass import get_parameters as get_emnist_byclass_parameters +from experiments.run_fashion_mnist import get_parameters as get_fashion_mnist_parameters +from experiments.run_mnist import get_parameters as get_mnist_parameters +from experiments.run_wine_quality import get_parameters as get_wine_quality_parameters +from train import train FLAGS = flags.FLAGS flags.DEFINE_string('experiment_dataset', 'cifar10', 'dataset to use') diff --git a/smooth_conformal_prediction.py b/smooth_conformal_prediction.py index b53b171..516b03b 100644 --- a/smooth_conformal_prediction.py +++ b/smooth_conformal_prediction.py @@ -32,7 +32,7 @@ import jax import jax.numpy as jnp -from conformal_training import variational_sorting_net +import variational_sorting_net _SmoothQuantileFn = Callable[[Any, float], float] diff --git a/smooth_conformal_prediction_test.py b/smooth_conformal_prediction_test.py index 2ef089c..b5aaa9e 100644 --- a/smooth_conformal_prediction_test.py +++ b/smooth_conformal_prediction_test.py @@ -21,11 +21,11 @@ import jax.numpy as jnp import numpy as np -from conformal_training import sorting_nets -from conformal_training import variational_sorting_net -import conformal_training.conformal_prediction as cp -import conformal_training.smooth_conformal_prediction as scp -import conformal_training.test_utils as cptutils +import sorting_nets +import variational_sorting_net +import conformal_prediction as cp +import smooth_conformal_prediction as scp +import test_utils as cptutils class SmoothConformalPredictionTest(parameterized.TestCase): diff --git a/sorting_nets_test.py b/sorting_nets_test.py index 4d91dc0..dbfa1c3 100644 --- a/sorting_nets_test.py +++ b/sorting_nets_test.py @@ -20,8 +20,8 @@ import jax import jax.numpy as jnp -from conformal_training import sorting_nets -from conformal_training import variational_sorting_net +import sorting_nets +import variational_sorting_net class SortingNetsTest(parameterized.TestCase): diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..72bb9db --- /dev/null +++ b/test.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2022 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +rm -rf ./data +python3 colab_utils_test.py +python3 conformal_prediction_test.py +python3 data_test.py +python3 data_utils_test.py +python3 evaluation_test.py +python3 models_test.py +python3 smooth_conformal_prediction_test.py +python3 sorting_nets_test.py +python3 train_utils_test.py +python3 variational_sorting_net_test.py diff --git a/train.py b/train.py index a77ce5d..66d961c 100644 --- a/train.py +++ b/train.py @@ -18,11 +18,11 @@ import haiku as hk import ml_collections as collections -import conformal_training.data_utils as cpdatautils -import conformal_training.train_conformal as cpconformal -import conformal_training.train_coverage as cpcoverage -import conformal_training.train_normal as cpnormal -import conformal_training.train_utils as cputils +import data_utils as cpdatautils +import train_conformal as cpconformal +import train_coverage as cpcoverage +import train_normal as cpnormal +import train_utils as cputils def train(config: collections.ConfigDict): diff --git a/train_conformal.py b/train_conformal.py index 6fcd521..070f467 100644 --- a/train_conformal.py +++ b/train_conformal.py @@ -23,10 +23,10 @@ import ml_collections as collections -import conformal_training.evaluation as cpeval -import conformal_training.smooth_conformal_prediction as scp -import conformal_training.train_coverage as cpcoverage -import conformal_training.train_utils as cputils +import evaluation as cpeval +import smooth_conformal_prediction as scp +import train_coverage as cpcoverage +import train_utils as cputils SmoothCalibrateFn = Callable[ diff --git a/train_coverage.py b/train_coverage.py index 1b08fa8..81e1241 100644 --- a/train_coverage.py +++ b/train_coverage.py @@ -24,14 +24,14 @@ import jax.numpy as jnp import ml_collections as collections -from conformal_training import sorting_nets -from conformal_training import variational_sorting_net -import conformal_training.conformal_prediction as cp -import conformal_training.data as cpdata -import conformal_training.evaluation as cpeval -import conformal_training.smooth_conformal_prediction as scp -import conformal_training.train_normal as cpnormal -import conformal_training.train_utils as cputils +import sorting_nets +import variational_sorting_net +import conformal_prediction as cp +import data as cpdata +import evaluation as cpeval +import smooth_conformal_prediction as scp +import train_normal as cpnormal +import train_utils as cputils SizeLossFn = Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray] diff --git a/train_normal.py b/train_normal.py index 7a2dc3c..9a73f0a 100644 --- a/train_normal.py +++ b/train_normal.py @@ -24,10 +24,10 @@ import numpy as np import tensorflow as tf -import conformal_training.data as cpdata -import conformal_training.evaluation as cpeval -import conformal_training.open_source_utils as cpstaging -import conformal_training.train_utils as cputils +import data as cpdata +import evaluation as cpeval +import open_source_utils as cpstaging +import train_utils as cputils ShiftFn = Callable[[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]] diff --git a/train_utils.py b/train_utils.py index cbd812a..69345a9 100644 --- a/train_utils.py +++ b/train_utils.py @@ -22,9 +22,9 @@ import ml_collections as collections import optax -import conformal_training.data as cpdata -import conformal_training.models as cpmodels -import conformal_training.open_source_utils as cpstaging +import data as cpdata +import models as cpmodels +import open_source_utils as cpstaging FlatMapping = Union[hk.Params, hk.State] diff --git a/train_utils_test.py b/train_utils_test.py index 5c87200..f47a17c 100644 --- a/train_utils_test.py +++ b/train_utils_test.py @@ -19,7 +19,7 @@ import jax.numpy as jnp import numpy as np -import conformal_training.train_utils as cputils +import train_utils as cputils class TrainUtilsTest(parameterized.TestCase): diff --git a/variational_sorting_net_test.py b/variational_sorting_net_test.py index ae42185..8fe3a45 100644 --- a/variational_sorting_net_test.py +++ b/variational_sorting_net_test.py @@ -19,8 +19,8 @@ import jax import jax.numpy as jnp -from conformal_training import sorting_nets -from conformal_training import variational_sorting_net +import sorting_nets +import variational_sorting_net class VariationalSortingNetTest(parameterized.TestCase):