From accf7a5a1a26516ab36288aa7e757355971905bc Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Mon, 1 Nov 2021 11:34:34 -0700 Subject: [PATCH] Fix EltwiseReduceMod (#90) * Fix AVVX512DQ EltwiseReduceMod --- hexl/eltwise/eltwise-reduce-mod.cpp | 19 ++++++++----- test/test-eltwise-reduce-mod.cpp | 44 +++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index aebdf805..48164cd9 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -97,18 +97,23 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, } return; } + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && modulus < (1ULL << 52)) { + EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + return; + } +#endif + #ifdef HEXL_HAS_AVX512DQ if (has_avx512dq) { - if (modulus < (1ULL << 52)) { - EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, - output_mod_factor); - } else { - EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor, - output_mod_factor); - } + EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); return; } #endif + HEXL_VLOG(3, "Calling EltwiseReduceModNative"); EltwiseReduceModNative(result, operand, n, modulus, input_mod_factor, output_mod_factor); diff --git a/test/test-eltwise-reduce-mod.cpp b/test/test-eltwise-reduce-mod.cpp index 9e5480ef..31f55072 100644 --- a/test/test-eltwise-reduce-mod.cpp +++ b/test/test-eltwise-reduce-mod.cpp @@ -10,6 +10,7 @@ #include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" #include "test-util.hpp" +#include "util/util-internal.hpp" namespace intel { namespace hexl { @@ -79,5 +80,48 @@ TEST(EltwiseReduceMod, 4_2) { CheckEqual(result, exp_out); } +// First parameter is the number of bits in the modulus +// Second parameter is whether or not to prefer small moduli +class EltwiseReduceModTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + m_modulus_bits = std::get<0>(GetParam()); + m_prefer_small_primes = std::get<1>(GetParam()); + m_modulus = GeneratePrimes(1, m_modulus_bits, m_prefer_small_primes)[0]; + } + + void TearDown() override {} + + public: + uint64_t m_N{1024 + 7}; // m_N % 8 = 7 to test AVX512 boundary case + uint64_t m_modulus_bits; + bool m_prefer_small_primes; + uint64_t m_modulus; +}; + +// Test public API matches Native implementation on random values +TEST_P(EltwiseReduceModTest, Random) { + uint64_t upper_bound = + m_modulus < (1ULL << 32) ? m_modulus * m_modulus : 1ULL << 63; + + auto input = GenerateInsecureUniformRandomValues(m_N, 0, upper_bound); + std::vector result_native(m_N, 0); + std::vector result_public_api(m_N, 0); + + EltwiseReduceModNative(result_native.data(), input.data(), m_N, m_modulus, + m_modulus, 1); + EltwiseReduceMod(result_public_api.data(), input.data(), m_N, m_modulus, + m_modulus, 1); + AssertEqual(result_native, result_public_api); +} + +INSTANTIATE_TEST_SUITE_P( + EltwiseReduceMod, EltwiseReduceModTest, + ::testing::Combine(::testing::ValuesIn(AlignedVector64{ + 20, 25, 30, 31, 32, 33, 35, 40, 48, 49, 50, 51, 52, + 55, 58, 59, 60}), + ::testing::ValuesIn(std::vector{false, true}))); + } // namespace hexl } // namespace intel