-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathcolab_utils_test.py
92 lines (81 loc) · 3.2 KB
/
colab_utils_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for evaluation utilities."""
import os
from absl.testing import absltest
from absl.testing import parameterized
import jax
import ml_collections as collections
import numpy as np
import colab_utils as cpcolab
import data_utils as cpdatautils
import test_utils as cptutils
class ColabUtilsTest(parameterized.TestCase):
def _get_model(self, num_examples, num_classes):
val_examples = num_examples//2
labels = cptutils.get_labels(num_examples, num_classes)
logits = cptutils.get_probabilities(labels, dominance=0.5)
config = collections.ConfigDict()
config.dataset = 'cifar10'
config.val_examples = val_examples
data = cpdatautils.get_data_stats(config)
data['groups'] = {'groups': cpcolab.get_groups(config.dataset, 'groups')}
model = {
'val_logits': logits[:val_examples],
'val_labels': labels[:val_examples],
'test_logits': logits[val_examples:],
'test_labels': labels[val_examples:],
'data': data,
}
return model
def _check_results(self, results):
self.assertIn('mean', results.keys())
self.assertIn('std', results.keys())
if os.getenv('EVAL_VAL', '0') == '1':
self.assertIn('val', results['mean'].keys())
self.assertIn('test', results['mean'].keys())
# Just test whether some basic metrics are there and not NaN or so.
metrics_to_check = [
'size', 'coverage', 'accuracy',
'class_size_0', 'class_coverage_0',
'size_0', 'cumulative_size_0',
'groups_miscoverage',
]
if os.getenv('EVAL_CONFUSION') == '1':
metrics_to_check += [
'classification_confusion_0_0', 'coverage_confusion_0_0'
]
for metric in metrics_to_check:
mean = results['mean']['test'][metric]
std = results['std']['test'][metric]
self.assertFalse(np.isnan(mean))
self.assertFalse(np.isinf(mean))
self.assertGreaterEqual(mean, 0.)
self.assertFalse(np.isnan(std))
self.assertFalse(np.isinf(std))
self.assertGreaterEqual(std, 0.)
# Extra check for cumulative size
self.assertAlmostEqual(results['mean']['test']['cumulative_size_9'], 1)
def test_evaluate_conformal_prediction(self):
num_examples = 1000
num_classes = 10
model = self._get_model(num_examples, num_classes)
calibrate_fn, predict_fn = cpcolab.get_threshold_fns(0.05, jit=True)
rng = jax.random.PRNGKey(0)
results = cpcolab.evaluate_conformal_prediction(
model, calibrate_fn, predict_fn, trials=2, rng=rng)
self._check_results(results)
if __name__ == '__main__':
absltest.main()