forked from google/differential-privacy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
325 lines (269 loc) · 11.2 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# Copyright 2020 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classes and functions for the accounting library."""
import math
import typing
import dataclasses
import numpy as np
from scipy import fft
from scipy import signal
@dataclasses.dataclass
class DifferentialPrivacyParameters(object):
"""Representation of the differential privacy parameters of a mechanism.
Attributes:
epsilon: the epsilon in (epsilon, delta)-differential privacy.
delta: the delta in (epsilon, delta)-differential privacy.
"""
epsilon: float
delta: float = 0
def __post_init__(self):
if self.epsilon < 0:
raise ValueError(f'epsilon should be positive: {self.epsilon}')
if self.delta < 0 or self.delta > 1:
raise ValueError(f'delta should be between 0 and 1: {self.delta}')
@dataclasses.dataclass
class BinarySearchParameters(object):
"""Parameters used for binary search.
Attributes:
upper_bound: An upper bound on the binary search range.
lower_bound: A lower bound on the binary search range.
initial_guess: An initial guess to start the search with. Must be positive.
When this guess is close to the true value, it can help make the binary
search faster.
tolerance: An acceptable error on the returned value.
discrete: Whether the search is over integers.
"""
lower_bound: float
upper_bound: float
initial_guess: typing.Optional[float] = None
tolerance: float = 1e-7
discrete: bool = False
def inverse_monotone_function(
func: typing.Callable[[float], float],
value: float,
search_parameters: BinarySearchParameters,
increasing: bool = False) -> typing.Optional[float]:
"""Inverse a monotone function.
Args:
func: The function to be inversed.
value: The desired value of the function.
search_parameters: Parameters used for binary search.
increasing: Whether the function is monotonically increasing.
Returns:
x such that func(x) is no more than value, when such x exists. It is
guaranteed that the returned x is within search_parameters.tolerance of the
smallest (for monotonically decreasing func) or the largest (for
monotonically increasing func) such x. When no such x exists within the
given range, returns None.
"""
lower_x = search_parameters.lower_bound
upper_x = search_parameters.upper_bound
initial_guess_x = search_parameters.initial_guess
if increasing:
check = lambda func_value, target_value: func_value <= target_value
if lower_x != -math.inf and func(lower_x) > value:
return None
else:
check = lambda func_value, target_value: func_value > target_value
if upper_x != math.inf and func(upper_x) > value:
return None
if initial_guess_x is not None:
while initial_guess_x < upper_x and check(func(initial_guess_x), value):
lower_x = initial_guess_x
initial_guess_x *= 2
upper_x = min(upper_x, initial_guess_x)
if search_parameters.discrete:
tolerance = 1
else:
tolerance = search_parameters.tolerance
while upper_x - lower_x > tolerance:
if search_parameters.discrete:
mid_x = (upper_x + lower_x) // 2
else:
mid_x = (upper_x + lower_x) / 2
if check(func(mid_x), value):
lower_x = mid_x
else:
upper_x = mid_x
if increasing:
return lower_x
else:
return upper_x
def dictionary_to_list(
input_dictionary: typing.Mapping[int, float]
) -> typing.Tuple[int, typing.List[float]]:
"""Converts an integer-keyed dictionary into an list.
Args:
input_dictionary: A dictionary whose keys are integers.
Returns:
A tuple of an integer offset and a list result_list. The offset is the
minimum value of the input dictionary. result_list has length equal to the
difference between the maximum and minimum values of the input dictionary.
result_list[i] is equal to dictionary[offset + i] and is zero if offset + i
is not a key in the input dictionary.
"""
offset = min(input_dictionary)
max_val = max(input_dictionary)
result_list = [input_dictionary.get(i, 0) for i in range(offset, max_val + 1)]
return (offset, result_list)
def list_to_dictionary(
input_list: typing.List[float],
offset: int,
tail_mass_truncation: float = 0) -> typing.Mapping[int, float]:
"""Converts a list into an integer-keyed dictionary, with a specified offset.
Args:
input_list: An input list.
offset: The offset in the key of the output dictionary
tail_mass_truncation: an upper bound on the tails of the input list that
might be truncated.
Returns:
A dictionary whose value at key is equal to input_list[key - offset]. If
input_list[key - offset] is less than or equal to zero, it is not included
in the dictionary.
"""
lower_truncation_index = 0
lower_truncation_mass = 0
while lower_truncation_index < len(input_list):
lower_truncation_mass += input_list[lower_truncation_index]
if lower_truncation_mass > tail_mass_truncation / 2:
break
lower_truncation_index += 1
upper_truncation_index = len(input_list) - 1
upper_truncation_mass = 0
while upper_truncation_index >= 0:
upper_truncation_mass += input_list[upper_truncation_index]
if upper_truncation_mass > tail_mass_truncation / 2:
break
upper_truncation_index -= 1
result_dictionary = {}
for i in range(lower_truncation_index, upper_truncation_index + 1):
if input_list[i] > 0:
result_dictionary[i + offset] = input_list[i]
return result_dictionary
def convolve_dictionary(
dictionary1: typing.Mapping[int, float],
dictionary2: typing.Mapping[int, float],
tail_mass_truncation: float = 0) -> typing.Mapping[int, float]:
"""Computes a convolution of two dictionaries.
Args:
dictionary1: The first dictionary whose keys are integers.
dictionary2: The second dictionary whose keys are integers.
tail_mass_truncation: an upper bound on the tails of the output that might
be truncated.
Returns:
The dictionary where for each key its corresponding value is the sum, over
all key1, key2 such that key1 + key2 = key, of dictionary1[key1] times
dictionary2[key2]
"""
# Convert the dictionaries to lists.
min1, list1 = dictionary_to_list(dictionary1)
min2, list2 = dictionary_to_list(dictionary2)
# Compute the convolution of the two lists.
result_list = signal.fftconvolve(list1, list2)
# Convert the list back to a dictionary and return
return list_to_dictionary(
result_list, min1 + min2, tail_mass_truncation=tail_mass_truncation)
def compute_self_convolve_bounds(
input_list: typing.List[float],
num_times: int,
tail_mass_truncation: float = 0,
orders: typing.Optional[typing.List[float]] = None
) -> typing.Tuple[int, int]:
"""Computes truncation bounds for convolution using Chernoff bound.
Args:
input_list: The input list to be convolved.
num_times: The number of times the list is to be convolved with itself.
tail_mass_truncation: an upper bound on the tails of the output that might
be truncated.
orders: a list of orders on which the Chernoff bound is applied.
Returns:
A pair of upper and lower bounds for which the mass of the result of
convolution outside of this range is at most tail_mass_truncation.
"""
upper_bound = (len(input_list) - 1) * num_times
lower_bound = 0
if tail_mass_truncation == 0:
return lower_bound, upper_bound
if orders is None:
# Set orders so whose absolute values are not too large; otherwise, we may
# run into numerical issues.
orders = (np.concatenate((np.arange(-20, 0), np.arange(1, 21)))
/ len(input_list))
# Compute log of the moment generating function at the specified orders.
log_mgfs = np.log([
np.dot(np.exp(np.arange(len(input_list)) * order), input_list)
for order in orders
])
for order, log_mgf_value in zip(orders, log_mgfs):
# Use Chernoff bound to update the upper/lower bound. See equation (5) in
# the supplementary material.
bound = (num_times * log_mgf_value +
math.log(2 / tail_mass_truncation)) / order
if order > 0:
upper_bound = min(upper_bound, math.ceil(bound))
if order < 0:
lower_bound = max(lower_bound, math.floor(bound))
return lower_bound, upper_bound
def self_convolve(
input_list: typing.List[float],
num_times: int,
tail_mass_truncation: float = 0) -> typing.Tuple[int, typing.List[float]]:
"""Computes a convolution of the input list with itself num_times times.
Args:
input_list: The input list to be convolved.
num_times: The number of times the list is to be convolved with itself.
tail_mass_truncation: an upper bound on the tails of the output that might
be truncated.
Returns:
A pair of truncation_lower_bound, output_list, where the i-th entry of
output_list is approximately the sum, over all i_1, i_2, ..., i_num_times
such that i_1 + i_2 + ... + i_num_times = i + truncation_lower_bound,
of input_list[i_1] * input_list[i_2] * ... * input_list[i_num_times].
"""
truncation_lower_bound, truncation_upper_bound = compute_self_convolve_bounds(
input_list, num_times, tail_mass_truncation)
# Use FFT to compute the convolution
fast_len = fft.next_fast_len(truncation_upper_bound -
truncation_lower_bound + 1)
truncated_convolution_output = np.real(
fft.ifft(fft.fft(input_list, fast_len)**num_times))
# Discrete Fourier Transform wraps around module fast_len. Extract the output
# values in the range of interest.
output_list = [
truncated_convolution_output[i % fast_len]
for i in range(truncation_lower_bound, truncation_upper_bound + 1)
]
return truncation_lower_bound, output_list
def self_convolve_dictionary(
input_dictionary: typing.Mapping[int, float],
num_times: int,
tail_mass_truncation: float = 0) -> typing.Mapping[int, float]:
"""Computes a convolution of the input dictionary with itself num_times times.
Args:
input_dictionary: The input dictionary whose keys are integers.
num_times: The number of times the dictionary is to be convolved with
itself.
tail_mass_truncation: an upper bound on the tails of the output that might
be truncated.
Returns:
The dictionary where for each key its corresponding value is the sum, over
all key1, key2, ..., key_num_times such that key1 + key2 + ... +
key_num_times = key, of input_dictionary[key1] * input_dictionary[key2] *
... * input_dictionary[key_num_times]
"""
min_val, input_list = dictionary_to_list(input_dictionary)
min_val_convolution, output_list = self_convolve(
input_list, num_times, tail_mass_truncation=tail_mass_truncation)
return list_to_dictionary(output_list,
min_val * num_times + min_val_convolution)