Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sharding of global challenge phase commitment and opcode proving #695

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub struct ZKVMProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
pub raw_pi: Vec<Vec<E::BaseField>>,
// the evaluation of raw_pi.
pub pi_evals: Vec<E>,
opcode_proofs: BTreeMap<String, (usize, ZKVMOpcodeProof<E, PCS>)>,
opcode_proofs: BTreeMap<String, (usize, Vec<ZKVMOpcodeProof<E, PCS>>)>,
table_proofs: BTreeMap<String, (usize, ZKVMTableProof<E, PCS>)>,
}

Expand Down
106 changes: 70 additions & 36 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::assert_eq;
use ff_ext::ExtensionField;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
Expand All @@ -8,7 +9,7 @@ use ff::Field;
use itertools::{Itertools, enumerate, izip};
use mpcs::PolynomialCommitmentScheme;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension},
mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension},
util::ceil_log2,
virtual_poly::build_eq_x_r_vec,
virtual_poly_v2::ArcMultilinearExtension,
Expand Down Expand Up @@ -36,6 +37,7 @@ use crate::{
},
utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads},
virtual_polys::VirtualPolynomials,
witness::RowMajorMatrix,
};

use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof};
Expand Down Expand Up @@ -90,33 +92,59 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
}
exit_span!(span);

// TODO: is it better to set different size of different opcode?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. Why choosing 1048576 here?

let shard_size = 1048576;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be easier to read in hex. (Because it's not just a random number.)


// commit to main traces
let mut commitments = BTreeMap::new();
let mut wits = BTreeMap::new();
// TODO: (1) is it ok to store mle? (2) replace tuple with struct?
#[allow(clippy::type_complexity)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use some type synonym or so?

let mut wits_and_commitments: BTreeMap<
String,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of the key here? A type synonym might be useful here?

Vec<(
RowMajorMatrix<_>,
Vec<DenseMultilinearExtension<_>>,
PCS::CommitmentWithWitness,
)>,
> = BTreeMap::new();

let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true);
// commit to opcode circuits first and then commit to table circuits, sorted by name
for (circuit_name, witness) in witnesses.into_iter_sorted() {
let num_instances = witness.num_instances();
tracing::debug!(
"committing {} witnesses of size {}..",
circuit_name,
num_instances
);
if num_instances == 0 {
wits_and_commitments.insert(circuit_name.clone(), Vec::new());
continue;
}
let span = entered_span!(
"commit to iteration",
circuit_name = circuit_name,
profiling_2 = true
);
let witness = match num_instances {
0 => vec![],
_ => {
let witness = witness.into_mles();
commitments.insert(
circuit_name.clone(),
PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript)
.map_err(ZKVMError::PCSError)?,
);
witness
}
};

let witness_shards = witness.shard_by_rows(shard_size);
if witness_shards.len() > 1 {
tracing::info!(
"split {circuit_name} witness into {} shards",
witness_shards.len()
);
}
let witness_and_commitment: Vec<_> = witness_shards
.into_iter()
.map(|witness| -> Result<_, ZKVMError> {
let witness_mles = witness.clone().into_mles();
let commitment =
PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript)
.map_err(ZKVMError::PCSError)?;
Ok((witness, witness_mles, commitment))
})
.collect::<Result<Vec<_>, _>>()?;
wits_and_commitments.insert(circuit_name, witness_and_commitment);
exit_span!(span);
wits.insert(circuit_name, (witness, num_instances));
}
exit_span!(commit_to_traces_span);

