diff --git a/cmake/ExternalIntelHEXL.cmake b/cmake/ExternalIntelHEXL.cmake index 8a1f971d8..da1a9227d 100644 --- a/cmake/ExternalIntelHEXL.cmake +++ b/cmake/ExternalIntelHEXL.cmake @@ -5,7 +5,7 @@ FetchContent_Declare( hexl PREFIX hexl GIT_REPOSITORY https://github.com/intel/hexl - GIT_TAG c28943d # v1.1.0 + GIT_TAG 2dc1db # v1.1.0 ) FetchContent_GetProperties(hexl) diff --git a/native/src/seal/util/CMakeLists.txt b/native/src/seal/util/CMakeLists.txt index b0885bc6d..98d045059 100644 --- a/native/src/seal/util/CMakeLists.txt +++ b/native/src/seal/util/CMakeLists.txt @@ -13,6 +13,7 @@ set(SEAL_SOURCE_FILES ${SEAL_SOURCE_FILES} ${CMAKE_CURRENT_LIST_DIR}/galois.cpp ${CMAKE_CURRENT_LIST_DIR}/hash.cpp ${CMAKE_CURRENT_LIST_DIR}/iterator.cpp + ${CMAKE_CURRENT_LIST_DIR}/intel_seal_ext.cpp ${CMAKE_CURRENT_LIST_DIR}/mempool.cpp ${CMAKE_CURRENT_LIST_DIR}/numth.cpp ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp diff --git a/native/src/seal/util/intel_seal_ext.cpp b/native/src/seal/util/intel_seal_ext.cpp new file mode 100644 index 000000000..74fd48d78 --- /dev/null +++ b/native/src/seal/util/intel_seal_ext.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/defines.h" + +#ifdef SEAL_USE_INTEL_HEXL +#include "seal/memorymanager.h" +#include "seal/util/intel_seal_ext.h" +#include "seal/util/locks.h" +#include +#include "hexl/hexl.hpp" + +namespace intel +{ + namespace seal_ext + { + intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root) + { + static std::unordered_map, intel::hexl::NTT, seal_ext::HashPair> ntt_cache_; + + static seal::util::ReaderWriterLocker ntt_cache_locker_; + + std::pair key{ N, modulus }; + + // Enable shared access of NTT already present + { + seal::util::ReaderLock reader_lock(ntt_cache_locker_.acquire_read()); + auto ntt_it = ntt_cache_.find(key); + if (ntt_it != ntt_cache_.end()) + { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + seal::util::WriterLock write_lock(ntt_cache_locker_.acquire_write()); + + // Check ntt_cache for value (maybe added by another thread) + auto ntt_it = ntt_cache_.find(key); + if (ntt_it == ntt_cache_.end()) + { + intel::hexl::NTT ntt( + N, modulus, root, seal::MemoryManager::GetPool(), intel::hexl::SimpleThreadSafePolicy{}); + ntt_it = ntt_cache_.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; + } + } // namespace seal_ext +} // namespace intel + +#endif diff --git a/native/src/seal/util/intel_seal_ext.h b/native/src/seal/util/intel_seal_ext.h index 8bb78ef4a..567a0cf5f 100644 --- a/native/src/seal/util/intel_seal_ext.h +++ b/native/src/seal/util/intel_seal_ext.h @@ -3,9 +3,13 @@ #pragma once +#include "seal/util/defines.h" + #ifdef SEAL_USE_INTEL_HEXL #include "seal/memorymanager.h" +#include "seal/util/iterator.h" #include "seal/util/locks.h" +#include "seal/util/pointer.h" #include #include "hexl/hexl.hpp" @@ -140,10 +144,6 @@ namespace intel } }; - static std::unordered_map, intel::hexl::NTT, HashPair> ntt_cache_; - - static seal::util::ReaderWriterLocker ntt_cache_locker_; - /** Returns a HEXL NTT object corresponding to the given parameters. @@ -151,36 +151,10 @@ namespace intel @param[in] modulus The modulus @param[in] root The root of unity */ - static intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root) - { - std::pair key{ N, modulus }; - - // Enable shared access of NTT already present - { - seal::util::ReaderLock reader_lock(ntt_cache_locker_.acquire_read()); - auto ntt_it = ntt_cache_.find(key); - if (ntt_it != ntt_cache_.end()) - { - return ntt_it->second; - } - } - - // Deal with NTT not yet present - seal::util::WriterLock write_lock(ntt_cache_locker_.acquire_write()); - - // Check ntt_cache for value (maybe added by another thread) - auto ntt_it = ntt_cache_.find(key); - if (ntt_it == ntt_cache_.end()) - { - intel::hexl::NTT ntt( - N, modulus, root, seal::MemoryManager::GetPool(), intel::hexl::SimpleThreadSafePolicy{}); - ntt_it = ntt_cache_.emplace(std::move(key), std::move(ntt)).first; - } - return ntt_it->second; - } + intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root); /** - Computes for forward negacyclic NTT from the given parameters. + Computes the forward negacyclic NTT from the given parameters. @param[in,out] operand The data on which to compute the NTT. @param[in] N The polynomial modulus degree @@ -189,7 +163,7 @@ namespace intel @param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus) @param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus) */ - static void compute_forward_ntt( + inline void compute_forward_ntt( seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -197,7 +171,7 @@ namespace intel } /** - Computes for inverse negacyclic NTT from the given parameters. + Computes the inverse negacyclic NTT from the given parameters. @param[in,out] operand The data on which to compute the NTT. @param[in] N The polynomial modulus degree @@ -206,7 +180,7 @@ namespace intel @param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus) @param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus) */ - static void compute_inverse_ntt( + inline void compute_inverse_ntt( seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -215,4 +189,5 @@ namespace intel } // namespace seal_ext } // namespace intel + #endif diff --git a/native/src/seal/util/ntt.cpp b/native/src/seal/util/ntt.cpp index 621df3873..bc36b1467 100644 --- a/native/src/seal/util/ntt.cpp +++ b/native/src/seal/util/ntt.cpp @@ -49,6 +49,11 @@ namespace seal throw invalid_argument("invalid modulus"); } +#ifdef SEAL_USE_INTEL_HEXL + // Pre-compute HEXL NTT object + intel::seal_ext::get_ntt(coeff_count_, modulus.value(), root_); +#endif + // Populate tables with powers of root in specific orders. root_powers_ = allocate(coeff_count_, pool_); MultiplyUIntModOperand root;