Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_in_subgroup for G2 and make G2 ops generic. #255

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
E12TFrobeniusSquareCircuit,
E12TInverseCircuit,
E12TMulCircuit,
FP2MulCircuit,
FP6NegCircuit,
TowerMillerBit0,
TowerMillerBit1,
Expand Down Expand Up @@ -136,6 +137,7 @@ class CircuitID(Enum):
)
ADD_EC_POINT_G2 = int.from_bytes(b"add_ec_point_g2", "big")
DOUBLE_EC_POINT_G2 = int.from_bytes(b"double_ec_point_g2", "big")
FP2_MUL = int.from_bytes(b"fp2_mul", "big")
FULL_ECIP_BATCHED = int.from_bytes(b"full_ecip__batched", "big")


Expand Down Expand Up @@ -403,6 +405,12 @@ class CircuitID(Enum):
"filename": "tower_circuits",
"curve_ids": [CurveID.BLS12_381],
},
CircuitID.FP2_MUL: {
"class": FP2MulCircuit,
"params": None,
"filename": "tower_circuits",
"curve_ids": [CurveID.BLS12_381, CurveID.BN254],
},
# CircuitID.HONK_SUMCHECK_CIRCUIT: {
# "class": SumCheckCircuit,
# "params": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,37 @@ def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit.extend_struct_output(structs.u384(name="zc1b1a1", elmts=[zc1b1[1]]))

return circuit


class FP2MulCircuit(BaseEXTFCircuit):
def __init__(
self,
curve_id: int,
auto_run: bool = True,
init_hash: int = None,
compilation_mode: int = 1,
):
super().__init__("fp2_mul", curve_id, auto_run, init_hash, compilation_mode)

def build_input(self) -> list[PyFelt]:
input = []
input.extend([self.field.random() for _ in range(4)])

return input

def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit = ModuloCircuit(
name=self.name,
curve_id=self.curve_id,
compilation_mode=self.compilation_mode,
)
a0 = circuit.write_struct(structs.u384(name="a0", elmts=[input.pop(0)]))
a1 = circuit.write_struct(structs.u384(name="a1", elmts=[input.pop(0)]))
b0 = circuit.write_struct(structs.u384(name="b0", elmts=[input.pop(0)]))
b1 = circuit.write_struct(structs.u384(name="b1", elmts=[input.pop(0)]))
res = circuit.fp2_mul([a0, a1], [b0, b1])

circuit.extend_struct_output(structs.u384(name="res0", elmts=[res[0]]))
circuit.extend_struct_output(structs.u384(name="res1", elmts=[res[1]]))

