Skip to content

Commit

Permalink
Re-implement encrypt.circom and poseidon decrypt
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Zhang <[email protected]>
  • Loading branch information
jimthematrix committed Aug 22, 2024
1 parent be9e994 commit 257f2ab
Show file tree
Hide file tree
Showing 10 changed files with 348 additions and 35 deletions.
66 changes: 57 additions & 9 deletions zkp/circuits/lib/encrypt.circom
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// limitations under the License.
pragma circom 2.1.4;

include "../node_modules/circomlib/circuits/comparators.circom";
include "../node_modules/circomlib/circuits/poseidon.circom";

// encrypts a message by appending it to Poseidon hashes
Expand All @@ -24,16 +25,63 @@ template SymmetricEncrypt(length) {
signal input plainText[length]; // private
signal input key[2]; // private
signal input nonce; // public
signal output cipherText[length];

component hashers[length];
for (var i = 0; i < length; i++) {
hashers[i] = Poseidon(4);
hashers[i].inputs[0] <== key[0];
hashers[i].inputs[1] <== key[1];
hashers[i].inputs[2] <== nonce;
hashers[i].inputs[3] <== i;
var two128 = 2 ** 128;

cipherText[i] <== hashers[i].out + plainText[i];
// nonce must be < 2^128
component lt = LessThan(252);
lt.in[0] <== nonce;
lt.in[1] <== two128;
lt.out === 1;

// the number of plain text messages must be multiple of 3
// pad the array with zeros if necessary.
// e.g. if length == 4, l == 6
var l = length;
while (l % 3 != 0) {
l += 1;
}
var messages[l];
for (var i = 0; i < l; i++) {
if (i < length) {
messages[i] = plainText[i];
} else {
messages[i] = 0;
}
}

signal output cipherText[l + 1];

// calculate the number of iterations needed for the encryption
// process. "\"" is integer division
var n = l \ 3;

// create the initial state: [0, key[0], key[1], nonce + (length * 2^128)]
component rounds[n + 1];
rounds[0] = PoseidonEx(4, 4);
rounds[0].initialState <== 0;
rounds[0].inputs[0] <== 0;
rounds[0].inputs[1] <== key[0];
rounds[0].inputs[2] <== key[1];
rounds[0].inputs[3] <== nonce + (length * two128);

for (var i = 0; i < n; i++) {
rounds[i + 1] = PoseidonEx(4, 4);
rounds[i + 1].initialState <== 0;
rounds[i + 1].inputs[0] <== rounds[i].out[0];

// Absorb three elements of message, setting them to the
// corresponding inputs of the next round
rounds[i + 1].inputs[1] <== rounds[i].out[1] + messages[i * 3];
rounds[i + 1].inputs[2] <== rounds[i].out[2] + messages[i * 3 + 1];
rounds[i + 1].inputs[3] <== rounds[i].out[3] + messages[i * 3 + 2];

// release three elements of the ciphertext
cipherText[i * 3] <== rounds[i + 1].inputs[1];
cipherText[i * 3 + 1] <== rounds[i + 1].inputs[2];
cipherText[i * 3 + 2] <== rounds[i + 1].inputs[3];
}

// Iterate Poseidon on the state one last time
cipherText[l] <== rounds[n].out[1];
}
20 changes: 20 additions & 0 deletions zkp/circuits/lib/poseidon-ex.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
pragma circom 2.1.4;

include "../node_modules/circomlib/circuits/poseidon.circom";

template TestPoseidonEx() {
signal input inputs[4];
signal output out[4];

component poseidon = PoseidonEx(4, 4);
poseidon.initialState <== 0;
poseidon.inputs[0] <== inputs[0];
poseidon.inputs[1] <== inputs[1];
poseidon.inputs[2] <== inputs[2];
poseidon.inputs[3] <== inputs[3];

out[0] <== poseidon.out[0];
out[1] <== poseidon.out[1];
out[2] <== poseidon.out[2];
out[3] <== poseidon.out[3];
}
17 changes: 17 additions & 0 deletions zkp/circuits/lib/poseidon.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pragma circom 2.1.4;

include "../node_modules/circomlib/circuits/poseidon.circom";

template TestPoseidon() {
signal input a;
signal input b;
signal input c;
signal output out;

component poseidon = Poseidon(3);
poseidon.inputs[0] <== a;
poseidon.inputs[1] <== b;
poseidon.inputs[2] <== c;

out <== poseidon.out;
}
147 changes: 131 additions & 16 deletions zkp/js/lib/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,146 @@
// limitations under the License.

const { genRandomSalt } = require('maci-crypto');
const Poseidon = require('poseidon-lite');
const { buildPoseidon } = require('circomlibjs');
const { solidityPackedKeccak256 } = require('ethers');
const { createHash } = require('crypto');

const F = BigInt('21888242871839275222246405745257275088548364400416034343698204186575808495617');

function newSalt() {
return genRandomSalt();
}

function poseidonDecrypt(cipherText, sharedSecret, nonce) {
const plainText = [];
cipherText.forEach((msg, index) => {
const hash = Poseidon.poseidon4([sharedSecret[0], sharedSecret[1], BigInt(nonce), BigInt(index)]);
// subtract the hash from the cipherText to get the plainText
plainText[index] = addMod([BigInt(msg), -hash], F);
while (plainText[index] < 0n) {
plainText[index] += F;
// Implements the encryption and decryption functions using Poseidon hash
// as described: https://drive.google.com/file/d/1EVrP3DzoGbmzkRmYnyEDcIQcXVU7GlOd/view
// The encryption and decryption functions are compatible with the circom implementations,
// meaning the cipher texts encrypted by the circuit in circuits/lib/encrypt.circom can
// be decrypted by the poseidonDecrypt function. And vice versa.
class PoseidonCipher {
constructor() {}

async init() {
this.poseidon = await buildPoseidon();
this.Fr = this.poseidon.F;
this.two128 = this.Fr.e(BigInt('340282366920938463463374607431768211456'));
}

async encrypt(msg, key, nonce) {
validateInputs(msg, key, nonce);

const Fr = this.Fr;
msg = msg.map((x) => Fr.e(x));

// the size of the message array must be a multiple of 3
const message = [...msg];
while (message.length % 3 > 0) {
// pad with zeros if necessary
message.push(Fr.zero);
}

// Create the initial state
// S = (0, kS[0], kS[1], N + l * 2^128)
let state = [Fr.zero, Fr.e(key[0]), Fr.e(key[1]), Fr.add(Fr.e(BigInt(nonce)), Fr.mul(Fr.e(BigInt(msg.length)), this.two128))];

const ciphertext = [];

const n = Math.floor(message.length / 3);
for (let i = 0; i < n; i += 1) {
// Iterate Poseidon on the state
state = this.poseidon(state, 0, 4);

// Modify the state for the next round
state[1] = Fr.add(message[i * 3], state[1]);
state[2] = Fr.add(message[i * 3 + 1], state[2]);
state[3] = Fr.add(message[i * 3 + 2], state[3]);

// Record the three elements of the encrypted message
ciphertext.push(state[1]);
ciphertext.push(state[2]);
ciphertext.push(state[3]);
}

// Iterate Poseidon on the state one last time
state = this.poseidon(state, 0, 4);

// Record the last ciphertext element
ciphertext.push(Fr.add(Fr.zero, state[1]));

return ciphertext.map((t) => Fr.toObject(t));
}

async decrypt(ciphertext, key, nonce, length) {
validateInputs(ciphertext, key, nonce, length);

const Fr = this.Fr;

// Create the initial state
// S = (0, kS[0], kS[1], N + l ∗ 2^128).
let state = [Fr.zero, Fr.e(key[0]), Fr.e(key[1]), Fr.add(Fr.e(BigInt(nonce)), Fr.mul(Fr.e(BigInt(length)), this.two128))];

const message = [];

const n = Math.floor(ciphertext.length / 3);
for (let i = 0; i < n; i += 1) {
// Iterate Poseidon on the state
state = this.poseidon(state, 0, 4);

// Release three elements of the decrypted message
message.push(Fr.sub(Fr.e(ciphertext[i * 3]), state[1]));
message.push(Fr.sub(Fr.e(ciphertext[i * 3 + 1]), state[2]));
message.push(Fr.sub(Fr.e(ciphertext[i * 3 + 2]), state[3]));

// Modify the state for the next round
state[1] = ciphertext[i * 3];
state[2] = ciphertext[i * 3 + 1];
state[3] = ciphertext[i * 3 + 2];
}
});
return plainText;

// If length > 3, check if the last (3 - (l mod 3)) elements of the message are 0
if (length > 3) {
if (length % 3 === 2) {
this.checkEqual(message[message.length - 1], Fr.zero, 'The last element of the message must be 0');
} else if (length % 3 === 1) {
this.checkEqual(message[message.length - 1], Fr.zero, 'The last element of the message must be 0');
this.checkEqual(message[message.length - 2], Fr.zero, 'The second to last element of the message must be 0');
}
}

// Iterate Poseidon on the state one last time
state = this.poseidon(state, 0, 4);

// Check the last ciphertext element
this.checkEqual(Fr.e(ciphertext[ciphertext.length - 1]), Fr.e(state[1]), 'The last ciphertext element must match the second item of the permuted state');

return message.slice(0, length).map((t) => Fr.toObject(t));
}

checkEqual(a, b, error) {
if (!this.Fr.eq(a, b)) {
throw new Error(error);
}
}
}

function addMod(addMe, m) {
return addMe.reduce((e, acc) => (((e + m) % m) + acc) % m, BigInt(0));
function validateInputs(msg, key, nonce, length) {
if (!Array.isArray(msg)) {
throw new Error('The message must be an array');
}
for (let i = 0; i < msg.length; i += 1) {
if (typeof msg[i] !== 'bigint') {
throw new Error('Each message element must be a BigInt');
}
}
if (key.length !== 2) {
throw new Error('The key must be an array of two elements');
}
if (typeof key[0] !== 'bigint' || typeof key[1] !== 'bigint') {
throw new Error('The key must be an array of two BigInts');
}
if (typeof nonce !== 'bigint') {
throw new Error('The nonce must be a BigInt');
}
if (length && length < 1) {
throw new Error('The length must be at least 1');
}
}

// convert the proof json to the format that the Solidity verifier expects
Expand Down Expand Up @@ -84,7 +199,7 @@ function kycHash(bjjPublicKey) {

module.exports = {
newSalt,
poseidonDecrypt,
PoseidonCipher,
encodeProof,
getProofHash,
tokenUriHash,
Expand Down
1 change: 1 addition & 0 deletions zkp/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"@iden3/js-merkletree": "^1.2.0",
"chai": "^4.4.1",
"circom_tester": "^0.0.20",
"circomlibjs": "^0.1.7",
"mocha": "^10.2.0"
}
}
5 changes: 5 additions & 0 deletions zkp/js/test/circuits/poseidon-ex.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pragma circom 2.1.4;

include "../../../circuits/lib/poseidon-ex.circom";

component main = TestPoseidonEx();
5 changes: 5 additions & 0 deletions zkp/js/test/circuits/poseidon.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pragma circom 2.1.4;

include "../../../circuits/lib/poseidon.circom";

component main = TestPoseidon();
37 changes: 27 additions & 10 deletions zkp/js/test/lib/encrypt.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

const { expect } = require('chai');
const { join } = require('path');
const crypto = require('crypto');
const { wasm: wasm_tester } = require('circom_tester');
const { genRandomSalt, genKeypair, genEcdhSharedKey, stringifyBigInts } = require('maci-crypto');
const { poseidonDecrypt } = require('../../lib/util.js');
const { genKeypair, genEcdhSharedKey, stringifyBigInts } = require('maci-crypto');
const { PoseidonCipher } = require('../../lib/util.js');

describe('Encryption circuit tests', () => {
let circuit;
Expand All @@ -37,24 +38,40 @@ describe('Encryption circuit tests', () => {
receiverPubKey = keypair.pubKey;
});

it('should generate the cipher text in the proof circuit, which can be decrypted by the receiver', async () => {
const messageAndSalt = [1234567890, 24680135791234567890].map((x) => BigInt(x));
it('using poseidonEncrypt() to generate the cipher text, and poseidonDecrypt() to recover the plain text', async () => {
const key = genEcdhSharedKey(senderPrivKey, receiverPubKey);
const hex = crypto.randomBytes(16).toString('hex');
const nonce = BigInt(`0x${hex}`);

const cipher = new PoseidonCipher();
await cipher.init();

const result = await cipher.encrypt([123n, 4567890n], key, nonce);

const recoveredKey = genEcdhSharedKey(receiverPrivKey, senderPubKey);
const plainText = await cipher.decrypt(result, recoveredKey, nonce, 2);
expect(plainText).to.deep.equal([123n, 4567890n]);
});

it('should generate the cipher text in the proof circuit, which can be decrypted by the receiver', async () => {
const key = genEcdhSharedKey(senderPrivKey, receiverPubKey);
const nonce = genRandomSalt();
const hex = crypto.randomBytes(16).toString('hex');
const nonce = BigInt(`0x${hex}`);

const circuitInputs = stringifyBigInts({
plainText: messageAndSalt,
plainText: [123, 4567890],
nonce,
key,
});

const witness = await circuit.calculateWitness(circuitInputs);

const encryptedValue = witness[1];
const encryptedNonce = witness[2];
const ciphertext = witness.slice(1, 5).map((x) => BigInt(x));
const recoveredKey = genEcdhSharedKey(receiverPrivKey, senderPubKey);
const plainText = poseidonDecrypt([encryptedValue, encryptedNonce], recoveredKey, nonce);
expect(plainText).to.deep.equal(messageAndSalt);

const cipher = new PoseidonCipher();
await cipher.init();
const plainText = await cipher.decrypt(ciphertext, recoveredKey, nonce, 2);
expect(plainText).to.deep.equal([123n, 4567890n]);
});
});
Loading

0 comments on commit 257f2ab

Please sign in to comment.