Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support local prover && multiple task types #54

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
14 changes: 13 additions & 1 deletion conf/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"endpoint": "https://l2-rpc.scrollsdk"
},
"prover": {
"circuit_type": 3,
"circuit_type": [3],
"circuit_version": "v0.13.1",
"n_workers": 1,
"cloud": {
Expand All @@ -20,6 +20,18 @@
"retry_count": 3,
"retry_wait_time_sec": 5,
"connection_timeout_sec": 60
},
"local": {
"low_version_circuit": {
"hard_fork_name": "bernoulli",
"params_path": "params",
"assets_path": "assets"
},
"high_version_circuit": {
"hard_fork_name": "curie",
"params_path": "params",
"assets_path": "assets"
}
}
},
"db_path": "db"
Expand Down
2 changes: 1 addition & 1 deletion examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl ProvingService for CloudProver {
fn is_local(&self) -> bool {
false
}
async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
async fn prove(&self, req: ProveRequest) -> ProveResponse {
Expand Down
2 changes: 1 addition & 1 deletion examples/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl ProvingService for LocalProver {
fn is_local(&self) -> bool {
true
}
async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
async fn prove(&self, req: ProveRequest) -> ProveResponse {
Expand Down
112 changes: 98 additions & 14 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::prover::CircuitType;
use anyhow::{bail, Result};
use dotenv::dotenv;
use serde::{Deserialize, Serialize};
use serde_json;
Expand Down Expand Up @@ -35,7 +36,7 @@ pub struct L2GethConfig {

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProverConfig {
pub circuit_type: CircuitType,
pub circuit_types: Vec<CircuitType>,
pub circuit_version: String,
pub n_workers: usize,
pub cloud: Option<CloudProverConfig>,
Expand All @@ -53,45 +54,50 @@ pub struct CloudProverConfig {

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LocalProverConfig {
// TODO:
// params path
// assets path
// DB config
pub low_version_circuit: CircuitConfig,
pub high_version_circuit: CircuitConfig,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CircuitConfig {
pub hard_fork_name: String,
pub params_path: String,
pub assets_path: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DbConfig {}

impl Config {
pub fn from_reader<R>(reader: R) -> anyhow::Result<Self>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better keep anyhow::Result, rust Instructs us to write it this way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean change anyhow::Result<Self> to Result<Self>?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, still keep anyhow::Result

pub fn from_reader<R>(reader: R) -> Result<Self>
where
R: std::io::Read,
{
serde_json::from_reader(reader).map_err(|e| anyhow::anyhow!(e))
}

pub fn from_file(file_name: String) -> anyhow::Result<Self> {
pub fn from_file(file_name: String) -> Result<Self> {
let file = File::open(file_name)?;
Config::from_reader(&file)
}

pub fn from_file_and_env(file_name: String) -> anyhow::Result<Self> {
pub fn from_file_and_env(file_name: String) -> Result<Self> {
let mut cfg = Config::from_file(file_name)?;
cfg.override_with_env()?;
Ok(cfg)
}

fn get_env_var(key: &str) -> anyhow::Result<Option<String>> {
Ok(std::env::var_os(key)
fn get_env_var(key: &str) -> Result<Option<String>> {
std::env::var_os(key)
.map(|val| {
val.to_str()
.ok_or_else(|| anyhow::anyhow!("{key} env var is not valid UTF-8"))
.map(String::from)
})
.transpose()?)
.transpose()
}

fn override_with_env(&mut self) -> anyhow::Result<()> {
fn override_with_env(&mut self) -> Result<()> {
dotenv().ok();

if let Some(val) = Self::get_env_var("PROVER_NAME_PREFIX")? {
Expand All @@ -108,8 +114,22 @@ impl Config {
l2geth.endpoint = val;
}
}
if let Some(val) = Self::get_env_var("CIRCUIT_TYPE")? {
self.prover.circuit_type = CircuitType::from_u8(val.parse()?);
if let Some(val) = Self::get_env_var("CIRCUIT_TYPES")? {
let values_vec: Vec<&str> = val
.trim_matches(|c| c == '[' || c == ']')
.split(',')
yiweichi marked this conversation as resolved.
Show resolved Hide resolved
.collect();

self.prover.circuit_types = values_vec
.iter()
.map(move |value| match value.parse::<u8>() {
yiweichi marked this conversation as resolved.
Show resolved Hide resolved
Ok(num) => CircuitType::from_u8(num),
Err(e) => {
eprintln!("Failed to parse circuit type: {}", e);
std::process::exit(1);
yiweichi marked this conversation as resolved.
Show resolved Hide resolved
}
})
.collect::<Vec<CircuitType>>();
}
if let Some(val) = Self::get_env_var("N_WORKERS")? {
self.prover.n_workers = val.parse()?;
Expand All @@ -131,3 +151,67 @@ impl Config {
Ok(())
}
}

impl LocalProverConfig {
pub fn from_reader<R>(reader: R) -> Result<Self>
where
R: std::io::Read,
{
serde_json::from_reader(reader).map_err(|e| anyhow::anyhow!(e))
}

pub fn from_file(file_name: String) -> Result<Self> {
let file = File::open(file_name)?;
LocalProverConfig::from_reader(&file)
}
}

static SCROLL_PROVER_ASSETS_DIR_ENV_NAME: &str = "SCROLL_PROVER_ASSETS_DIR";
static mut SCROLL_PROVER_ASSETS_DIRS: Vec<String> = vec![];

#[derive(Debug)]
pub struct AssetsDirEnvConfig {}
yiweichi marked this conversation as resolved.
Show resolved Hide resolved

impl AssetsDirEnvConfig {
pub fn init() -> Result<()> {
let value = std::env::var(SCROLL_PROVER_ASSETS_DIR_ENV_NAME)?;
let dirs: Vec<&str> = value.split(',').collect();
if dirs.len() != 2 {
bail!("env variable SCROLL_PROVER_ASSETS_DIR value must be 2 parts seperated by comma.")
}
unsafe {
SCROLL_PROVER_ASSETS_DIRS = dirs.into_iter().map(|s| s.to_string()).collect();
log::info!(
"init SCROLL_PROVER_ASSETS_DIRS: {:?}",
SCROLL_PROVER_ASSETS_DIRS
);
}
Ok(())
}

pub fn enable_first() {
unsafe {
log::info!(
"set env {SCROLL_PROVER_ASSETS_DIR_ENV_NAME} to {}",
&SCROLL_PROVER_ASSETS_DIRS[0]
);
std::env::set_var(
SCROLL_PROVER_ASSETS_DIR_ENV_NAME,
&SCROLL_PROVER_ASSETS_DIRS[0],
);
}
}

pub fn enable_second() {
unsafe {
log::info!(
"set env {SCROLL_PROVER_ASSETS_DIR_ENV_NAME} to {}",
&SCROLL_PROVER_ASSETS_DIRS[1]
);
std::env::set_var(
SCROLL_PROVER_ASSETS_DIR_ENV_NAME,
&SCROLL_PROVER_ASSETS_DIRS[1],
);
}
}
}
19 changes: 12 additions & 7 deletions src/coordinator_handler/coordinator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version};
use tokio::sync::{Mutex, MutexGuard};

pub struct CoordinatorClient {
circuit_type: CircuitType,
circuit_types: Vec<CircuitType>,
vks: Vec<String>,
circuit_version: String,
pub prover_name: String,
Expand All @@ -18,15 +18,15 @@ pub struct CoordinatorClient {
impl CoordinatorClient {
pub fn new(
cfg: CoordinatorConfig,
circuit_type: CircuitType,
circuit_types: Vec<CircuitType>,
vks: Vec<String>,
circuit_version: String,
prover_name: String,
key_signer: KeySigner,
) -> anyhow::Result<Self> {
let api = Api::new(cfg)?;
let client = Self {
circuit_type,
circuit_types,
vks,
circuit_version,
prover_name,
Expand Down Expand Up @@ -107,10 +107,15 @@ impl CoordinatorClient {
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Missing challenge token"))?;

let prover_types = match self.circuit_type {
CircuitType::Batch | CircuitType::Bundle => vec![CircuitType::Batch], // to conform to coordinator logic
_ => vec![self.circuit_type],
};
let mut prover_types = vec![];
if self.circuit_types.contains(&CircuitType::Bundle)
|| self.circuit_types.contains(&CircuitType::Batch)
{
prover_types.push(CircuitType::Batch)
}
if self.circuit_types.contains(&CircuitType::Chunk) {
prover_types.push(CircuitType::Chunk)
}

let login_message = LoginMessage {
challenge: login_response_data.token.clone(),
Expand Down
6 changes: 2 additions & 4 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ impl Db {
.get(fmt_coordinator_task_key(public_key))
.ok()?
.as_ref()
.map(|v| serde_json::from_slice(v).ok())
.flatten()
.and_then(|v| serde_json::from_slice(v).ok())
}

pub fn get_proving_task_id_by_public_key(&self, public_key: String) -> Option<String> {
self.db
.get(fmt_proving_task_id_key(public_key))
.ok()?
.map(|v| String::from_utf8(v).ok())
.flatten()
.and_then(|v| String::from_utf8(v).ok())
}

pub fn set_coordinator_task_by_public_key(
Expand Down
14 changes: 8 additions & 6 deletions src/prover/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ impl ProverBuilder {
anyhow::bail!("cannot use multiple workers with local proving service");
}

if self.cfg.prover.circuit_type == CircuitType::Chunk && self.cfg.l2geth.is_none() {
if self.cfg.prover.circuit_types.contains(&CircuitType::Chunk) && self.cfg.l2geth.is_none()
{
anyhow::bail!("circuit_type is chunk but l2geth config is not provided");
}

let get_vk_request = GetVkRequest {
circuit_type: self.cfg.prover.circuit_type,
circuit_types: self.cfg.prover.circuit_types.clone(),
circuit_version: self.cfg.prover.circuit_version.clone(),
};
let get_vk_response = self
.proving_service
.as_ref()
.unwrap()
.get_vk(get_vk_request)
.get_vks(get_vk_request)
.await;
if let Some(error) = get_vk_response.error {
anyhow::bail!("failed to get vk: {}", error);
Expand All @@ -67,12 +68,13 @@ impl ProverBuilder {
let key_signers =
key_signers.map_err(|e| anyhow::anyhow!("cannot create key_signer, err: {e}"))?;

let circuit_types_cloned = self.cfg.prover.circuit_types.clone();
yiweichi marked this conversation as resolved.
Show resolved Hide resolved
let coordinator_clients: Result<Vec<_>, _> = (0..self.cfg.prover.n_workers)
.map(|i| {
CoordinatorClient::new(
self.cfg.coordinator.clone(),
self.cfg.prover.circuit_type,
vec![get_vk_response.vk.clone()],
circuit_types_cloned.clone(),
get_vk_response.vks.clone(),
self.cfg.prover.circuit_version.clone(),
format!("{}{}", self.cfg.prover_name_prefix, i),
key_signers[i].clone(),
Expand All @@ -91,7 +93,7 @@ impl ProverBuilder {
});

Ok(Prover {
circuit_type: self.cfg.prover.circuit_type,
circuit_types: self.cfg.prover.circuit_types.clone(),
circuit_version: self.cfg.prover.circuit_version,
coordinator_clients,
l2geth_client,
Expand Down
12 changes: 6 additions & 6 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};
const WORKER_SLEEP_SEC: u64 = 20;

pub struct Prover {
circuit_type: CircuitType,
circuit_types: Vec<CircuitType>,
circuit_version: String,
coordinator_clients: Vec<CoordinatorClient>,
l2geth_client: Option<L2gethClient>,
Expand All @@ -35,7 +35,7 @@ pub struct Prover {
impl Prover {
pub async fn run(self) {
assert!(self.n_workers == self.coordinator_clients.len());
if self.circuit_type == CircuitType::Chunk {
if self.circuit_types.contains(&CircuitType::Chunk) {
assert!(self.l2geth_client.is_some());
}

Expand Down Expand Up @@ -276,7 +276,7 @@ impl Prover {
};

GetTaskRequest {
task_types: vec![self.circuit_type],
task_types: self.circuit_types.clone(),
prover_height,
}
}
Expand All @@ -286,9 +286,9 @@ impl Prover {
task: &GetTaskResponseData,
) -> anyhow::Result<ProveRequest> {
anyhow::ensure!(
task.task_type == self.circuit_type,
"task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}",
self.circuit_type,
self.circuit_types.contains(&task.task_type),
"unsupported task type. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}",
self.circuit_types,
task.task_type,
task.uuid,
task.task_id
Expand Down
7 changes: 4 additions & 3 deletions src/prover/proving_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@ use async_trait::async_trait;
#[async_trait]
pub trait ProvingService {
fn is_local(&self) -> bool;
async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse;
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse;
async fn prove(&self, req: ProveRequest) -> ProveResponse;
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse;
}

pub struct GetVkRequest {
pub circuit_type: CircuitType,
pub circuit_types: Vec<CircuitType>,
pub circuit_version: String,
}

pub struct GetVkResponse {
pub vk: String,
pub vks: Vec<String>,
pub error: Option<String>,
}

#[derive(Clone)]
pub struct ProveRequest {
pub circuit_type: CircuitType,
pub circuit_version: String,
Expand Down
Loading
Loading