Skip to content

Commit

Permalink
Fboemer/cleaner ntt
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer committed Jun 7, 2021
1 parent 97e4b8d commit 23a1c2b
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 36 deletions.
2 changes: 1 addition & 1 deletion cmake/ExternalIntelHEXL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ FetchContent_Declare(
hexl
PREFIX hexl
GIT_REPOSITORY https://github.com/intel/hexl
GIT_TAG c28943d # v1.1.0
GIT_TAG 2dc1db # v1.1.0
)
FetchContent_GetProperties(hexl)

Expand Down
1 change: 1 addition & 0 deletions native/src/seal/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ set(SEAL_SOURCE_FILES ${SEAL_SOURCE_FILES}
${CMAKE_CURRENT_LIST_DIR}/galois.cpp
${CMAKE_CURRENT_LIST_DIR}/hash.cpp
${CMAKE_CURRENT_LIST_DIR}/iterator.cpp
${CMAKE_CURRENT_LIST_DIR}/intel_seal_ext.cpp
${CMAKE_CURRENT_LIST_DIR}/mempool.cpp
${CMAKE_CURRENT_LIST_DIR}/numth.cpp
${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp
Expand Down
51 changes: 51 additions & 0 deletions native/src/seal/util/intel_seal_ext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "seal/util/defines.h"

#ifdef SEAL_USE_INTEL_HEXL
#include "seal/memorymanager.h"
#include "seal/util/intel_seal_ext.h"
#include "seal/util/locks.h"
#include <unordered_map>
#include "hexl/hexl.hpp"

namespace intel
{
namespace seal_ext
{
intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root)
{
static std::unordered_map<std::pair<uint64_t, uint64_t>, intel::hexl::NTT, seal_ext::HashPair> ntt_cache_;

static seal::util::ReaderWriterLocker ntt_cache_locker_;

std::pair<uint64_t, uint64_t> key{ N, modulus };

// Enable shared access of NTT already present
{
seal::util::ReaderLock reader_lock(ntt_cache_locker_.acquire_read());
auto ntt_it = ntt_cache_.find(key);
if (ntt_it != ntt_cache_.end())
{
return ntt_it->second;
}
}

// Deal with NTT not yet present
seal::util::WriterLock write_lock(ntt_cache_locker_.acquire_write());

// Check ntt_cache for value (maybe added by another thread)
auto ntt_it = ntt_cache_.find(key);
if (ntt_it == ntt_cache_.end())
{
intel::hexl::NTT ntt(
N, modulus, root, seal::MemoryManager::GetPool(), intel::hexl::SimpleThreadSafePolicy{});
ntt_it = ntt_cache_.emplace(std::move(key), std::move(ntt)).first;
}
return ntt_it->second;
}
} // namespace seal_ext
} // namespace intel

#endif
45 changes: 10 additions & 35 deletions native/src/seal/util/intel_seal_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

#pragma once

#include "seal/util/defines.h"

#ifdef SEAL_USE_INTEL_HEXL
#include "seal/memorymanager.h"
#include "seal/util/iterator.h"
#include "seal/util/locks.h"
#include "seal/util/pointer.h"
#include <unordered_map>
#include "hexl/hexl.hpp"

Expand Down Expand Up @@ -140,47 +144,17 @@ namespace intel
}
};

static std::unordered_map<std::pair<uint64_t, uint64_t>, intel::hexl::NTT, HashPair> ntt_cache_;

static seal::util::ReaderWriterLocker ntt_cache_locker_;

/**
Returns a HEXL NTT object corresponding to the given parameters.
@param[in] N The polynomial modulus degree
@param[in] modulus The modulus
@param[in] root The root of unity
*/
static intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root)
{
std::pair<uint64_t, uint64_t> key{ N, modulus };

// Enable shared access of NTT already present
{
seal::util::ReaderLock reader_lock(ntt_cache_locker_.acquire_read());
auto ntt_it = ntt_cache_.find(key);
if (ntt_it != ntt_cache_.end())
{
return ntt_it->second;
}
}

// Deal with NTT not yet present
seal::util::WriterLock write_lock(ntt_cache_locker_.acquire_write());

// Check ntt_cache for value (maybe added by another thread)
auto ntt_it = ntt_cache_.find(key);
if (ntt_it == ntt_cache_.end())
{
intel::hexl::NTT ntt(
N, modulus, root, seal::MemoryManager::GetPool(), intel::hexl::SimpleThreadSafePolicy{});
ntt_it = ntt_cache_.emplace(std::move(key), std::move(ntt)).first;
}
return ntt_it->second;
}
intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root);

/**
Computes for forward negacyclic NTT from the given parameters.
Computes the forward negacyclic NTT from the given parameters.
@param[in,out] operand The data on which to compute the NTT.
@param[in] N The polynomial modulus degree
Expand All @@ -189,15 +163,15 @@ namespace intel
@param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus)
@param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus)
*/
static void compute_forward_ntt(
inline void compute_forward_ntt(
seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor,
uint64_t output_mod_factor)
{
get_ntt(N, modulus, root).ComputeForward(operand, operand, input_mod_factor, output_mod_factor);
}

/**
Computes for inverse negacyclic NTT from the given parameters.
Computes the inverse negacyclic NTT from the given parameters.
@param[in,out] operand The data on which to compute the NTT.
@param[in] N The polynomial modulus degree
Expand All @@ -206,7 +180,7 @@ namespace intel
@param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus)
@param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus)
*/
static void compute_inverse_ntt(
inline void compute_inverse_ntt(
seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor,
uint64_t output_mod_factor)
{
Expand All @@ -215,4 +189,5 @@ namespace intel

} // namespace seal_ext
} // namespace intel

#endif
5 changes: 5 additions & 0 deletions native/src/seal/util/ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ namespace seal
throw invalid_argument("invalid modulus");
}

#ifdef SEAL_USE_INTEL_HEXL
// Pre-compute HEXL NTT object
intel::seal_ext::get_ntt(coeff_count_, modulus.value(), root_);
#endif

// Populate tables with powers of root in specific orders.
root_powers_ = allocate<MultiplyUIntModOperand>(coeff_count_, pool_);
MultiplyUIntModOperand root;
Expand Down

0 comments on commit 23a1c2b

Please sign in to comment.