diff --git a/cryptomite/pycryptomite.cpp b/cryptomite/pycryptomite.cpp index d412649..f770e86 100644 --- a/cryptomite/pycryptomite.cpp +++ b/cryptomite/pycryptomite.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include @@ -27,7 +28,13 @@ PYBIND11_MODULE(_cryptomite, m) { py::class_(m, "NTT") .def(py::init()) - .def("ntt", &NTT::ntt); + .def("ntt", &NTT::ntt) + .def("mul_vec", &NTT::mul_vec) + .def("conv", &NTT::conv); - m.def("mul_vec", &mul_vec); + py::class_(m, "BigNTT") + .def(py::init()) + .def("ntt", &BigNTT::ntt) + .def("mul_vec", &BigNTT::mul_vec) + .def("conv", &BigNTT::conv); } diff --git a/cryptomite/utils.py b/cryptomite/utils.py index 8b53d9b..1b2740e 100644 --- a/cryptomite/utils.py +++ b/cryptomite/utils.py @@ -6,7 +6,7 @@ from math import sqrt from typing import Literal, Sequence -from cryptomite._cryptomite import NTT, mul_vec +from cryptomite._cryptomite import BigNTT, NTT __all__ = ['is_prime', 'prime_facto', 'previous_prime', 'next_prime', 'closest_prime', 'previous_na_set', 'next_na_set', @@ -45,12 +45,12 @@ def conv(l: int, source1: Sequence[int], source2: Sequence[int]) -> list[int]: """ L = 1 << l assert len(source1) == len(source2) == L - ntt = NTT(l) - ntt_source1 = ntt.ntt(source1, False) - ntt_source2 = ntt.ntt(source2, False) - mul_source = mul_vec(ntt_source1, ntt_source2) - conv_output = ntt.ntt(mul_source, True) - return conv_output + ntt = BigNTT(l) if l > 30 else NTT(l) + # ntt_source1 = ntt.ntt(source1, False) + # ntt_source2 = ntt.ntt(source2, False) + # mul_source = ntt.mul_vec(ntt_source1, ntt_source2) + # conv_output = ntt.ntt(mul_source, True) + return ntt.conv(source1, source2) def is_prime(n: int) -> bool: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9a45a7b..5e205c8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(trevisan trevisan.cpp irreducible_poly.cpp ntt.cpp) +add_library(trevisan trevisan.cpp irreducible_poly.cpp ntt.cpp bigntt.cpp) target_include_directories(trevisan PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/bigntt.cpp b/src/bigntt.cpp new file mode 100644 index 0000000..d22c8b2 --- /dev/null +++ b/src/bigntt.cpp @@ -0,0 +1,154 @@ +#include "bigntt.h" + +#include +#include + +#define P ((9ull<<42) + 1) +#define G 5 // primitive root mod P + +uint64_t add(uint64_t a, uint64_t b) { + __int128_t c = a; + c += b; + __int128_t d = c - P; + __int128_t e = d >> 64; + // return d if d >= 0 else c + return (c&e) | (d&~e); +} + +/** + * Subtraction: a-b mod P + * + * @pre a,b < P + */ +static uint64_t sub(uint64_t a, uint64_t b) { + __int128_t c = a; + c -= b; + __int128_t d = c + P; + __int128_t e = c >> 64; + // return c if c >= 0 else d + return (c&~e) | (d&e); +} + +uint64_t mul(uint64_t a, uint64_t b) { + __int128_t n = a; n*= b; + return n % P; +} + +/** + * Modular exponentiation: a^e mod P + */ +static uint64_t modexp(uint64_t a, uint64_t e) { + // e is not secret, no need to make constant time + uint64_t r = 1; + while (e) { + if (e&1) { + r = mul(r, a); + } + e >>= 1; + a = mul(a, a); + } + return r; +} + +/** + * Reverse the bits of x (an l-bit number) + */ +static uint64_t reverse_bits(unsigned l, uint64_t x) { + uint64_t y = 0; + while (l) { + l--; + y |= ((x&1) << l); + x >>= 1; + } + return y; +} + +BigNTT::BigNTT(unsigned l) : L(1ll< 40) { + throw std::runtime_error("Must have 1 <= l <= 40."); + } + + Linv = modexp(L, P-2); + + uint64_t half_L = L/2; + + R = std::vector(half_L); + Rinv = std::vector(half_L); + revbits = std::vector(L); + + uint64_t r = modexp(G, (P - 1) >> l); // primitive L'th root of unity + + { + { + uint64_t t = 1; + for (uint64_t i = 0; i < half_L; i++) { + R[i] = t; + t = mul(t, r); + } + } + + { + // r^(L/2) = -1 + uint64_t t = P - 1; + for (uint64_t i = 1; i <= half_L; i++) { + t = mul(t, r); + Rinv[half_L - i] = t; + } + } + } + + for (uint64_t i = 0; i < L; i++) { + revbits[i] = reverse_bits(l, i); + } + +} + +std::vector BigNTT::ntt(const std::vector &x, bool inverse) { + const std::vector& U = inverse ? Rinv : R; + + std::vector y(L, 0); + + // Bit inversion + for (uint64_t i = 0; i < L; i++) { + y[revbits[i]] = x[i]; + } + + // Main loop + for ( + uint64_t h = 2, k = 1, u = L/2; + h <= L; + k = h, h <<= 1, u >>= 1) + { + for (uint64_t i = 0; i < L; i += h) { + for (uint64_t j = 0, v = 0; j < k; j++, v += u) { + uint64_t r = i + j; + uint64_t s = r + k; + uint64_t a = y[r]; + uint64_t b = mul(y[s], U[v]); + y[r] = add(a, b); + y[s] = sub(a, b); + } + } + } + + // Normalization for inverse + if (inverse) { + for (uint64_t i = 0; i < L; i++) { + y[i] = mul(Linv, y[i]); + } + } + return y; +} + +std::vector BigNTT::mul_vec(const std::vector &a, const std::vector &b) { + std::vector c(a.size()); + for (uint64_t i = 0; i < a.size(); i++) { + c[i] = mul(a[i], b[i]); + } + return c; +} + +std::vector BigNTT::conv(const std::vector &a, const std::vector &b) { + std::vector c = mul_vec(ntt(a, false), ntt(b, false)); + return ntt(c, true); +} \ No newline at end of file diff --git a/src/bigntt.h b/src/bigntt.h new file mode 100644 index 0000000..23c06c6 --- /dev/null +++ b/src/bigntt.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +class BigNTT { + private: + /** Sequence length (power of 2) */ + uint64_t L; + + /** Inverse of L mod p */ + uint64_t Linv; + + /** + * Powers 1, r, r^2, ..., r^(L/2-1) mod p, where r is a primitive L'th + * root of unity mod p + */ + std::vector R; + + /** + * Inverse powers 1, r^{-1}, r^{-2}, ..., r{-^(L/2-1)} mod p + */ + std::vector Rinv; + + /** + * Lookup table for bit reversals + */ + std::vector revbits; + + public: + explicit BigNTT(unsigned l); + + std::vector ntt(const std::vector &x, bool inverse); + + std::vector mul_vec(const std::vector &a, const std::vector &b); + std::vector conv(const std::vector &a, const std::vector &b); +}; diff --git a/src/ntt.cpp b/src/ntt.cpp index 9156f69..56fa1b8 100644 --- a/src/ntt.cpp +++ b/src/ntt.cpp @@ -11,6 +11,7 @@ uint32_t add(uint32_t a, uint32_t b) { c += b; uint64_t d = c - P; uint64_t e = d >> 32; + // return d if d >= 0 else c return (c&e) | (d&~e); } @@ -24,6 +25,7 @@ static uint32_t sub(uint32_t a, uint32_t b) { c -= b; uint64_t d = c + P; uint64_t e = c >> 32; + // return c if c >= 0 else d return (c&~e) | (d&e); } @@ -32,14 +34,6 @@ uint32_t mul(uint32_t a, uint32_t b) { return n % P; } -std::vector mul_vec(const std::vector &a, const std::vector &b) { - std::vector c(a.size()); - for (uint32_t i = 0; i < a.size(); i++) { - c[i] = mul(a[i], b[i]); - } - return c; -} - /** * Modular exponentiation: a^e mod P */ @@ -145,3 +139,16 @@ std::vector NTT::ntt(const std::vector &x, bool inverse) { } return y; } + +std::vector NTT::mul_vec(const std::vector &a, const std::vector &b) { + std::vector c(a.size()); + for (uint32_t i = 0; i < a.size(); i++) { + c[i] = mul(a[i], b[i]); + } + return c; +} + +std::vector NTT::conv(const std::vector &a, const std::vector &b) { + std::vector c = mul_vec(ntt(a, false), ntt(b, false)); + return ntt(c, true); +} \ No newline at end of file diff --git a/src/ntt.h b/src/ntt.h index 30d3b1f..8134f2a 100644 --- a/src/ntt.h +++ b/src/ntt.h @@ -3,8 +3,6 @@ #include #include -std::vector mul_vec(const std::vector &a, const std::vector &b); - class NTT { private: /** Sequence length (power of 2) */ @@ -33,4 +31,7 @@ class NTT { explicit NTT(unsigned l); std::vector ntt(const std::vector &x, bool inverse); + + std::vector mul_vec(const std::vector &a, const std::vector &b); + std::vector conv(const std::vector &a, const std::vector &b); }; diff --git a/test/test_ntt.py b/test/test_ntt.py new file mode 100644 index 0000000..8262eb8 --- /dev/null +++ b/test/test_ntt.py @@ -0,0 +1,47 @@ +import pytest +from cryptomite._cryptomite import BigNTT, NTT +import numpy as np + +test_range = list(range(2, 21)) + + +def slow_conv(a, b): + """ direct implementation """ + c = [0] * len(a) + for i in range(len(a)): + for j in range(len(b)): + c[(i + j) % len(c)] += a[i] * b[j] + return c + + +@pytest.mark.parametrize('n', test_range) +def test_ntt_inv(n): + ntt = NTT(n) + for _ in range(10): + v = np.random.randint(0, 1 << n, 1 << n).tolist() + assert ntt.ntt(ntt.ntt(v, False), True) == v + + +@pytest.mark.parametrize('n', test_range) +def test_big_ntt_inv(n): + ntt = BigNTT(n) + for _ in range(10): + v = np.random.randint(0, 1 << n, 1 << n).tolist() + assert ntt.ntt(ntt.ntt(v, False), True) == v + + +@pytest.mark.parametrize('n', list(range(2, 11))) +def test_ntt_conv(n): + ntt = NTT(n) + a = np.random.randint(0, 2, 1 << n).tolist() + b = np.random.randint(0, 2, 1 << n).tolist() + assert ntt.conv(a, b) == slow_conv(a, b) + + +@pytest.mark.parametrize('n', test_range) +def test_big_ntt_conv(n): + ntt = NTT(n) + big_ntt = BigNTT(n) + a = np.random.randint(0, 2, 1 << n).tolist() + b = np.random.randint(0, 2, 1 << n).tolist() + assert ntt.conv(a, b) == big_ntt.conv(a, b)