return circuit
70 changes: 63 additions & 7 deletions src/src/circuits/tower_circuits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,34 @@ pub fn run_BLS12_381_E12T_MUL_circuit(X: E12T, Y: E12T) -> (E12T,) {
return (res,);
}
#[inline(always)]
pub fn run_BLS12_381_FP2_MUL_circuit(a0: u384, a1: u384, b0: u384, b1: u384) -> (u384, u384) {
// INPUT stack
let (in0, in1, in2) = (CE::<CI<0>> {}, CE::<CI<1>> {}, CE::<CI<2>> {});
let in3 = CE::<CI<3>> {};
let t0 = circuit_mul(in0, in2); // Fp2 mul start
let t1 = circuit_mul(in1, in3);
let t2 = circuit_sub(t0, t1); // Fp2 mul real part end
let t3 = circuit_mul(in0, in3);
let t4 = circuit_mul(in1, in2);
let t5 = circuit_add(t3, t4); // Fp2 mul imag part end

let modulus = get_BLS12_381_modulus(); // BLS12_381 prime field modulus

let mut circuit_inputs = (t2, t5).new_inputs();
// Prefill constants:

// Fill inputs:
circuit_inputs = circuit_inputs.next_2(a0); // in0
circuit_inputs = circuit_inputs.next_2(a1); // in1
circuit_inputs = circuit_inputs.next_2(b0); // in2
circuit_inputs = circuit_inputs.next_2(b1); // in3

let outputs = circuit_inputs.done_2().eval(modulus).unwrap();
let res0: u384 = outputs.get_output(t2);
let res1: u384 = outputs.get_output(t5);
return (res0, res1);
}
#[inline(always)]
pub fn run_BLS12_381_TOWER_MILLER_BIT0_1P_circuit(
yInv_0: u384, xNegOverY_0: u384, Q_0: G2Point, M_i: E12T,
) -> (G2Point, E12T) {
Expand Down Expand Up @@ -4308,6 +4336,34 @@ pub fn run_BN254_E12T_MUL_circuit(X: E12T, Y: E12T) -> (E12T,) {
return (res,);
}
#[inline(always)]
pub fn run_BN254_FP2_MUL_circuit(a0: u384, a1: u384, b0: u384, b1: u384) -> (u384, u384) {
// INPUT stack
let (in0, in1, in2) = (CE::<CI<0>> {}, CE::<CI<1>> {}, CE::<CI<2>> {});
let in3 = CE::<CI<3>> {};
let t0 = circuit_mul(in0, in2); // Fp2 mul start
let t1 = circuit_mul(in1, in3);
let t2 = circuit_sub(t0, t1); // Fp2 mul real part end
let t3 = circuit_mul(in0, in3);
let t4 = circuit_mul(in1, in2);
let t5 = circuit_add(t3, t4); // Fp2 mul imag part end

let modulus = get_BN254_modulus(); // BN254 prime field modulus

let mut circuit_inputs = (t2, t5).new_inputs();
// Prefill constants:

// Fill inputs:
circuit_inputs = circuit_inputs.next_2(a0); // in0
circuit_inputs = circuit_inputs.next_2(a1); // in1
circuit_inputs = circuit_inputs.next_2(b0); // in2
circuit_inputs = circuit_inputs.next_2(b1); // in3

let outputs = circuit_inputs.done_2().eval(modulus).unwrap();
let res0: u384 = outputs.get_output(t2);
let res1: u384 = outputs.get_output(t5);
return (res0, res1);
}
#[inline(always)]
pub fn run_BN254_TOWER_MILLER_BIT0_1P_circuit(
yInv_0: u384, xNegOverY_0: u384, Q_0: G2Point, M_i: E12T,
) -> (G2Point, E12T) {
Expand Down Expand Up @@ -6045,12 +6101,12 @@ mod tests {
run_BLS12_381_E12T_DECOMP_KARABINA_I_Z_circuit, run_BLS12_381_E12T_FROBENIUS_CUBE_circuit,
run_BLS12_381_E12T_FROBENIUS_SQUARE_circuit, run_BLS12_381_E12T_FROBENIUS_circuit,
run_BLS12_381_E12T_INVERSE_circuit, run_BLS12_381_E12T_MUL_circuit,
run_BLS12_381_TOWER_MILLER_BIT0_1P_circuit, run_BLS12_381_TOWER_MILLER_BIT1_1P_circuit,
run_BLS12_381_TOWER_MILLER_INIT_BIT_1P_circuit, run_BN254_E12T_CYCLOTOMIC_SQUARE_circuit,
run_BN254_E12T_FROBENIUS_CUBE_circuit, run_BN254_E12T_FROBENIUS_SQUARE_circuit,
run_BN254_E12T_FROBENIUS_circuit, run_BN254_E12T_INVERSE_circuit,
run_BN254_E12T_MUL_circuit, run_BN254_TOWER_MILLER_BIT0_1P_circuit,
run_BN254_TOWER_MILLER_BIT1_1P_circuit, run_BN254_TOWER_MILLER_FINALIZE_BN_1P_circuit,
run_FP6_NEG_circuit,
run_BLS12_381_FP2_MUL_circuit, run_BLS12_381_TOWER_MILLER_BIT0_1P_circuit,
run_BLS12_381_TOWER_MILLER_BIT1_1P_circuit, run_BLS12_381_TOWER_MILLER_INIT_BIT_1P_circuit,
run_BN254_E12T_CYCLOTOMIC_SQUARE_circuit, run_BN254_E12T_FROBENIUS_CUBE_circuit,
run_BN254_E12T_FROBENIUS_SQUARE_circuit, run_BN254_E12T_FROBENIUS_circuit,
run_BN254_E12T_INVERSE_circuit, run_BN254_E12T_MUL_circuit, run_BN254_FP2_MUL_circuit,
run_BN254_TOWER_MILLER_BIT0_1P_circuit, run_BN254_TOWER_MILLER_BIT1_1P_circuit,
run_BN254_TOWER_MILLER_FINALIZE_BN_1P_circuit, run_FP6_NEG_circuit,
};
}
1 change: 1 addition & 0 deletions src/src/definitions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ pub fn get_modulus(curve_index: usize) -> CircuitModulus {
}
}


