Skip to content

Commit

Permalink
Add BigNTT and do conv in one step
Browse files Browse the repository at this point in the history
  • Loading branch information
y-richie-y committed Jan 20, 2024
1 parent d0e0165 commit 25ad7d6
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 20 deletions.
11 changes: 9 additions & 2 deletions cryptomite/pycryptomite.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <ntt.h>
#include <bigntt.h>
#include <trevisan.cpp>


Expand All @@ -27,7 +28,13 @@ PYBIND11_MODULE(_cryptomite, m) {

py::class_<NTT>(m, "NTT")
.def(py::init<int>())
.def("ntt", &NTT::ntt);
.def("ntt", &NTT::ntt)
.def("mul_vec", &NTT::mul_vec)
.def("conv", &NTT::conv);

m.def("mul_vec", &mul_vec);
py::class_<BigNTT>(m, "BigNTT")
.def(py::init<int>())
.def("ntt", &BigNTT::ntt)
.def("mul_vec", &BigNTT::mul_vec)
.def("conv", &BigNTT::conv);
}
14 changes: 7 additions & 7 deletions cryptomite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from math import sqrt
from typing import Literal, Sequence

from cryptomite._cryptomite import NTT, mul_vec
from cryptomite._cryptomite import BigNTT, NTT

__all__ = ['is_prime', 'prime_facto', 'previous_prime', 'next_prime',
'closest_prime', 'previous_na_set', 'next_na_set',
Expand Down Expand Up @@ -45,12 +45,12 @@ def conv(l: int, source1: Sequence[int], source2: Sequence[int]) -> list[int]:
"""
L = 1 << l
assert len(source1) == len(source2) == L
ntt = NTT(l)
ntt_source1 = ntt.ntt(source1, False)
ntt_source2 = ntt.ntt(source2, False)
mul_source = mul_vec(ntt_source1, ntt_source2)
conv_output = ntt.ntt(mul_source, True)
return conv_output
ntt = BigNTT(l) if l > 30 else NTT(l)
# ntt_source1 = ntt.ntt(source1, False)
# ntt_source2 = ntt.ntt(source2, False)
# mul_source = ntt.mul_vec(ntt_source1, ntt_source2)
# conv_output = ntt.ntt(mul_source, True)
return ntt.conv(source1, source2)


def is_prime(n: int) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_library(trevisan trevisan.cpp irreducible_poly.cpp ntt.cpp)
add_library(trevisan trevisan.cpp irreducible_poly.cpp ntt.cpp bigntt.cpp)

target_include_directories(trevisan PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

Expand Down
154 changes: 154 additions & 0 deletions src/bigntt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "bigntt.h"

#include <stdexcept>
#include <vector>

#define P ((9ull<<42) + 1)
#define G 5 // primitive root mod P

uint64_t add(uint64_t a, uint64_t b) {
__int128_t c = a;
c += b;
__int128_t d = c - P;
__int128_t e = d >> 64;
// return d if d >= 0 else c
return (c&e) | (d&~e);
}

/**
* Subtraction: a-b mod P
*
* @pre a,b < P
*/
static uint64_t sub(uint64_t a, uint64_t b) {
__int128_t c = a;
c -= b;
__int128_t d = c + P;
__int128_t e = c >> 64;
// return c if c >= 0 else d
return (c&~e) | (d&e);
}

uint64_t mul(uint64_t a, uint64_t b) {
__int128_t n = a; n*= b;
return n % P;
}

/**
* Modular exponentiation: a^e mod P
*/
static uint64_t modexp(uint64_t a, uint64_t e) {
// e is not secret, no need to make constant time
uint64_t r = 1;
while (e) {
if (e&1) {
r = mul(r, a);
}
e >>= 1;
a = mul(a, a);
}
return r;
}

/**
* Reverse the bits of x (an l-bit number)
*/
static uint64_t reverse_bits(unsigned l, uint64_t x) {
uint64_t y = 0;
while (l) {
l--;
y |= ((x&1) << l);
x >>= 1;
}
return y;
}

BigNTT::BigNTT(unsigned l) : L(1ll<<l) {
if (l < 1 || l > 40) {
throw std::runtime_error("Must have 1 <= l <= 40.");
}

Linv = modexp(L, P-2);

uint64_t half_L = L/2;

R = std::vector<uint64_t>(half_L);
Rinv = std::vector<uint64_t>(half_L);
revbits = std::vector<uint64_t>(L);

uint64_t r = modexp(G, (P - 1) >> l); // primitive L'th root of unity

{
{
uint64_t t = 1;
for (uint64_t i = 0; i < half_L; i++) {
R[i] = t;
t = mul(t, r);
}
}

{
// r^(L/2) = -1
uint64_t t = P - 1;
for (uint64_t i = 1; i <= half_L; i++) {
t = mul(t, r);
Rinv[half_L - i] = t;
}
}
}

for (uint64_t i = 0; i < L; i++) {
revbits[i] = reverse_bits(l, i);
}

}

std::vector<uint64_t> BigNTT::ntt(const std::vector<uint64_t> &x, bool inverse) {
const std::vector<uint64_t>& U = inverse ? Rinv : R;

std::vector<uint64_t> y(L, 0);

// Bit inversion
for (uint64_t i = 0; i < L; i++) {
y[revbits[i]] = x[i];
}

// Main loop
for (
uint64_t h = 2, k = 1, u = L/2;
h <= L;
k = h, h <<= 1, u >>= 1)
{
for (uint64_t i = 0; i < L; i += h) {
for (uint64_t j = 0, v = 0; j < k; j++, v += u) {
uint64_t r = i + j;
uint64_t s = r + k;
uint64_t a = y[r];
uint64_t b = mul(y[s], U[v]);
y[r] = add(a, b);
y[s] = sub(a, b);
}
}
}

// Normalization for inverse
if (inverse) {
for (uint64_t i = 0; i < L; i++) {
y[i] = mul(Linv, y[i]);
}
}
return y;
}

std::vector<uint64_t> BigNTT::mul_vec(const std::vector<uint64_t> &a, const std::vector<uint64_t> &b) {
std::vector<uint64_t> c(a.size());
for (uint64_t i = 0; i < a.size(); i++) {
c[i] = mul(a[i], b[i]);
}
return c;
}

std::vector<uint64_t> BigNTT::conv(const std::vector<uint64_t> &a, const std::vector<uint64_t> &b) {
std::vector<uint64_t> c = mul_vec(ntt(a, false), ntt(b, false));
return ntt(c, true);
}
37 changes: 37 additions & 0 deletions src/bigntt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <cstdint>
#include <vector>

class BigNTT {
private:
/** Sequence length (power of 2) */
uint64_t L;

/** Inverse of L mod p */
uint64_t Linv;

/**
* Powers 1, r, r^2, ..., r^(L/2-1) mod p, where r is a primitive L'th
* root of unity mod p
*/
std::vector<uint64_t> R;

/**
* Inverse powers 1, r^{-1}, r^{-2}, ..., r{-^(L/2-1)} mod p
*/
std::vector<uint64_t> Rinv;

/**
* Lookup table for bit reversals
*/
std::vector<uint64_t> revbits;

public:
explicit BigNTT(unsigned l);

std::vector<uint64_t> ntt(const std::vector<uint64_t> &x, bool inverse);

std::vector<uint64_t> mul_vec(const std::vector<uint64_t> &a, const std::vector<uint64_t> &b);
std::vector<uint64_t> conv(const std::vector<uint64_t> &a, const std::vector<uint64_t> &b);
};
23 changes: 15 additions & 8 deletions src/ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ uint32_t add(uint32_t a, uint32_t b) {
c += b;
uint64_t d = c - P;
uint64_t e = d >> 32;
// return d if d >= 0 else c
return (c&e) | (d&~e);
}

Expand All @@ -24,6 +25,7 @@ static uint32_t sub(uint32_t a, uint32_t b) {
c -= b;
uint64_t d = c + P;
uint64_t e = c >> 32;
// return c if c >= 0 else d
return (c&~e) | (d&e);
}

Expand All @@ -32,14 +34,6 @@ uint32_t mul(uint32_t a, uint32_t b) {
return n % P;
}

std::vector<uint32_t> mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b) {
std::vector<uint32_t> c(a.size());
for (uint32_t i = 0; i < a.size(); i++) {
c[i] = mul(a[i], b[i]);
}
return c;
}

/**
* Modular exponentiation: a^e mod P
*/
Expand Down Expand Up @@ -145,3 +139,16 @@ std::vector<uint32_t> NTT::ntt(const std::vector<uint32_t> &x, bool inverse) {
}
return y;
}

std::vector<uint32_t> NTT::mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b) {
std::vector<uint32_t> c(a.size());
for (uint32_t i = 0; i < a.size(); i++) {
c[i] = mul(a[i], b[i]);
}
return c;
}

std::vector<uint32_t> NTT::conv(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b) {
std::vector<uint32_t> c = mul_vec(ntt(a, false), ntt(b, false));
return ntt(c, true);
}
5 changes: 3 additions & 2 deletions src/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include <cstdint>
#include <vector>

std::vector<uint32_t> mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b);

class NTT {
private:
/** Sequence length (power of 2) */
Expand Down Expand Up @@ -33,4 +31,7 @@ class NTT {
explicit NTT(unsigned l);

std::vector<uint32_t> ntt(const std::vector<uint32_t> &x, bool inverse);

std::vector<uint32_t> mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b);
std::vector<uint32_t> conv(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b);
};
47 changes: 47 additions & 0 deletions test/test_ntt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from cryptomite._cryptomite import BigNTT, NTT
import numpy as np

test_range = list(range(2, 21))


def slow_conv(a, b):
""" direct implementation """
c = [0] * len(a)
for i in range(len(a)):
for j in range(len(b)):
c[(i + j) % len(c)] += a[i] * b[j]
return c


@pytest.mark.parametrize('n', test_range)
def test_ntt_inv(n):
ntt = NTT(n)
for _ in range(10):
v = np.random.randint(0, 1 << n, 1 << n).tolist()
assert ntt.ntt(ntt.ntt(v, False), True) == v


@pytest.mark.parametrize('n', test_range)
def test_big_ntt_inv(n):
ntt = BigNTT(n)
for _ in range(10):
v = np.random.randint(0, 1 << n, 1 << n).tolist()
assert ntt.ntt(ntt.ntt(v, False), True) == v


@pytest.mark.parametrize('n', list(range(2, 11)))
def test_ntt_conv(n):
ntt = NTT(n)
a = np.random.randint(0, 2, 1 << n).tolist()
b = np.random.randint(0, 2, 1 << n).tolist()
assert ntt.conv(a, b) == slow_conv(a, b)


@pytest.mark.parametrize('n', test_range)
def test_big_ntt_conv(n):
ntt = NTT(n)
big_ntt = BigNTT(n)
a = np.random.randint(0, 2, 1 << n).tolist()
b = np.random.randint(0, 2, 1 << n).tolist()
assert ntt.conv(a, b) == big_ntt.conv(a, b)

0 comments on commit 25ad7d6

Please sign in to comment.