diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b5d6486..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index 640ba97..88ec6ff 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ db/ *.lock *.tgz .vscode/ + +.DS_Store \ No newline at end of file diff --git a/conf/config.json b/conf/config.json deleted file mode 100644 index 657207b..0000000 --- a/conf/config.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "prover_name_prefix": "cloud_prover_", - "keys_dir": "keys", - "coordinator": { - "base_url": "https://coordinator-api.scrollsdk", - "retry_count": 3, - "retry_wait_time_sec": 5, - "connection_timeout_sec": 60 - }, - "l2geth": { - "endpoint": "https://l2-rpc.scrollsdk" - }, - "prover": { - "circuit_type": 3, - "circuit_version": "v0.13.1", - "n_workers": 1, - "cloud": { - "base_url": "", - "api_key": "", - "retry_count": 3, - "retry_wait_time_sec": 5, - "connection_timeout_sec": 60 - } - }, - "db_path": "db" -} diff --git a/conf/config_cloud.json b/conf/config_cloud.json new file mode 100644 index 0000000..baf5777 --- /dev/null +++ b/conf/config_cloud.json @@ -0,0 +1,27 @@ +{ + "prover_name_prefix": "cloud_prover_", + "keys_dir": "keys", + "coordinator": { + "base_url": "https://coordinator-api.scrollsdk", + "retry_count": 3, + "retry_wait_time_sec": 5, + "connection_timeout_sec": 60 + }, + "l2geth": { + "endpoint": "https://l2-rpc.scrollsdk" + }, + "prover": { + "circuit_type": [1,2,3], + "circuit_version": "v0.13.1", + "n_workers": 1, + "cloud": { + "base_url": "", + "api_key": "", + "retry_count": 3, + "retry_wait_time_sec": 5, + "connection_timeout_sec": 60 + } + }, + "db_path": "db" + } + \ No newline at end of file diff --git a/conf/config_local.json b/conf/config_local.json new file mode 100644 index 0000000..6fc1b62 --- /dev/null +++ b/conf/config_local.json @@ -0,0 +1,31 @@ +{ + "prover_name_prefix": "local_prover_", + "keys_dir": "keys", + "coordinator": { + "base_url": "https://coordinator-api.scrollsdk", + "retry_count": 3, + "retry_wait_time_sec": 5, + "connection_timeout_sec": 60 + }, + "l2geth": { + "endpoint": "https://l2-rpc.scrollsdk" + }, + "prover": { + "circuit_type": [1,2,3], + "circuit_version": "v0.13.1", + "local": { + "low_version_circuit": { + "hard_fork_name": "bernoulli", + "params_path": "params", + "assets_path": "assets" + }, + "high_version_circuit": { + "hard_fork_name": "curie", + "params_path": "params", + "assets_path": "assets" + } + } + }, + "db_path": "db" + } + \ No newline at end of file diff --git a/examples/cloud.rs b/examples/cloud.rs index 8150089..a2d8116 100644 --- a/examples/cloud.rs +++ b/examples/cloud.rs @@ -32,7 +32,7 @@ impl ProvingService for CloudProver { fn is_local(&self) -> bool { false } - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse { + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse { todo!() } async fn prove(&self, req: ProveRequest) -> ProveResponse { diff --git a/examples/local.rs b/examples/local.rs index 0757a0e..7b020b2 100644 --- a/examples/local.rs +++ b/examples/local.rs @@ -28,7 +28,7 @@ impl ProvingService for LocalProver { fn is_local(&self) -> bool { true } - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse { + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse { todo!() } async fn prove(&self, req: ProveRequest) -> ProveResponse { @@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> { init_tracing(); let args = Args::parse(); - let cfg: Config = Config::from_file(args.config_file)?; + let cfg: Config = Config::from_file_and_env(args.config_file)?; let local_prover = LocalProver::new(cfg.prover.local.clone().unwrap()); let prover = ProverBuilder::new(cfg) .with_proving_service(Box::new(local_prover)) diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index 595abd6..0000000 Binary files a/src/.DS_Store and /dev/null differ diff --git a/src/config.rs b/src/config.rs index 8829b02..6910b53 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,5 @@ use crate::prover::CircuitType; +use anyhow::{anyhow, Result}; use dotenv::dotenv; use serde::{Deserialize, Serialize}; use serde_json; @@ -16,10 +17,6 @@ pub struct Config { pub health_listener_addr: String, } -fn default_health_listener_addr() -> String { - "0.0.0.0:80".to_string() -} - #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CoordinatorConfig { pub base_url: String, @@ -35,8 +32,9 @@ pub struct L2GethConfig { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ProverConfig { - pub circuit_type: CircuitType, + pub circuit_types: Vec, pub circuit_version: String, + #[serde(default = "default_n_workers")] pub n_workers: usize, pub cloud: Option, pub local: Option, @@ -53,45 +51,58 @@ pub struct CloudProverConfig { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct LocalProverConfig { - // TODO: - // params path - // assets path - // DB config + pub low_version_circuit: CircuitConfig, + pub high_version_circuit: CircuitConfig, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CircuitConfig { + pub hard_fork_name: String, + pub params_path: String, + pub assets_path: String, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct DbConfig {} +fn default_health_listener_addr() -> String { + "0.0.0.0:80".to_string() +} + +fn default_n_workers() -> usize { + 1 +} + impl Config { - pub fn from_reader(reader: R) -> anyhow::Result + pub fn from_reader(reader: R) -> Result where R: std::io::Read, { - serde_json::from_reader(reader).map_err(|e| anyhow::anyhow!(e)) + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) } - pub fn from_file(file_name: String) -> anyhow::Result { + pub fn from_file(file_name: String) -> Result { let file = File::open(file_name)?; Config::from_reader(&file) } - pub fn from_file_and_env(file_name: String) -> anyhow::Result { + pub fn from_file_and_env(file_name: String) -> Result { let mut cfg = Config::from_file(file_name)?; cfg.override_with_env()?; Ok(cfg) } - fn get_env_var(key: &str) -> anyhow::Result> { - Ok(std::env::var_os(key) + fn get_env_var(key: &str) -> Result> { + std::env::var_os(key) .map(|val| { val.to_str() - .ok_or_else(|| anyhow::anyhow!("{key} env var is not valid UTF-8")) + .ok_or_else(|| anyhow!("{key} env var is not valid UTF-8")) .map(String::from) }) - .transpose()?) + .transpose() } - fn override_with_env(&mut self) -> anyhow::Result<()> { + fn override_with_env(&mut self) -> Result<()> { dotenv().ok(); if let Some(val) = Self::get_env_var("PROVER_NAME_PREFIX")? { @@ -108,11 +119,22 @@ impl Config { l2geth.endpoint = val; } } - if let Some(val) = Self::get_env_var("CIRCUIT_TYPE")? { - self.prover.circuit_type = CircuitType::from_u8(val.parse()?); - } - if let Some(val) = Self::get_env_var("N_WORKERS")? { - self.prover.n_workers = val.parse()?; + if let Some(val) = Self::get_env_var("CIRCUIT_TYPES")? { + let values_vec: Vec<&str> = val + .trim_matches(|c| c == '[' || c == ']') + .split(',') + .map(|s| s.trim()) + .collect(); + + self.prover.circuit_types = values_vec + .iter() + .map(|value| match value.parse::() { + Ok(num) => CircuitType::from_u8(num), + Err(e) => { + panic!("Failed to parse circuit type: {}", e); + } + }) + .collect::>(); } if let Some(val) = Self::get_env_var("PROVING_SERVICE_BASE_URL")? { if let Some(cloud) = &mut self.prover.cloud { @@ -131,3 +153,17 @@ impl Config { Ok(()) } } + +impl LocalProverConfig { + pub fn from_reader(reader: R) -> Result + where + R: std::io::Read, + { + serde_json::from_reader(reader).map_err(|e| anyhow!(e)) + } + + pub fn from_file(file_name: String) -> Result { + let file = File::open(file_name)?; + LocalProverConfig::from_reader(&file) + } +} diff --git a/src/coordinator_handler/coordinator_client.rs b/src/coordinator_handler/coordinator_client.rs index 82c8cbb..de216d4 100644 --- a/src/coordinator_handler/coordinator_client.rs +++ b/src/coordinator_handler/coordinator_client.rs @@ -6,9 +6,8 @@ use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version}; use tokio::sync::{Mutex, MutexGuard}; pub struct CoordinatorClient { - circuit_type: CircuitType, + circuit_types: Vec, vks: Vec, - circuit_version: String, pub prover_name: String, pub key_signer: KeySigner, api: Api, @@ -18,17 +17,15 @@ pub struct CoordinatorClient { impl CoordinatorClient { pub fn new( cfg: CoordinatorConfig, - circuit_type: CircuitType, + circuit_types: Vec, vks: Vec, - circuit_version: String, prover_name: String, key_signer: KeySigner, ) -> anyhow::Result { let api = Api::new(cfg)?; let client = Self { - circuit_type, + circuit_types, vks, - circuit_version, prover_name, key_signer, api, @@ -107,15 +104,20 @@ impl CoordinatorClient { .as_ref() .ok_or_else(|| anyhow::anyhow!("Missing challenge token"))?; - let prover_types = match self.circuit_type { - CircuitType::Batch | CircuitType::Bundle => vec![CircuitType::Batch], // to conform to coordinator logic - _ => vec![self.circuit_type], - }; + let mut prover_types = vec![]; + if self.circuit_types.contains(&CircuitType::Bundle) + || self.circuit_types.contains(&CircuitType::Batch) + { + prover_types.push(CircuitType::Batch) + } + if self.circuit_types.contains(&CircuitType::Chunk) { + prover_types.push(CircuitType::Chunk) + } let login_message = LoginMessage { challenge: login_response_data.token.clone(), prover_name: self.prover_name.clone(), - prover_version: get_version(&self.circuit_version).to_string(), + prover_version: get_version().to_string(), prover_types, vks: self.vks.clone(), }; diff --git a/src/db.rs b/src/db.rs index ad90f76..c1785ac 100644 --- a/src/db.rs +++ b/src/db.rs @@ -19,16 +19,14 @@ impl Db { .get(fmt_coordinator_task_key(public_key)) .ok()? .as_ref() - .map(|v| serde_json::from_slice(v).ok()) - .flatten() + .and_then(|v| serde_json::from_slice(v).ok()) } pub fn get_proving_task_id_by_public_key(&self, public_key: String) -> Option { self.db .get(fmt_proving_task_id_key(public_key)) .ok()? - .map(|v| String::from_utf8(v).ok()) - .flatten() + .and_then(|v| String::from_utf8(v).ok()) } pub fn set_coordinator_task_by_public_key( diff --git a/src/prover/builder.rs b/src/prover/builder.rs index 8a9cb4a..6a74764 100644 --- a/src/prover/builder.rs +++ b/src/prover/builder.rs @@ -40,19 +40,20 @@ impl ProverBuilder { anyhow::bail!("cannot use multiple workers with local proving service"); } - if self.cfg.prover.circuit_type == CircuitType::Chunk && self.cfg.l2geth.is_none() { + if self.cfg.prover.circuit_types.contains(&CircuitType::Chunk) && self.cfg.l2geth.is_none() + { anyhow::bail!("circuit_type is chunk but l2geth config is not provided"); } let get_vk_request = GetVkRequest { - circuit_type: self.cfg.prover.circuit_type, + circuit_types: self.cfg.prover.circuit_types.clone(), circuit_version: self.cfg.prover.circuit_version.clone(), }; let get_vk_response = self .proving_service .as_ref() .unwrap() - .get_vk(get_vk_request) + .get_vks(get_vk_request) .await; if let Some(error) = get_vk_response.error { anyhow::bail!("failed to get vk: {}", error); @@ -69,12 +70,17 @@ impl ProverBuilder { let coordinator_clients: Result, _> = (0..self.cfg.prover.n_workers) .map(|i| { + let prover_name = if self.proving_service.as_ref().unwrap().is_local() { + self.cfg.prover_name_prefix.clone() + } else { + format!("{}{}", self.cfg.prover_name_prefix, i) + }; + CoordinatorClient::new( self.cfg.coordinator.clone(), - self.cfg.prover.circuit_type, - vec![get_vk_response.vk.clone()], - self.cfg.prover.circuit_version.clone(), - format!("{}{}", self.cfg.prover_name_prefix, i), + self.cfg.prover.circuit_types.clone(), + get_vk_response.vks.clone(), + prover_name, key_signers[i].clone(), ) }) @@ -91,7 +97,7 @@ impl ProverBuilder { }); Ok(Prover { - circuit_type: self.cfg.prover.circuit_type, + circuit_types: self.cfg.prover.circuit_types.clone(), circuit_version: self.cfg.prover.circuit_version, coordinator_clients, l2geth_client, diff --git a/src/prover/mod.rs b/src/prover/mod.rs index 9c8ce73..28b5d45 100644 --- a/src/prover/mod.rs +++ b/src/prover/mod.rs @@ -22,7 +22,7 @@ pub use {builder::ProverBuilder, proving_service::ProvingService, types::*}; const WORKER_SLEEP_SEC: u64 = 20; pub struct Prover { - circuit_type: CircuitType, + circuit_types: Vec, circuit_version: String, coordinator_clients: Vec, l2geth_client: Option, @@ -35,7 +35,7 @@ pub struct Prover { impl Prover { pub async fn run(self) { assert!(self.n_workers == self.coordinator_clients.len()); - if self.circuit_type == CircuitType::Chunk { + if self.circuit_types.contains(&CircuitType::Chunk) { assert!(self.l2geth_client.is_some()); } @@ -276,7 +276,7 @@ impl Prover { }; GetTaskRequest { - task_types: vec![self.circuit_type], + task_types: self.circuit_types.clone(), prover_height, } } @@ -286,9 +286,9 @@ impl Prover { task: &GetTaskResponseData, ) -> anyhow::Result { anyhow::ensure!( - task.task_type == self.circuit_type, - "task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", - self.circuit_type, + self.circuit_types.contains(&task.task_type), + "unsupported task type. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", + self.circuit_types, task.task_type, task.uuid, task.task_id diff --git a/src/prover/proving_service.rs b/src/prover/proving_service.rs index 2625828..4eea80e 100644 --- a/src/prover/proving_service.rs +++ b/src/prover/proving_service.rs @@ -4,21 +4,22 @@ use async_trait::async_trait; #[async_trait] pub trait ProvingService { fn is_local(&self) -> bool; - async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse; + async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse; async fn prove(&self, req: ProveRequest) -> ProveResponse; async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse; } pub struct GetVkRequest { - pub circuit_type: CircuitType, + pub circuit_types: Vec, pub circuit_version: String, } pub struct GetVkResponse { - pub vk: String, + pub vks: Vec, pub error: Option, } +#[derive(Clone)] pub struct ProveRequest { pub circuit_type: CircuitType, pub circuit_version: String, diff --git a/src/prover/types.rs b/src/prover/types.rs index e00e912..ba0bf19 100644 --- a/src/prover/types.rs +++ b/src/prover/types.rs @@ -1,7 +1,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum CircuitType { + #[default] Undefined, Chunk, Batch, diff --git a/src/utils.rs b/src/utils.rs index 760e018..c67643e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,9 +1,22 @@ use tracing_subscriber::filter::{EnvFilter, LevelFilter}; -const SDK_VERSION: &str = env!("CARGO_PKG_VERSION"); +use std::cell::OnceCell; -pub fn get_version(circuit_version: &str) -> String { - format!("sdk-v{}-{}", SDK_VERSION, circuit_version) +static DEFAULT_COMMIT: &str = "unknown"; +static mut VERSION: OnceCell = OnceCell::new(); + +pub const TAG: &str = "v0.0.0"; +pub const DEFAULT_ZK_VERSION: &str = "000000-000000"; + +fn init_version() -> String { + let commit = option_env!("GIT_REV").unwrap_or(DEFAULT_COMMIT); + let tag = option_env!("GO_TAG").unwrap_or(TAG); + let zk_version = option_env!("ZK_VERSION").unwrap_or(DEFAULT_ZK_VERSION); + format!("{tag}-{commit}-{zk_version}") +} + +pub fn get_version() -> String { + unsafe { VERSION.get_or_init(init_version).clone() } } pub fn init_tracing() {