// Returns the modulus of BLS12_381
#[inline(always)]
pub fn get_BLS12_381_modulus() -> CircuitModulus {
Expand Down
160 changes: 153 additions & 7 deletions src/src/ec_ops_g2.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,68 @@ use core::circuit::{
u384, circuit_add, circuit_sub, circuit_mul, circuit_inverse, EvalCircuitTrait,
CircuitOutputsTrait, CircuitModulus, CircuitInputs,
};
use garaga::circuits::tower_circuits::{run_BLS12_381_FP2_MUL_circuit, run_BN254_FP2_MUL_circuit};

use core::option::Option;
use garaga::core::circuit::AddInputResultTrait2;
use garaga::definitions::{G2Point, G2PointZero, get_BLS12_381_modulus, get_b2, get_a, get_modulus};
use garaga::definitions::{
G2Point, G2PointZero, get_BLS12_381_modulus, get_b2, get_a, get_p, get_modulus,
};
use garaga::circuits::ec;
use garaga::utils::u384_assert_zero;
use garaga::basic_field_ops::neg_mod_p;


const X_SEED_BN254: u256 = 0x44E992B44A6909F1;
const X_SEED_BLS12_381: u256 = 0xD201000000010000; // negated .


const ENDO_U_A0_BN254: u384 = u384 {
limb0: 0xc2c3330c99e39557176f553d,
limb1: 0x4c0bec3cf559b143b78cc310,
limb2: 0x2fb347984f7911f7,
limb3: 0x0,
};
const ENDO_U_A1_BN254: u384 = u384 {
limb0: 0xb7c9dce1665d51c640fcba2,
limb1: 0x4ba4cc8bd75a079432ae2a1d,
limb2: 0x16c9e55061ebae20,
limb3: 0x0,
};
const ENDO_V_A0_BN254: u384 = u384 {
limb0: 0xa9c95998dc54014671a0135a,
limb1: 0xdc5ec698b6e2f9b9dbaae0ed,
limb2: 0x63cf305489af5dc,
limb3: 0x0,
};
const ENDO_V_A1_BN254: u384 = u384 {
limb0: 0x8fa25bd282d37f632623b0e3,
limb1: 0x704b5a7ec796f2b21807dc9,
limb2: 0x7c03cbcac41049a,
limb3: 0x0,
};

const ENDO_U_A0_BLS12_381: u384 = u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 };
const ENDO_U_A1_BLS12_381: u384 = u384 {
limb0: 0x4f49fffd8bfd00000000aaad,
limb1: 0x897d29650fb85f9b409427eb,
limb2: 0x63d4de85aa0d857d89759ad4,
limb3: 0x1a0111ea397fe699ec024086,
};
const ENDO_V_A0_BLS12_381: u384 = u384 {
limb0: 0x3e67fa0af1ee7b04121bdea2,
limb1: 0xef396489f61eb45e304466cf,
limb2: 0xd77a2cd91c3dedd930b1cf60,
limb3: 0x135203e60180a68ee2e9c448,
};
const ENDO_V_A1_BLS12_381: u384 = u384 {
limb0: 0x72ec05f4c81084fbede3cc09,
limb1: 0x77f76e17009241c5ee67992f,
limb2: 0x6bd17ffe48395dabc2d3435e,
limb3: 0x6af0e0437ff400b6831e36d,
};


