Skip to content

Commit

Permalink
Merge pull request #346 from fboemer/fboemer/faster-ckks-mult
Browse files Browse the repository at this point in the history
Faster CKKS multiply
  • Loading branch information
Wei Dai authored Jun 8, 2021
2 parents 97e4b8d + 1ee4165 commit a44ca9a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ The optional dependencies and their tested versions (other versions may work as

#### Intel HEXL

Intel HEXL is a library providing efficient implementations of cryptographic primitives common in homomorphic encryption. The acceleration is particularly evident on Intel processors with the Intel AVX512-IMA52 instruction set.
Intel HEXL is a library providing efficient implementations of cryptographic primitives common in homomorphic encryption. The acceleration is particularly evident on Intel processors with the Intel AVX512-IFMA52 instruction set.

#### Microsoft GSL

Expand Down
90 changes: 80 additions & 10 deletions native/src/seal/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ namespace seal
#endif
}

void Evaluator::bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
void Evaluator::bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form())
{
Expand Down Expand Up @@ -506,6 +506,74 @@ namespace seal
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), dest_size);

if (dest_size == 3)
{
// We want to keep six polynomials in the L1 cache: x[0], x[1], x[2], y[0], y[1], temp.
// For a 32KiB cache, which can store 32768 / 8 = 4096 coefficients, = 682.67 coefficients per polynomial,
// we should keep the tile size at 682 or below. The tile size must divide coeff_count, i.e. be a power of
// two. Some testing shows similar performance with tile size 256 and 512, and worse performance on smaller
// tiles. We pick the smaller of the two to prevent L1 cache misses on processors with < 32 KiB L1 cache.
size_t tile_size = min<size_t>(coeff_count, size_t(256));
size_t num_tiles = coeff_count / tile_size;
#ifdef SEAL_DEBUG
if (coeff_count % tile_size != 0)
{
throw invalid_argument("tile_size does not divide coeff_count");
}
#endif

// Set up iterators for input ciphertexts
PolyIter encrypted1_iter = iter(encrypted1);
ConstPolyIter encrypted2_iter = iter(encrypted2);

// Semantic misuse of RNSIter; each is really pointing to the data for each RNS factor in sequence
ConstRNSIter encrypted2_0_iter(*encrypted2_iter[0], tile_size);
ConstRNSIter encrypted2_1_iter(*encrypted2_iter[1], tile_size);
RNSIter encrypted1_0_iter(*encrypted1_iter[0], tile_size);
RNSIter encrypted1_1_iter(*encrypted1_iter[1], tile_size);
RNSIter encrypted1_2_iter(*encrypted1_iter[2], tile_size);

// Temporary buffer to store intermediate results
SEAL_ALLOCATE_GET_COEFF_ITER(temp, tile_size, pool);

// Computes the output tile_size coefficients at a time
// Given input tuples of polynomials x = (x[0], x[1], x[2]), y = (y[0], y[1]), computes
// x = (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1])
// with appropriate modular reduction
SEAL_ITERATE(coeff_modulus, coeff_modulus_size, [&](auto I) {
SEAL_ITERATE(iter(size_t(0)), num_tiles, [&](auto J) {
// Compute third output polynomial, overwriting input
// x[2] = x[1] * y[1]
dyadic_product_coeffmod(
encrypted1_1_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_2_iter[0]);

// Compute second output polynomial, overwriting input
// temp = x[1] * y[0]
dyadic_product_coeffmod(encrypted1_1_iter[0], encrypted2_0_iter[0], tile_size, I, temp);
// x[1] = x[0] * y[1]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_1_iter[0]);
// x[1] += temp
add_poly_coeffmod(encrypted1_1_iter[0], temp, tile_size, I, encrypted1_1_iter[0]);

// Compute first output polynomial, overwriting input
// x[0] = x[0] * y[0]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_0_iter[0], tile_size, I, encrypted1_0_iter[0]);

// Manually increment iterators
++encrypted1_0_iter;
++encrypted1_1_iter;
++encrypted1_2_iter;
++encrypted2_0_iter;
++encrypted2_1_iter;
});
});

encrypted1.scale() = new_scale;
return;
}

// Set up iterators for input ciphertexts
auto encrypted1_iter = iter(encrypted1);
auto encrypted2_iter = iter(encrypted2);
Expand Down Expand Up @@ -921,7 +989,8 @@ namespace seal
}
}

void Evaluator::mod_switch_drop_to_next(const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
void Evaluator::mod_switch_drop_to_next(
const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
{
// Assuming at this point encrypted is already validated.
auto context_data_ptr = context_.get_context_data(encrypted.parms_id());
Expand Down Expand Up @@ -1020,7 +1089,8 @@ namespace seal
plain.parms_id() = next_context_data.parms_id();
}

void Evaluator::mod_switch_to_next(const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
void Evaluator::mod_switch_to_next(
const Ciphertext &encrypted, Ciphertext &destination, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1627,7 +1697,7 @@ namespace seal
encrypted.scale() = new_scale;
}

void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const
void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const
{
// Verify parameters.
if (!plain_ntt.is_ntt_form())
Expand Down Expand Up @@ -1668,7 +1738,7 @@ namespace seal
encrypted_ntt.scale() = new_scale;
}

void Evaluator::transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, MemoryPoolHandle pool) const
void Evaluator::transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_valid_for(plain, context_))
Expand Down Expand Up @@ -1761,7 +1831,7 @@ namespace seal
plain.parms_id() = parms_id;
}

void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) const
void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1809,7 +1879,7 @@ namespace seal
#endif
}

void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) const
void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted_ntt, context_) || !is_buffer_valid(encrypted_ntt))
Expand Down Expand Up @@ -1857,7 +1927,7 @@ namespace seal
}

void Evaluator::apply_galois_inplace(
Ciphertext &encrypted, uint32_t galois_elt, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
Ciphertext &encrypted, uint32_t galois_elt, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
Expand Down Expand Up @@ -1961,7 +2031,7 @@ namespace seal
}

void Evaluator::rotate_internal(
Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, MemoryPoolHandle pool) const
{
auto context_data_ptr = context_.get_context_data(encrypted.parms_id());
if (!context_data_ptr)
Expand Down Expand Up @@ -2019,7 +2089,7 @@ namespace seal

void Evaluator::switch_key_inplace(
Ciphertext &encrypted, ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, size_t kswitch_keys_index,
MemoryPoolHandle pool) const
MemoryPoolHandle pool) const
{
auto parms_id = encrypted.parms_id();
auto &context_data = *context_.get_context_data(parms_id);
Expand Down

0 comments on commit a44ca9a

Please sign in to comment.