Skip to content

Commit

Permalink
Fix EltwiseCmpSubMod (#84)
Browse files Browse the repository at this point in the history
* Fix EltwiseCmpSubMod

* Fix Windows build by fixing cpu-features commit
  • Loading branch information
fboemer authored Oct 22, 2021
1 parent ccb6063 commit 343acab
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 34 deletions.
8 changes: 4 additions & 4 deletions benchmark/bench-eltwise-cmp-sub-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,20 @@ BENCHMARK(BM_EltwiseCmpSubModNative)

#ifdef HEXL_HAS_AVX512DQ
// state[0] is the degree
static void BM_EltwiseCmpSubModAVX512(benchmark::State& state) { // NOLINT
static void BM_EltwiseCmpSubModAVX512_64(benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
uint64_t modulus = 100;
uint64_t bound = GenerateInsecureUniformRandomValue(0, modulus);
uint64_t diff = GenerateInsecureUniformRandomValue(1, modulus);
auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, modulus);

for (auto _ : state) {
EltwiseCmpSubModAVX512(input1.data(), input1.data(), input_size, modulus,
CMPINT::NLT, bound, diff);
EltwiseCmpSubModAVX512<64>(input1.data(), input1.data(), input_size,
modulus, CMPINT::NLT, bound, diff);
}
}

BENCHMARK(BM_EltwiseCmpSubModAVX512)
BENCHMARK(BM_EltwiseCmpSubModAVX512_64)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
Expand Down
2 changes: 1 addition & 1 deletion cmake/third-party/cpu-features/CMakeLists.txt.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ project(cpu-features-download NONE)
include(ExternalProject)
ExternalProject_Add(cpu_features
GIT_REPOSITORY https://github.com/google/cpu_features.git
GIT_TAG master
GIT_TAG 32b49eb5e7809052a28422cfde2f2745fbb0eb76 # master branch on Oct 20, 2021
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-build"
CONFIGURE_COMMAND ""
Expand Down
17 changes: 15 additions & 2 deletions hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace intel {
namespace hexl {

#ifdef HEXL_HAS_AVX512DQ
template <int BitShift = 64>
template <int BitShift>
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t n, uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff) {
Expand Down Expand Up @@ -51,12 +51,25 @@ void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t prod_right_shift = ceil_log_mod + beta;
__m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));

uint64_t alpha = BitShift - 2;
uint64_t mu_64 =
MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
modulus)
.BarrettFactor();

if (BitShift == 64) {
// Single-worded Barrett reduction.
mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor();
}

__m512i v_mu_64 = _mm512_set1_epi64(static_cast<int64_t>(mu_64));

for (size_t i = n / 8; i > 0; --i) {
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));

v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod);
v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod);

__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
Expand Down
2 changes: 0 additions & 2 deletions hexl/eltwise/eltwise-cmp-sub-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1,

for (size_t i = 0; i < n; ++i) {
uint64_t op = operand1[i];

bool op_cmp = Compare(cmp, op, bound);
op %= modulus;

if (op_cmp) {
op = SubUIntMod(op, diff, modulus);
}
Expand Down
3 changes: 2 additions & 1 deletion hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace hexl {
/// @param[in] bound Scalar to compare against
/// @param[in] diff Scalar to subtract by
/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1
/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0,
/// ..., n-1
void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n,
uint64_t modulus, CMPINT cmp, uint64_t bound,
uint64_t diff);
Expand Down
19 changes: 5 additions & 14 deletions hexl/util/avx512-util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <vector>

#include "hexl/logging/logging.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/check.hpp"
#include "hexl/util/defines.hpp"
Expand Down Expand Up @@ -389,6 +390,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
__mmask8 mask =
_mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT);
if (mask != 0) {
// values above 2^52
__m512i x_hi = _mm512_srli_epi64(x, static_cast<unsigned int>(52ULL));
__m512i x_intr = _mm512_slli_epi64(x, static_cast<unsigned int>(12ULL));
__m512i x_lo =
Expand All @@ -408,33 +410,22 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
} else {
__m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52);
// Barrett subtraction
// tmp[0] = input - tmp[1] * q;
__m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q);
x = _mm512_sub_epi64(x, tmp1_times_mod);
}
}
#endif
if (BitShift == 64) {
__m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64);
// Barrett subtraction
// tmp[0] = input - tmp[1] * q;
__m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q);
x = _mm512_sub_epi64(x, tmp1_times_mod);
}

// Correction
if (OutputModFactor == 2) {
return x;
} else {
if (BitShift == 64) {
x = _mm512_hexl_small_mod_epu64(x, q);
}
if (BitShift == 52) {
x = _mm512_hexl_small_mod_epu64<2>(x, q);
}
return x;
if (OutputModFactor == 1) {
x = _mm512_hexl_small_mod_epu64<2>(x, q);
}
return x;
}