#[generate_trait]
impl G2PointImpl of G2PointTrait {
fn assert_on_curve(self: @G2Point, curve_index: usize) {
Expand All @@ -30,6 +84,50 @@ impl G2PointImpl of G2PointTrait {
);
return check0.is_zero() && check1.is_zero();
}
// Will fail (with ec_mul) if point is not on the curve.
fn is_in_subgroup(self: @G2Point, curve_index: usize) -> bool {
let pt = *self;
match curve_index {
0 => {
// https://github.com/Consensys/gnark-crypto/blob/37b2cbd0023e53386258750a3e0dd16d45edc2cf/ecc/bn254/g2.go#L494
let a = ec_mul(pt, X_SEED_BN254, curve_index).unwrap();
let b = psi(a, curve_index);
let a = match a.is_zero() {
true => Option::None,
false => Option::Some(a),
};

let a = ec_safe_add_with_options(a, Option::Some(pt), curve_index);
let res = psi(b, curve_index);
let c = ec_safe_add(res, b, curve_index);

let c = ec_safe_add_with_options(c, a, curve_index);
let res = psi(res, curve_index);
let (res) = ec::run_DOUBLE_EC_POINT_G2_A_EQ_0_circuit(res, curve_index);

let neg_c = match c {
Option::Some(c) => Option::Some(c.negate(curve_index)),
Option::None => Option::None,
};
let res = ec_safe_add_with_options(Option::Some(res), neg_c, curve_index);
match res {
Option::Some(r) => Self::is_on_curve(@r, curve_index),
Option::None => false,
}
},
1 => {
// https://github.com/Consensys/gnark-crypto/blob/37b2cbd0023e53386258750a3e0dd16d45edc2cf/ecc/bls12-381/g2.go#L495
let tmp = psi(pt, curve_index);
let res = ec_mul(pt, X_SEED_BLS12_381, curve_index).unwrap();
let res = ec_safe_add(res, tmp, curve_index);
match res {
Option::Some(r) => Self::is_on_curve(@r, curve_index),
Option::None => false,
}
},
_ => { false },
}
}
fn negate(self: @G2Point, curve_index: usize) -> G2Point {
let modulus = get_modulus(curve_index);
G2Point {
Expand Down Expand Up @@ -72,6 +170,37 @@ fn ec_mul(pt: G2Point, s: u256, curve_index: usize) -> Option<G2Point> {
}
}


// // psi sets p to ψ(q) = u o π o u⁻¹ where u:E'→E is the isomorphism from the twist to the
// curve E and π is the Frobenius map.
// Source gnark.
fn psi(pt: G2Point, curve_index: usize) -> G2Point {
match curve_index {
0 => {
let modulus = get_modulus(curve_index);
let (px0, px1) = run_BN254_FP2_MUL_circuit(
pt.x0, neg_mod_p(pt.x1, modulus), ENDO_U_A0_BN254, ENDO_U_A1_BN254,
);
let (py0, py1) = run_BN254_FP2_MUL_circuit(
pt.y0, neg_mod_p(pt.y1, modulus), ENDO_V_A0_BN254, ENDO_V_A1_BN254,
);
return G2Point { x0: px0, x1: px1, y0: py0, y1: py1 };
},
1 => {
let modulus = get_modulus(curve_index);
let (px0, px1) = run_BLS12_381_FP2_MUL_circuit(
pt.x0, neg_mod_p(pt.x1, modulus), ENDO_U_A0_BLS12_381, ENDO_U_A1_BLS12_381,
);
let (py0, py1) = run_BLS12_381_FP2_MUL_circuit(
pt.y0, neg_mod_p(pt.y1, modulus), ENDO_V_A0_BLS12_381, ENDO_V_A1_BLS12_381,
);
return G2Point { x0: px0, x1: px1, y0: py0, y1: py1 };
},
_ => { core::panic_with_felt252('invalid curve id fp2mul') },
}
}


// Returns the bits of the 256 bit number in little endian format.
fn get_bits_little(s: u256) -> Array<felt252> {
let mut bits = ArrayTrait::new();
Expand All @@ -95,14 +224,31 @@ fn get_bits_little(s: u256) -> Array<felt252> {
bits
}


#[inline]
fn ec_safe_add_with_options(
P: Option<G2Point>, Q: Option<G2Point>, curve_index: usize,
) -> Option<G2Point> {
// assumes that the points are on the curve and not the point at infinity.
// Returns None if the points are the same and opposite y coordinates (Point at infinity)
if P.is_none() {
return Q;
}
if Q.is_none() {
return P;
}

return ec_safe_add(P.unwrap(), Q.unwrap(), curve_index);
}

#[inline]
fn ec_safe_add(P: G2Point, Q: G2Point, curve_index: usize) -> Option<G2Point> {
// assumes that the points are on the curve and not the point at infinity.
// Returns None if the points are the same and opposite y coordinates (Point at infinity)
let same_x = eq_mod_p(P.x0, P.x1, Q.x0, Q.x1);
let same_x = eq_mod_p(P.x0, P.x1, Q.x0, Q.x1, curve_index);

if same_x {
let opposite_y = eq_neg_mod_p(P.y0, P.y1, Q.y0, Q.y1);
let opposite_y = eq_neg_mod_p(P.y0, P.y1, Q.y0, Q.y1, curve_index);

if opposite_y {
return Option::None;
Expand Down Expand Up @@ -138,15 +284,15 @@ fn ec_mul_inner(pt: G2Point, mut bits: Array<felt252>, curve_index: usize) -> Op

// returns true if a == b mod p bls12-381
#[inline]
pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool {
pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384, curve_index: usize) -> bool {
let _a0 = CE::<CI<0>> {};
let _a1 = CE::<CI<1>> {};
let _b0 = CE::<CI<2>> {};
let _b1 = CE::<CI<3>> {};
let sub0 = circuit_sub(_a0, _b0);
let sub1 = circuit_sub(_a1, _b1);

let modulus = get_BLS12_381_modulus();
let modulus = get_modulus(curve_index);

let outputs = (sub0, sub1)
.new_inputs()
Expand All @@ -163,15 +309,15 @@ pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool {

// returns true if a == -b mod p bls12-381
#[inline]
pub fn eq_neg_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool {
pub fn eq_neg_mod_p(a0: u384, a1: u384, b0: u384, b1: u384, curve_index: usize) -> bool {
let _a0 = CE::<CI<0>> {};
let _a1 = CE::<CI<1>> {};
let _b0 = CE::<CI<2>> {};
let _b1 = CE::<CI<3>> {};
let check0 = circuit_add(_a0, _b0);
let check1 = circuit_add(_a1, _b1);

let modulus = get_BLS12_381_modulus();
let modulus = get_modulus(curve_index);
let outputs = (check0, check1)
.new_inputs()
.next_2(a0)
Expand Down
Loading