-
Notifications
You must be signed in to change notification settings - Fork 714
/
Copy pathbaseconverter.h
291 lines (218 loc) · 9.78 KB
/
baseconverter.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#include <stdexcept>
#include <vector>
#include <memory>
#include "seal/util/pointer.h"
#include "seal/memorymanager.h"
#include "seal/smallmodulus.h"
#include "seal/util/smallntt.h"
#include "seal/biguint.h"
namespace seal
{
namespace util
{
class BaseConverter
{
public:
BaseConverter(MemoryPoolHandle pool) : pool_(std::move(pool))
{
if (!pool_)
{
throw std::invalid_argument("pool is uninitialized");
}
}
BaseConverter(const std::vector<SmallModulus> &coeff_base,
std::size_t coeff_count, const SmallModulus &small_plain_mod,
MemoryPoolHandle pool);
/**
Generates the pre-computations for the given parameters.
*/
void generate(const std::vector<SmallModulus> &coeff_base,
std::size_t coeff_count, const SmallModulus &small_plain_mod);
void floor_last_coeff_modulus_inplace(
std::uint64_t *rns_poly,
MemoryPoolHandle pool) const;
void floor_last_coeff_modulus_ntt_inplace(
std::uint64_t *rns_poly,
const Pointer<SmallNTTTables> &rns_ntt_tables,
MemoryPoolHandle pool) const;
void round_last_coeff_modulus_inplace(
std::uint64_t *rns_poly,
MemoryPoolHandle pool) const;
void round_last_coeff_modulus_ntt_inplace(
std::uint64_t *rns_poly,
const Pointer<SmallNTTTables> &rns_ntt_tables,
MemoryPoolHandle pool) const;
/**
Fast base converter from q to Bsk
*/
void fastbconv(const std::uint64_t *input,
std::uint64_t *destination, MemoryPoolHandle pool) const;
/**
Fast base converter from Bsk to q
*/
void fastbconv_sk(const std::uint64_t *input,
std::uint64_t *destination, MemoryPoolHandle pool) const;
/**
Reduction from Bsk U {m_tilde} to Bsk
*/
void mont_rq(const std::uint64_t *input,
std::uint64_t *destination) const;
/**
Fast base converter from q U Bsk to Bsk
*/
void fast_floor(const std::uint64_t *input,
std::uint64_t *destination, MemoryPoolHandle pool) const;
/**
Fast base converter from q to Bsk U {m_tilde}
*/
void fastbconv_mtilde(const std::uint64_t *input,
std::uint64_t *destination, MemoryPoolHandle pool) const;
/**
Fast base converter from q to plain_modulus U {gamma}
*/
void fastbconv_plain_gamma(const std::uint64_t *input,
std::uint64_t *destination, MemoryPoolHandle pool) const;
void reset() noexcept;
SEAL_NODISCARD inline auto is_generated() const noexcept
{
return generated_;
}
SEAL_NODISCARD inline auto coeff_base_mod_count() const noexcept
{
return coeff_base_mod_count_;
}
SEAL_NODISCARD inline auto aux_base_mod_count() const noexcept
{
return aux_base_mod_count_;
}
SEAL_NODISCARD inline auto &get_plain_gamma_product() const noexcept
{
return plain_gamma_product_mod_coeff_array_;
}
SEAL_NODISCARD inline auto &get_neg_inv_coeff() const noexcept
{
return neg_inv_coeff_products_all_mod_plain_gamma_array_;
}
SEAL_NODISCARD inline auto &get_plain_gamma_array() const noexcept
{
return plain_gamma_array_;
}
SEAL_NODISCARD inline auto get_coeff_products_array() const noexcept
-> const std::uint64_t *
{
return coeff_products_array_.get();
}
SEAL_NODISCARD inline std::uint64_t get_inv_gamma() const noexcept
{
return inv_gamma_mod_plain_;
}
SEAL_NODISCARD inline auto &get_bsk_small_ntt_tables() const noexcept
{
return bsk_small_ntt_tables_;
}
SEAL_NODISCARD inline auto bsk_base_mod_count() const noexcept
{
return bsk_base_mod_count_;
}
SEAL_NODISCARD inline auto &get_bsk_mod_array() const noexcept
{
return bsk_base_array_;
}
SEAL_NODISCARD inline auto &get_msk() const noexcept
{
return m_sk_;
}
SEAL_NODISCARD inline auto &get_m_tilde() const noexcept
{
return m_tilde_;
}
SEAL_NODISCARD inline auto &get_mtilde_inv_coeff_products_mod_coeff() const noexcept
{
return mtilde_inv_coeff_base_products_mod_coeff_array_;
}
SEAL_NODISCARD inline auto &get_inv_coeff_mod_mtilde() const noexcept
{
return inv_coeff_products_mod_mtilde_;
}
SEAL_NODISCARD inline auto &get_inv_coeff_mod_coeff_array() const noexcept
{
return inv_coeff_base_products_mod_coeff_array_;
}
SEAL_NODISCARD inline auto &get_inv_last_coeff_mod_array() const noexcept
{
return inv_last_coeff_mod_array_;
}
SEAL_NODISCARD inline auto &get_coeff_base_products_mod_msk() const noexcept
{
return coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1];
}
private:
BaseConverter(const BaseConverter ©) = delete;
BaseConverter(BaseConverter &&source) = delete;
BaseConverter &operator =(const BaseConverter &assign) = delete;
BaseConverter &operator =(BaseConverter &&assign) = delete;
MemoryPoolHandle pool_;
bool generated_ = false;
std::size_t coeff_count_ = 0;
std::size_t coeff_base_mod_count_ = 0;
std::size_t aux_base_mod_count_ = 0;
std::size_t bsk_base_mod_count_ = 0;
std::size_t plain_gamma_count_ = 0;
// Array of coefficient small moduli
Pointer<SmallModulus> coeff_base_array_;
// Array of auxiliary moduli
Pointer<SmallModulus> aux_base_array_;
// Array of auxiliary U {m_sk_} moduli
Pointer<SmallModulus> bsk_base_array_;
// Array of plain modulus U gamma
Pointer<SmallModulus> plain_gamma_array_;
// Punctured products of the coeff moduli
Pointer<std::uint64_t> coeff_products_array_;
// Matrix which contains the products of coeff moduli mod aux
Pointer<Pointer<std::uint64_t>> coeff_base_products_mod_aux_bsk_array_;
// Array of inverse coeff modulus products mod each small coeff mods
Pointer<std::uint64_t> inv_coeff_base_products_mod_coeff_array_;
// Array of coeff moduli products mod m_tilde
Pointer<std::uint64_t> coeff_base_products_mod_mtilde_array_;
// Array of coeff modulus products times m_tilda mod each coeff modulus
Pointer<std::uint64_t> mtilde_inv_coeff_base_products_mod_coeff_array_;
// Matrix of the inversion of coeff modulus products mod each auxiliary mods
Pointer<std::uint64_t> inv_coeff_products_all_mod_aux_bsk_array_;
// Matrix of auxiliary mods products mod each coeff modulus
Pointer<Pointer<std::uint64_t>> aux_base_products_mod_coeff_array_;
// Array of inverse auxiliary mod products mod each auxiliary mods
Pointer<std::uint64_t> inv_aux_base_products_mod_aux_array_;
// Array of auxiliary bases products mod m_sk_
Pointer<std::uint64_t> aux_base_products_mod_msk_array_;
// Coeff moduli products inverse mod m_tilde
std::uint64_t inv_coeff_products_mod_mtilde_ = 0;
// Auxiliary base products mod m_sk_ (m1*m2*...*ml)-1 mod m_sk
std::uint64_t inv_aux_products_mod_msk_ = 0;
// Gamma inverse mod plain modulus
std::uint64_t inv_gamma_mod_plain_ = 0;
// Auxiliary base products mod coeff moduli (m1*m2*...*ml) mod qi
Pointer<std::uint64_t> aux_products_all_mod_coeff_array_;
// Array of m_tilde inverse mod Bsk = m U {msk}
Pointer<std::uint64_t> inv_mtilde_mod_bsk_array_;
// Array of all coeff base products mod Bsk
Pointer<std::uint64_t> coeff_products_all_mod_bsk_array_;
// Matrix of coeff base product mod plain modulus and gamma
Pointer<Pointer<std::uint64_t>> coeff_products_mod_plain_gamma_array_;
// Array of negative inverse all coeff base product mod plain modulus and gamma
Pointer<std::uint64_t> neg_inv_coeff_products_all_mod_plain_gamma_array_;
// Array of plain_gamma_product mod coeff base moduli
Pointer<std::uint64_t> plain_gamma_product_mod_coeff_array_;
// Array of small NTT tables for moduli in Bsk
Pointer<SmallNTTTables> bsk_small_ntt_tables_;
// For modulus switching: inverses of the last coeff base modulus
Pointer<std::uint64_t> inv_last_coeff_mod_array_;
SmallModulus m_tilde_;
SmallModulus m_sk_;
SmallModulus small_plain_mod_;
SmallModulus gamma_;
};
}
}