diff --git a/benchmark/bench-eltwise-cmp-sub-mod.cpp b/benchmark/bench-eltwise-cmp-sub-mod.cpp index f08bec24..c1be6b12 100644 --- a/benchmark/bench-eltwise-cmp-sub-mod.cpp +++ b/benchmark/bench-eltwise-cmp-sub-mod.cpp @@ -42,7 +42,7 @@ BENCHMARK(BM_EltwiseCmpSubModNative) #ifdef HEXL_HAS_AVX512DQ // state[0] is the degree -static void BM_EltwiseCmpSubModAVX512(benchmark::State& state) { // NOLINT +static void BM_EltwiseCmpSubModAVX512_64(benchmark::State& state) { // NOLINT size_t input_size = state.range(0); uint64_t modulus = 100; uint64_t bound = GenerateInsecureUniformRandomValue(0, modulus); @@ -50,12 +50,12 @@ static void BM_EltwiseCmpSubModAVX512(benchmark::State& state) { // NOLINT auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, modulus); for (auto _ : state) { - EltwiseCmpSubModAVX512(input1.data(), input1.data(), input_size, modulus, - CMPINT::NLT, bound, diff); + EltwiseCmpSubModAVX512<64>(input1.data(), input1.data(), input_size, + modulus, CMPINT::NLT, bound, diff); } } -BENCHMARK(BM_EltwiseCmpSubModAVX512) +BENCHMARK(BM_EltwiseCmpSubModAVX512_64) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) diff --git a/cmake/third-party/cpu-features/CMakeLists.txt.in b/cmake/third-party/cpu-features/CMakeLists.txt.in index 95c83bed..d0e37176 100644 --- a/cmake/third-party/cpu-features/CMakeLists.txt.in +++ b/cmake/third-party/cpu-features/CMakeLists.txt.in @@ -8,7 +8,7 @@ project(cpu-features-download NONE) include(ExternalProject) ExternalProject_Add(cpu_features GIT_REPOSITORY https://github.com/google/cpu_features.git - GIT_TAG master + GIT_TAG 32b49eb5e7809052a28422cfde2f2745fbb0eb76 # master branch on Oct 20, 2021 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-build" CONFIGURE_COMMAND "" diff --git a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp index 349c0e04..66c44a0d 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -15,7 +15,7 @@ namespace intel { namespace hexl { #ifdef HEXL_HAS_AVX512DQ -template +template void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, uint64_t modulus, CMPINT cmp, uint64_t bound, uint64_t diff) { @@ -51,12 +51,25 @@ void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, uint64_t prod_right_shift = ceil_log_mod + beta; __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + uint64_t alpha = BitShift - 2; + uint64_t mu_64 = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + + __m512i v_mu_64 = _mm512_set1_epi64(static_cast(mu_64)); + for (size_t i = n / 8; i > 0; --i) { __m512i v_op = _mm512_loadu_si512(v_op_ptr); __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); v_op = _mm512_hexl_barrett_reduce64( - v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod); + v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod); __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); v_to_add = _mm512_sub_epi64(v_to_add, v_diff); diff --git a/hexl/eltwise/eltwise-cmp-sub-mod.cpp b/hexl/eltwise/eltwise-cmp-sub-mod.cpp index 96074502..49619a16 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod.cpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod.cpp @@ -52,10 +52,8 @@ void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, for (size_t i = 0; i < n; ++i) { uint64_t op = operand1[i]; - bool op_cmp = Compare(cmp, op, bound); op %= modulus; - if (op_cmp) { op = SubUIntMod(op, diff, modulus); } diff --git a/hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp index cf2963c3..3d4f132e 100644 --- a/hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp +++ b/hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -19,7 +19,8 @@ namespace hexl { /// @param[in] bound Scalar to compare against /// @param[in] diff Scalar to subtract by /// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p -/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, uint64_t modulus, CMPINT cmp, uint64_t bound, uint64_t diff); diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 43718429..1c791214 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -7,6 +7,7 @@ #include +#include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" #include "hexl/util/check.hpp" #include "hexl/util/defines.hpp" @@ -389,6 +390,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, __mmask8 mask = _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT); if (mask != 0) { + // values above 2^52 __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); __m512i x_intr = _mm512_slli_epi64(x, static_cast(12ULL)); __m512i x_lo = @@ -408,8 +410,6 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); } else { __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52); - // Barrett subtraction - // tmp[0] = input - tmp[1] * q; __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); x = _mm512_sub_epi64(x, tmp1_times_mod); } @@ -417,24 +417,15 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, #endif if (BitShift == 64) { __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64); - // Barrett subtraction - // tmp[0] = input - tmp[1] * q; __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); x = _mm512_sub_epi64(x, tmp1_times_mod); } // Correction - if (OutputModFactor == 2) { - return x; - } else { - if (BitShift == 64) { - x = _mm512_hexl_small_mod_epu64(x, q); - } - if (BitShift == 52) { - x = _mm512_hexl_small_mod_epu64<2>(x, q); - } - return x; + if (OutputModFactor == 1) { + x = _mm512_hexl_small_mod_epu64<2>(x, q); } + return x; } // Concatenate packed 64-bit integers in x and y, producing an intermediate diff --git a/test/test-eltwise-cmp-sub-mod-avx512.cpp b/test/test-eltwise-cmp-sub-mod-avx512.cpp index 75697397..7114c9d8 100644 --- a/test/test-eltwise-cmp-sub-mod-avx512.cpp +++ b/test/test-eltwise-cmp-sub-mod-avx512.cpp @@ -18,14 +18,45 @@ namespace intel { namespace hexl { // Checks AVX512 and native implementations match -#ifdef HEXL_HAS_AVX512DQ +#ifdef HEXL_HAS_AVX512IFMA +TEST(EltwiseCmpSubMod, AVX512_52) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + uint64_t length = 9; + uint64_t modulus = 1125896819525633; + + for (size_t trial = 0; trial < 200; ++trial) { + auto op1 = std::vector(length, 1106601337915084531); + uint64_t bound = 576460751967876096; + uint64_t diff = 3160741504001; + + auto op1_native = op1; + auto op1_avx512 = op1; + std::vector op1_out(op1.size(), 0); + std::vector op1_native_out(op1.size(), 0); + std::vector op1_avx512_out(op1.size(), 0); + + EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus, + intel::hexl::CMPINT::NLE, bound, diff); + EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(), + modulus, intel::hexl::CMPINT::NLE, bound, diff); + EltwiseCmpSubModAVX512<52>(op1_avx512_out.data(), op1.data(), op1.size(), + modulus, intel::hexl::CMPINT::NLE, bound, diff); + + ASSERT_EQ(op1_out, op1_native_out); + ASSERT_EQ(op1_native_out, op1_avx512_out); + } +} +#endif + +#ifdef HEXL_HAS_AVX512IFMA TEST(EltwiseCmpSubMod, AVX512) { if (!has_avx512dq) { GTEST_SKIP(); } uint64_t length = 172; - for (size_t cmp = 0; cmp < 8; ++cmp) { for (size_t bits = 48; bits <= 51; ++bits) { uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0]; @@ -48,8 +79,9 @@ TEST(EltwiseCmpSubMod, AVX512) { static_cast(cmp), bound, diff); EltwiseCmpSubModNative(op1a_out.data(), op1a.data(), op1a.size(), modulus, static_cast(cmp), bound, diff); - EltwiseCmpSubModAVX512(op1b_out.data(), op1b.data(), op1b.size(), - modulus, static_cast(cmp), bound, diff); + EltwiseCmpSubModAVX512<52>(op1b_out.data(), op1b.data(), op1b.size(), + modulus, static_cast(cmp), bound, + diff); ASSERT_EQ(op1_out, op1a_out); ASSERT_EQ(op1_out, op1b_out); @@ -57,6 +89,37 @@ TEST(EltwiseCmpSubMod, AVX512) { } } } + +TEST(EltwiseCmpSubMod, AVX512_64) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + uint64_t length = 9; + uint64_t modulus = 1152921504606748673; + + for (size_t trial = 0; trial < 200; ++trial) { + auto op1 = std::vector(length, 64961); + uint64_t bound = 576460752303415296; + uint64_t diff = 81920; + + auto op1_native = op1; + auto op1_avx512 = op1; + std::vector op1_out(op1.size(), 0); + std::vector op1_native_out(op1.size(), 0); + std::vector op1_avx512_out(op1.size(), 0); + + EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus, + intel::hexl::CMPINT::NLE, bound, diff); + EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(), + modulus, intel::hexl::CMPINT::NLE, bound, diff); + EltwiseCmpSubModAVX512<64>(op1_avx512_out.data(), op1.data(), op1.size(), + modulus, intel::hexl::CMPINT::NLE, bound, diff); + + ASSERT_EQ(op1_out, op1_native_out); + ASSERT_EQ(op1_native_out, op1_avx512_out); + } +} + #endif } // namespace hexl } // namespace intel diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 6b6b9853..e8bfff53 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -134,9 +134,6 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) { GTEST_SKIP(); } - std::random_device rd; - std::mt19937 gen(rd()); - size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { @@ -170,9 +167,6 @@ TEST(EltwiseReduceMod, AVX512Big_4_1) { GTEST_SKIP(); } - std::random_device rd; - std::mt19937 gen(rd()); - size_t length = 1024; for (size_t bits = 50; bits <= 62; ++bits) { @@ -263,6 +257,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) { } } } + #endif } // namespace hexl