Expand All @@ -135,13 +163,14 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
.iter() // Sorted by key.
.zip_eq(transcripts.iter_mut().enumerate())
{
let (witness, num_instances) = wits
let mut witness_and_wit: Vec<_> = wits_and_commitments
.remove(circuit_name)
.ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?;
if witness.is_empty() {

if witness_and_wit.is_empty() {
continue;
}
let wits_commit = commitments.remove(circuit_name).unwrap();

// TODO: add an enum for circuit type either in constraint_system or vk
let cs = pk.get_cs();
let is_opcode_circuit = cs.lk_table_expressions.is_empty()
Expand All @@ -157,31 +186,36 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
cs.w_expressions.len(),
cs.lk_expressions.len(),
);
let opcode_proof = self.create_opcode_proof(
circuit_name,
&self.pk.pp,
pk,
witness.into_iter().map(|w| w.into()).collect_vec(),
wits_commit,
&pi,
num_instances,
transcript,
&challenges,
)?;
tracing::info!(
"generated proof for opcode {} with num_instances={}",
circuit_name,
num_instances
);
let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, mles, wits_commit))| -> Result<_, ZKVMError> {
let num_instances = witness.num_instances();
let proof = self.create_opcode_proof(
circuit_name,
&self.pk.pp,
pk,
mles.into_iter().map(|v| v.into()).collect_vec(),
wits_commit,
&pi,
num_instances,
transcript,
&challenges,
)?;
tracing::info!(
"generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}"
);
Ok(proof)
}).collect::<Result<Vec<_>, _>>()?;
vm_proof
.opcode_proofs
.insert(circuit_name.clone(), (i, opcode_proof));
} else {
assert_eq!(witness_and_wit.len(), 1);
let (witness, mles, wits_commit) = witness_and_wit.remove(0);
let num_instances = witness.num_instances();
let (table_proof, pi_in_evals) = self.create_table_proof(
circuit_name,
&self.pk.pp,
pk,
witness.into_iter().map(|v| v.into()).collect_vec(),
mles.into_iter().map(|v| v.into()).collect_vec(),
wits_commit,
&pi,
transcript,
Expand Down
82 changes: 50 additions & 32 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
does_halt: bool,
) -> Result<bool, ZKVMError> {
// require ecall/halt proof to exist, depending whether we expect a halt.
// TODO: make it less adhoc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially once we support more ecalls. Or perhaps we should introduce a specific 'halt-successfully' introduction.

let num_instances = vm_proof
.opcode_proofs
.get(&HaltInstruction::<E>::name())
.map(|(_, p)| p.num_instances)
.map(|(_, p)| p[0].num_instances)
.unwrap_or(0);
if num_instances != (does_halt as usize) {
return Err(ZKVMError::VerifyError(format!(
Expand Down Expand Up @@ -119,8 +120,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>

for (name, (_, proof)) in vm_proof.opcode_proofs.iter() {
tracing::debug!("read {}'s commit", name);
PCS::write_commitment(&proof.wits_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
for p in proof {
PCS::write_commitment(&p.wits_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
}
}
for (name, (_, proof)) in vm_proof.table_proofs.iter() {
tracing::debug!("read {}'s commit", name);
Expand All @@ -140,43 +143,47 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let point_eval = PointAndEval::default();
let mut transcripts = transcript.fork(self.vk.circuit_vks.len());

for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
for (name, (i, opcode_proofs)) in vm_proof.opcode_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
.get(&name)
.ok_or(ZKVMError::VKNotFound(name.clone()))?;
let _rand_point = self.verify_opcode_proof(
&name,
&self.vk.vp,
circuit_vk,
&opcode_proof,
pi_evals,
transcript,
NUM_FANIN,
&point_eval,
&challenges,
)?;
tracing::info!("verified proof for opcode {}", name);
for opcode_proof in &opcode_proofs {
let _rand_point = self.verify_opcode_proof(
&name,
&self.vk.vp,
circuit_vk,
opcode_proof,
pi_evals,
transcript,
NUM_FANIN,
&point_eval,
&challenges,
)?;
}

// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().lk_expressions.len();
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
let num_padded_instance =
next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances;
dummy_table_item_multiplicity += num_padded_lks_per_instance
* opcode_proof.num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>();
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>();

logup_sum +=
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap();
logup_sum +=
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
tracing::info!("verified proof for opcode {}", name);
for opcode_proof in &opcode_proofs {
// getting the number of dummy padding item that we used in this opcode circuit
let num_lks = circuit_vk.get_cs().lk_expressions.len();
let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks;
let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances)
- opcode_proof.num_instances;
dummy_table_item_multiplicity += num_padded_lks_per_instance
* opcode_proof.num_instances
+ num_lks.next_power_of_two() * num_padded_instance;

prod_r *= opcode_proof.record_r_out_evals.iter().product::<E>();
prod_w *= opcode_proof.record_w_out_evals.iter().product::<E>();

logup_sum +=
opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap();
logup_sum +=
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
}
}

for (name, (i, table_proof)) in vm_proof.table_proofs {
Expand Down Expand Up @@ -471,6 +478,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
}

// verify zero expression (degree = 1) statement, thus no sumcheck
for (expr, name) in cs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic seems some left over. probably we combine with line 492?

.assert_zero_expressions
.iter()
.zip_eq(cs.assert_zero_expressions_namespace_map.iter())
{
if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr)
!= E::ZERO
{
tracing::error!("checking zero expression {name} failed.");
}
}
if cs.assert_zero_expressions.iter().any(|expr| {
eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO
}) {
Expand Down
27 changes: 27 additions & 0 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,33 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + From<u64>> RowMajorMatrix
.chain(padding_iter)
.collect::<Vec<_>>()
}

pub fn shard_by_rows(&self, shard_rows: usize) -> Vec<Self> {
let padded_row_num = self.num_instances() + self.num_padding_instances();
if padded_row_num <= shard_rows {
return vec![self.clone()];
}
// padded_row_num and chunk_rows should both be pow of 2.
assert_eq!(padded_row_num % shard_rows, 0);
let shard_num = self.num_instances().div_ceil(shard_rows);
let mut shards = Vec::new();
for i in 0..shard_num {
let start = i * shard_rows * self.num_col;
let end = ((i + 1) * shard_rows * self.num_col).min(self.values.len());
let values: Vec<_> = self.values[start..end].to_vec();

shards.push(Self {
num_col: self.num_col,
values,
padding_strategy: self.padding_strategy.clone(),
});
}
assert_eq!(
self.num_instances(),
shards.iter().map(|c| { c.num_instances() }).sum::<usize>()
);
shards
}
}

impl<F: Field + From<u64>> RowMajorMatrix<F> {
Expand Down