Skip to content

Commit

Permalink
Merge pull request #2062 from privacy-scaling-explorations/feature/po…
Browse files Browse the repository at this point in the history
…ll-join-circuit

feat(circuits): add poll joined circuit
  • Loading branch information
0xmad authored Jan 24, 2025
2 parents aac1ce7 + 30c4f6a commit 8b7ce16
Show file tree
Hide file tree
Showing 28 changed files with 716 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,35 @@ template PollJoining(stateTreeDepth) {

calculatedRoot === stateRoot;
}

template PollJoined(stateTreeDepth) {
// Constants defining the tree structure
var STATE_TREE_ARITY = 2;

// User's private key
signal input privKey;
// User's voice credits balance
signal input voiceCreditsBalance;
// Poll's joined timestamp
signal input joinTimestamp;
// Path elements
signal input pathElements[stateTreeDepth][STATE_TREE_ARITY - 1];
// Path indices
signal input pathIndices[stateTreeDepth];
// Poll State tree root which proves the user is joined
signal input stateRoot;

// User private to public key
var derivedPubKey[2] = PrivToPubKey()(privKey);

var stateLeaf = PoseidonHasher(4)([derivedPubKey[0], derivedPubKey[1], voiceCreditsBalance, joinTimestamp]);

// Inclusion proof
var stateLeafQip = MerkleTreeInclusionProof(stateTreeDepth)(
stateLeaf,
pathIndices,
pathElements
);

stateLeafQip === stateRoot;
}
8 changes: 7 additions & 1 deletion packages/circuits/circom/circuits.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
{
"PollJoining_10_test": {
"file": "./anon/pollJoining",
"file": "./anon/poll",
"template": "PollJoining",
"params": [10],
"pubs": ["nullifier", "stateRoot", "pollPubKey", "pollId"]
},
"PollJoined_10_test": {
"file": "./anon/poll",
"template": "PollJoined",
"params": [10],
"pubs": ["stateRoot"]
},
"ProcessMessages_10-20-2_test": {
"file": "./core/qv/processMessages",
"template": "ProcessMessages",
Expand Down
3 changes: 2 additions & 1 deletion packages/circuits/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"test:tallyVotes": "pnpm run mocha-test ts/__tests__/TallyVotes.test.ts",
"test:ceremonyParams": "pnpm run mocha-test ts/__tests__/CeremonyParams.test.ts",
"test:incrementalQuinaryTree": "pnpm run mocha-test ts/__tests__/IncrementalQuinaryTree.test.ts",
"test:pollJoining": "pnpm run mocha-test ts/__tests__/PollJoining.test.ts"
"test:pollJoining": "pnpm run mocha-test ts/__tests__/PollJoining.test.ts",
"test:pollJoined": "pnpm run mocha-test ts/__tests__/PollJoined.test.ts"
},
"dependencies": {
"@zk-kit/circuits": "^0.4.0",
Expand Down
126 changes: 126 additions & 0 deletions packages/circuits/ts/__tests__/PollJoined.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { type WitnessTester } from "circomkit";
import { MaciState, Poll } from "maci-core";
import { poseidon } from "maci-crypto";
import { Keypair, Message, PCommand } from "maci-domainobjs";

import type { IPollJoinedInputs } from "../types";

import { STATE_TREE_DEPTH, duration, messageBatchSize, treeDepths, voiceCreditBalance } from "./utils/constants";
import { circomkitInstance } from "./utils/utils";

describe("Poll Joined circuit", function test() {
this.timeout(900000);
const NUM_USERS = 50;

const coordinatorKeypair = new Keypair();

type PollJoinedCircuitInputs = [
"privKey",
"voiceCreditsBalance",
"joinTimestamp",
"stateLeaf",
"pathElements",
"pathIndices",
"stateRoot",
];

let circuit: WitnessTester<PollJoinedCircuitInputs>;

before(async () => {
circuit = await circomkitInstance.WitnessTester("pollJoined", {
file: "./anon/poll",
template: "PollJoined",
params: [STATE_TREE_DEPTH],
});
});

describe(`${NUM_USERS} users, 1 join`, () => {
const maciState = new MaciState(STATE_TREE_DEPTH);
let pollId: bigint;
let poll: Poll;
let users: Keypair[];
const messages: Message[] = [];
const commands: PCommand[] = [];

const timestamp = BigInt(Math.floor(Date.now() / 1000));

before(() => {
// Sign up
users = new Array(NUM_USERS).fill(0).map(() => new Keypair());

users.forEach((userKeypair) => {
maciState.signUp(userKeypair.pubKey);
});

pollId = maciState.deployPoll(timestamp + BigInt(duration), treeDepths, messageBatchSize, coordinatorKeypair);

poll = maciState.polls.get(pollId)!;
poll.updatePoll(BigInt(maciState.pubKeys.length));

// Join the poll
const { privKey, pubKey: pollPubKey } = users[0];

const nullifier = poseidon([BigInt(privKey.rawPrivKey.toString())]);

const stateIndex = BigInt(poll.joinPoll(nullifier, pollPubKey, voiceCreditBalance, timestamp));

// First command (valid)
const command = new PCommand(
stateIndex,
pollPubKey,
BigInt(0), // voteOptionIndex,
BigInt(9), // vote weight
BigInt(1), // nonce
BigInt(pollId),
);

const signature = command.sign(privKey);

const ecdhKeypair = new Keypair();
const sharedKey = Keypair.genEcdhSharedKey(ecdhKeypair.privKey, coordinatorKeypair.pubKey);
const message = command.encrypt(signature, sharedKey);
messages.push(message);
commands.push(command);

poll.publishMessage(message, ecdhKeypair.pubKey);

// Process messages
poll.processMessages(pollId);
});

it("should produce a proof", async () => {
const { privKey: privateKey, pubKey: pollPubKey } = users[0];
const nullifier = poseidon([BigInt(privateKey.asCircuitInputs()), poll.pollId]);

const stateLeafIndex = poll.joinPoll(nullifier, pollPubKey, voiceCreditBalance, timestamp);

const inputs = poll.joinedCircuitInputs({
maciPrivKey: privateKey,
stateLeafIndex: BigInt(stateLeafIndex),
voiceCreditsBalance: voiceCreditBalance,
joinTimestamp: timestamp,
}) as unknown as IPollJoinedInputs;

const witness = await circuit.calculateWitness(inputs);
await circuit.expectConstraintPass(witness);
});

it("should fail for fake witness", async () => {
const { privKey: privateKey, pubKey: pollPubKey } = users[1];
const nullifier = poseidon([BigInt(privateKey.asCircuitInputs()), poll.pollId]);

const stateLeafIndex = poll.joinPoll(nullifier, pollPubKey, voiceCreditBalance, timestamp);

const inputs = poll.joinedCircuitInputs({
maciPrivKey: privateKey,
stateLeafIndex: BigInt(stateLeafIndex),
voiceCreditsBalance: voiceCreditBalance,
joinTimestamp: timestamp,
}) as unknown as IPollJoinedInputs;
const witness = await circuit.calculateWitness(inputs);

const fakeWitness = Array(witness.length).fill(1n) as bigint[];
await circuit.expectConstraintFail(fakeWitness);
});
});
});
2 changes: 1 addition & 1 deletion packages/circuits/ts/__tests__/PollJoining.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ describe("Poll Joining circuit", function test() {

before(async () => {
circuit = await circomkitInstance.WitnessTester("pollJoining", {
file: "./anon/pollJoining",
file: "./anon/poll",
template: "PollJoining",
params: [STATE_TREE_DEPTH],
});
Expand Down
14 changes: 14 additions & 0 deletions packages/circuits/ts/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ export interface IPollJoiningInputs {
pollId: bigint;
}

/**
* Inputs for circuit PollJoined
*/
export interface IPollJoinedInputs {
privKey: bigint;
voiceCreditsBalance: bigint;
joinTimestamp: bigint;
stateLeaf: bigint[];
pathElements: bigint[][];
pathIndices: bigint[];
credits: bigint;
stateRoot: bigint;
}

/**
* Inputs for circuit ProcessMessages
*/
Expand Down
2 changes: 2 additions & 0 deletions packages/cli/tests/ceremony-params/ceremonyParams.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
ceremonyProcessMessagesNonQvWasmPath,
ceremonyTallyVotesNonQvWasmPath,
ceremonyPollJoiningZkeyPath,
ceremonyPollJoinedZkeyPath,
} from "../constants";
import { clean, isArm } from "../utils";

Expand All @@ -68,6 +69,7 @@ describe("Stress tests with ceremony params (6,3,2,20)", function test() {
voteOptionTreeDepth,
messageBatchSize,
pollJoiningZkeyPath: ceremonyPollJoiningZkeyPath,
pollJoinedZkeyPath: ceremonyPollJoinedZkeyPath,
processMessagesZkeyPathQv: ceremonyProcessMessagesZkeyPath,
tallyVotesZkeyPathQv: ceremonyTallyVotesZkeyPath,
processMessagesZkeyPathNonQv: ceremonyProcessMessagesNonQvZkeyPath,
Expand Down
5 changes: 5 additions & 0 deletions packages/cli/tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export const coordinatorPubKey = coordinatorKeypair.pubKey.serialize();
export const coordinatorPrivKey = coordinatorKeypair.privKey.serialize();

export const pollJoiningTestZkeyPath = "./zkeys/PollJoining_10_test/PollJoining_10_test.0.zkey";
export const pollJoinedTestZkeyPath = "./zkeys/PollJoined_10_test/PollJoined_10_test.0.zkey";
export const processMessageTestZkeyPath = "./zkeys/ProcessMessages_10-20-2_test/ProcessMessages_10-20-2_test.0.zkey";
export const tallyVotesTestZkeyPath = "./zkeys/TallyVotes_10-1-2_test/TallyVotes_10-1-2_test.0.zkey";
export const processMessageTestNonQvZkeyPath =
Expand All @@ -47,6 +48,7 @@ export const testTallyVotesWasmPath =
"./zkeys/TallyVotes_10-1-2_test/TallyVotes_10-1-2_test_js/TallyVotes_10-1-2_test.wasm";
export const testRapidsnarkPath = `${homedir()}/rapidsnark/build/prover`;
export const ceremonyPollJoiningZkeyPath = "./zkeys/PollJoining_10_test/PollJoining_10_test.0.zkey";
export const ceremonyPollJoinedZkeyPath = "./zkeys/PollJoined_10_test/PollJoined_10_test.0.zkey";
export const ceremonyProcessMessagesZkeyPath = "./zkeys/ProcessMessages_6-9-2-3/processMessages_6-9-2-3.zkey";
export const ceremonyProcessMessagesNonQvZkeyPath =
"./zkeys/ProcessMessagesNonQv_6-9-2-3/processMessagesNonQv_6-9-2-3.zkey";
Expand Down Expand Up @@ -98,6 +100,7 @@ export const setVerifyingKeysArgs: Omit<SetVerifyingKeysArgs, "signer"> = {
voteOptionTreeDepth: VOTE_OPTION_TREE_DEPTH,
messageBatchSize: MESSAGE_BATCH_SIZE,
pollJoiningZkeyPath: pollJoiningTestZkeyPath,
pollJoinedZkeyPath: pollJoinedTestZkeyPath,
processMessagesZkeyPathQv: processMessageTestZkeyPath,
tallyVotesZkeyPathQv: tallyVotesTestZkeyPath,
};
Expand All @@ -109,6 +112,7 @@ export const setVerifyingKeysNonQvArgs: Omit<SetVerifyingKeysArgs, "signer"> = {
voteOptionTreeDepth: VOTE_OPTION_TREE_DEPTH,
messageBatchSize: MESSAGE_BATCH_SIZE,
pollJoiningZkeyPath: pollJoiningTestZkeyPath,
pollJoinedZkeyPath: pollJoinedTestZkeyPath,
processMessagesZkeyPathNonQv: processMessageTestNonQvZkeyPath,
tallyVotesZkeyPathNonQv: tallyVotesTestNonQvZkeyPath,
};
Expand All @@ -119,6 +123,7 @@ export const checkVerifyingKeysArgs: Omit<CheckVerifyingKeysArgs, "signer"> = {
voteOptionTreeDepth: VOTE_OPTION_TREE_DEPTH,
messageBatchSize: MESSAGE_BATCH_SIZE,
pollJoiningZkeyPath: pollJoiningTestZkeyPath,
pollJoinedZkeyPath: pollJoinedTestZkeyPath,
processMessagesZkeyPath: processMessageTestZkeyPath,
tallyVotesZkeyPath: tallyVotesTestZkeyPath,
};
Expand Down
31 changes: 15 additions & 16 deletions packages/cli/ts/commands/checkVerifyingKeys.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const checkVerifyingKeys = async ({
processMessagesZkeyPath,
tallyVotesZkeyPath,
pollJoiningZkeyPath,
pollJoinedZkeyPath,
vkRegistry,
signer,
useQuadraticVoting = true,
Expand Down Expand Up @@ -69,30 +70,28 @@ export const checkVerifyingKeys = async ({
// extract the verification keys from the zkey files
const processVk = VerifyingKey.fromObj(await extractVk(processMessagesZkeyPath));
const tallyVk = VerifyingKey.fromObj(await extractVk(tallyVotesZkeyPath));
const pollVk = VerifyingKey.fromObj(await extractVk(pollJoiningZkeyPath));
const pollJoiningVk = VerifyingKey.fromObj(await extractVk(pollJoiningZkeyPath));
const pollJoinedVk = VerifyingKey.fromObj(await extractVk(pollJoinedZkeyPath));

try {
logYellow(quiet, info("Retrieving verifying keys from the contract..."));
// retrieve the verifying keys from the contract

const pollVkOnChain = await vkRegistryContractInstance.getPollVk(stateTreeDepth, voteOptionTreeDepth);
const mode = useQuadraticVoting ? EMode.QV : EMode.NON_QV;

const processVkOnChain = await vkRegistryContractInstance.getProcessVk(
stateTreeDepth,
voteOptionTreeDepth,
messageBatchSize,
useQuadraticVoting ? EMode.QV : EMode.NON_QV,
);

const tallyVkOnChain = await vkRegistryContractInstance.getTallyVk(
stateTreeDepth,
intStateTreeDepth,
voteOptionTreeDepth,
useQuadraticVoting ? EMode.QV : EMode.NON_QV,
);
const [pollJoiningVkOnChain, pollJoinedVkOnChain, processVkOnChain, tallyVkOnChain] = await Promise.all([
vkRegistryContractInstance.getPollJoiningVk(stateTreeDepth, voteOptionTreeDepth),
vkRegistryContractInstance.getPollJoinedVk(stateTreeDepth, voteOptionTreeDepth),
vkRegistryContractInstance.getProcessVk(stateTreeDepth, voteOptionTreeDepth, messageBatchSize, mode),
vkRegistryContractInstance.getTallyVk(stateTreeDepth, intStateTreeDepth, voteOptionTreeDepth, mode),
]);

// do the actual validation
if (!compareVks(pollVk, pollVkOnChain)) {
if (!compareVks(pollJoiningVk, pollJoiningVkOnChain)) {
logError("Poll verifying keys do not match");
}

if (!compareVks(pollJoinedVk, pollJoinedVkOnChain)) {
logError("Poll verifying keys do not match");
}

Expand Down
8 changes: 5 additions & 3 deletions packages/cli/ts/commands/extractVkToFile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@ export const extractVkToFile = async ({
processMessagesZkeyPathQv,
tallyVotesZkeyPathQv,
processMessagesZkeyPathNonQv,
tallyVotesZkeyPathNonQv,
pollJoinedZkeyPath,
pollJoiningZkeyPath,
tallyVotesZkeyPathNonQv,
outputFilePath,
}: ExtractVkToFileArgs): Promise<void> => {
const [processVkQv, tallyVkQv, processVkNonQv, tallyVkNonQv, pollVk] = await Promise.all([
const [processVkQv, tallyVkQv, processVkNonQv, tallyVkNonQv, pollJoiningVk, pollJoinedVk] = await Promise.all([
extractVk(processMessagesZkeyPathQv),
extractVk(tallyVotesZkeyPathQv),
extractVk(processMessagesZkeyPathNonQv),
extractVk(tallyVotesZkeyPathNonQv),
extractVk(pollJoiningZkeyPath),
extractVk(pollJoinedZkeyPath),
]);

await fs.promises.writeFile(
outputFilePath,
JSON.stringify({ processVkQv, tallyVkQv, processVkNonQv, tallyVkNonQv, pollVk }),
JSON.stringify({ processVkQv, tallyVkQv, processVkNonQv, tallyVkNonQv, pollJoiningVk, pollJoinedVk }),
);
};
6 changes: 3 additions & 3 deletions packages/cli/ts/commands/joinPoll.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const getStateIndex = (pubKeys: PubKey[], userMaciPubKey: PubKey, stateIndex?: b
* @param pollVk - Poll verifying key
* @returns proof - an array of strings
*/
const generateAndVerifyProof = async (
export const generateAndVerifyProof = async (
inputs: CircuitInputs,
zkeyPath: string,
useWasm: boolean | undefined,
Expand Down Expand Up @@ -288,7 +288,7 @@ export const joinPoll = async ({
) as unknown as CircuitInputs;
}

const pollVk = await extractVk(pollJoiningZkey);
const pollJoiningVk = await extractVk(pollJoiningZkey);

let pollStateIndex = "";
let receipt: ContractTransactionReceipt | null = null;
Expand All @@ -305,7 +305,7 @@ export const joinPoll = async ({
rapidsnark,
pollWitgen,
pollWasm,
pollVk,
pollJoiningVk,
);

// submit the message onchain as well as the encryption public key
Expand Down
Loading

0 comments on commit 8b7ce16

Please sign in to comment.