// Concatenate packed 64-bit integers in x and y, producing an intermediate
Expand Down
71 changes: 67 additions & 4 deletions test/test-eltwise-cmp-sub-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,45 @@ namespace intel {
namespace hexl {

// Checks AVX512 and native implementations match
#ifdef HEXL_HAS_AVX512DQ
#ifdef HEXL_HAS_AVX512IFMA
TEST(EltwiseCmpSubMod, AVX512_52) {
if (!has_avx512dq) {
GTEST_SKIP();
}
uint64_t length = 9;
uint64_t modulus = 1125896819525633;

for (size_t trial = 0; trial < 200; ++trial) {
auto op1 = std::vector<uint64_t>(length, 1106601337915084531);
uint64_t bound = 576460751967876096;
uint64_t diff = 3160741504001;

auto op1_native = op1;
auto op1_avx512 = op1;
std::vector<uint64_t> op1_out(op1.size(), 0);
std::vector<uint64_t> op1_native_out(op1.size(), 0);
std::vector<uint64_t> op1_avx512_out(op1.size(), 0);

EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus,
intel::hexl::CMPINT::NLE, bound, diff);
EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(),
modulus, intel::hexl::CMPINT::NLE, bound, diff);
EltwiseCmpSubModAVX512<52>(op1_avx512_out.data(), op1.data(), op1.size(),
modulus, intel::hexl::CMPINT::NLE, bound, diff);

ASSERT_EQ(op1_out, op1_native_out);
ASSERT_EQ(op1_native_out, op1_avx512_out);
}
}
#endif

#ifdef HEXL_HAS_AVX512IFMA
TEST(EltwiseCmpSubMod, AVX512) {
if (!has_avx512dq) {
GTEST_SKIP();
}

uint64_t length = 172;

for (size_t cmp = 0; cmp < 8; ++cmp) {
for (size_t bits = 48; bits <= 51; ++bits) {
uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0];
Expand All @@ -48,15 +79,47 @@ TEST(EltwiseCmpSubMod, AVX512) {
static_cast<CMPINT>(cmp), bound, diff);
EltwiseCmpSubModNative(op1a_out.data(), op1a.data(), op1a.size(),
modulus, static_cast<CMPINT>(cmp), bound, diff);
EltwiseCmpSubModAVX512(op1b_out.data(), op1b.data(), op1b.size(),
modulus, static_cast<CMPINT>(cmp), bound, diff);
EltwiseCmpSubModAVX512<52>(op1b_out.data(), op1b.data(), op1b.size(),
modulus, static_cast<CMPINT>(cmp), bound,
diff);

ASSERT_EQ(op1_out, op1a_out);
ASSERT_EQ(op1_out, op1b_out);
}
}
}
}

TEST(EltwiseCmpSubMod, AVX512_64) {
if (!has_avx512dq) {
GTEST_SKIP();
}
uint64_t length = 9;
uint64_t modulus = 1152921504606748673;

for (size_t trial = 0; trial < 200; ++trial) {
auto op1 = std::vector<uint64_t>(length, 64961);
uint64_t bound = 576460752303415296;
uint64_t diff = 81920;

auto op1_native = op1;
auto op1_avx512 = op1;
std::vector<uint64_t> op1_out(op1.size(), 0);
std::vector<uint64_t> op1_native_out(op1.size(), 0);
std::vector<uint64_t> op1_avx512_out(op1.size(), 0);

EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus,
intel::hexl::CMPINT::NLE, bound, diff);
EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(),
modulus, intel::hexl::CMPINT::NLE, bound, diff);
EltwiseCmpSubModAVX512<64>(op1_avx512_out.data(), op1.data(), op1.size(),
modulus, intel::hexl::CMPINT::NLE, bound, diff);

ASSERT_EQ(op1_out, op1_native_out);
ASSERT_EQ(op1_native_out, op1_avx512_out);
}
}

#endif
} // namespace hexl
} // namespace intel
7 changes: 1 addition & 6 deletions test/test-eltwise-reduce-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,6 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) {
GTEST_SKIP();
}

std::random_device rd;
std::mt19937 gen(rd());

size_t length = 1024;

for (size_t bits = 50; bits <= 62; ++bits) {
Expand Down Expand Up @@ -170,9 +167,6 @@ TEST(EltwiseReduceMod, AVX512Big_4_1) {
GTEST_SKIP();
}

std::random_device rd;
std::mt19937 gen(rd());

size_t length = 1024;

for (size_t bits = 50; bits <= 62; ++bits) {
Expand Down Expand Up @@ -263,6 +257,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
}
}
}

#endif

} // namespace hexl
Expand Down

0 comments on commit 343acab

Please sign in to comment.