diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b62a82d..f315a9d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,9 +21,11 @@ jobs: - name: Build default run: cargo build - name: Build non-zk # check without the zk feature enabled - run: cargo build --no-default-features --features backend_plonky2 + run: cargo build --no-default-features --features backend_plonky2,mem_cache - name: Build metrics run: cargo build --features metrics - name: Build time run: cargo build --features time + - name: Build disk_cache + run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache diff --git a/Cargo.toml b/Cargo.toml index e941c1e..d8a4626 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "pod2" version = "0.1.0" edition = "2021" +build = "build.rs" [lib] name = "pod2" @@ -19,7 +20,7 @@ env_logger = "0.11" lazy_static = "1.5.0" thiserror = { version = "2.0.12" } # enabled by features: -plonky2 = { git = "https://github.com/0xPolygonZero/plonky2", optional = true } +plonky2 = { git = "https://github.com/0xPARC/plonky2.git", rev = "3defd60532c8693cf5e9d2e6a8412c77ca58760f", optional = true } serde = "1.0.219" serde_json = "1.0.140" base64 = "0.22.1" @@ -32,6 +33,11 @@ rand = "0.8.5" hashbrown = { version = "0.14.3", default-features = false, features = ["serde"] } pest = "2.8.0" pest_derive = "2.8.0" +directories = { version = "6.0.0", optional = true } +minicbor-serde = { version = "0.5.0", features = ["std"], optional = true } +serde_bytes = "0.11" +serde_arrays = "0.2.0" +sha2 = { version = "0.10.9" } # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPolygonZero/plonky2"] @@ -42,10 +48,15 @@ pretty_assertions = "1.4.1" # Used only for testing JSON Schema generation and validation. jsonschema = "0.30.0" +[build-dependencies] +vergen-gitcl = { version = "1.0.0", features = ["build"] } + [features] -default = ["backend_plonky2", "zk"] +default = ["backend_plonky2", "zk", "mem_cache"] backend_plonky2 = ["plonky2"] zk = [] metrics = [] time = [] examples = [] +disk_cache = ["directories", "minicbor-serde"] +mem_cache = [] diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..9529b4a --- /dev/null +++ b/build.rs @@ -0,0 +1,21 @@ +#[cfg(feature = "mem_cache")] +fn main() {} + +#[cfg(feature = "disk_cache")] +fn main() -> Result<(), Box> { + use vergen_gitcl::{Emitter, GitclBuilder}; + // Example of injected vars: + // cargo:rustc-env=VERGEN_GIT_BRANCH=master + // cargo:rustc-env=VERGEN_GIT_COMMIT_AUTHOR_EMAIL=emitter@vergen.com + // cargo:rustc-env=VERGEN_GIT_COMMIT_AUTHOR_NAME=Jason Ozias + // cargo:rustc-env=VERGEN_GIT_COMMIT_COUNT=44 + // cargo:rustc-env=VERGEN_GIT_COMMIT_DATE=2024-01-30 + // cargo:rustc-env=VERGEN_GIT_COMMIT_MESSAGE=depsup + // cargo:rustc-env=VERGEN_GIT_COMMIT_TIMESTAMP=2024-01-30T21:43:43.000000000Z + // cargo:rustc-env=VERGEN_GIT_DESCRIBE=0.1.0-beta.1-15-g728e25c + // cargo:rustc-env=VERGEN_GIT_SHA=728e25ca5bb7edbbc505f12b28c66b2b27883cf1 + let gitcl = GitclBuilder::all_git()?; + Emitter::default().add_instructions(&gitcl)?.emit()?; + + Ok(()) +} diff --git a/examples/main_pod_points.rs b/examples/main_pod_points.rs index d7679f6..78206af 100644 --- a/examples/main_pod_points.rs +++ b/examples/main_pod_points.rs @@ -1,3 +1,4 @@ +#![allow(clippy::uninlined_format_args)] // TODO: Remove this in another PR //! Example of building main pods that verify signed pods and other main pods using custom //! predicates //! @@ -21,6 +22,7 @@ use pod2::{ }; fn main() -> Result<(), Box> { + env_logger::init(); let args: Vec = env::args().collect(); let mock = args.get(1).is_some_and(|arg1| arg1 == "--mock"); if mock { diff --git a/examples/signed_pod.rs b/examples/signed_pod.rs index ec84b4d..6b09661 100644 --- a/examples/signed_pod.rs +++ b/examples/signed_pod.rs @@ -1,3 +1,4 @@ +#![allow(clippy::uninlined_format_args)] // TODO: Remove this in another PR //! Simple example of building a signed pod and verifying it //! //! Run: `cargo run --release --example signed_pod` diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 905a9f6..f7579a0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "nightly-2025-01-20" +channel = "nightly-2025-07-02" components = ["clippy", "rustfmt"] diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 813fbfa..180d1cd 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -42,20 +42,29 @@ use std::{collections::HashMap, sync::LazyLock}; use crate::{ backends::plonky2::{ - emptypod::STANDARD_EMPTY_POD_DATA, primitives::merkletree::MerkleClaimAndProof, - DEFAULT_PARAMS, STANDARD_REC_MAIN_POD_CIRCUIT_DATA, + emptypod::cache_get_standard_empty_pod_verifier_circuit_data, + mainpod::cache_get_rec_main_pod_verifier_circuit_data, + primitives::merkletree::MerkleClaimAndProof, }, - middleware::{containers::Array, Hash, RawValue, Result, Value}, + middleware::{containers::Array, Hash, Params, RawValue, Result, Value}, }; -pub static DEFAULT_VD_SET: LazyLock = LazyLock::new(|| { - let params = &*DEFAULT_PARAMS; +pub static DEFAULT_VD_LIST: LazyLock> = LazyLock::new(|| { + let params = Params::default(); + vec![ + cache_get_rec_main_pod_verifier_circuit_data(¶ms) + .verifier_only + .clone(), + cache_get_standard_empty_pod_verifier_circuit_data() + .verifier_only + .clone(), + ] +}); - let vds = vec![ - STANDARD_REC_MAIN_POD_CIRCUIT_DATA.verifier_only.clone(), - STANDARD_EMPTY_POD_DATA.1.verifier_only.clone(), - ]; - VDSet::new(params.max_depth_mt_vds, &vds).unwrap() +pub static DEFAULT_VD_SET: LazyLock = LazyLock::new(|| { + let params = Params::default(); + let vds = &*DEFAULT_VD_LIST; + VDSet::new(params.max_depth_mt_vds, vds).unwrap() }); /// VDSet is the set of the allowed verifier_data hashes. When proving a diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index ca5bb0e..5b5db73 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -19,6 +19,7 @@ use plonky2::{ }, util::serialization::{Buffer, IoResult, Read, Write}, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -39,7 +40,7 @@ use crate::{ pub const CODE_SIZE: usize = HASH_SIZE + 2; const NUM_BITS: usize = 32; -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct ValueTarget { pub elements: [Target; VALUE_SIZE], } @@ -75,8 +76,9 @@ impl ValueTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct StatementArgTarget { + #[serde(with = "serde_arrays")] pub elements: [Target; STATEMENT_ARG_F_LEN], } @@ -128,7 +130,7 @@ impl StatementArgTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct StatementTarget { pub predicate: PredicateTarget, pub args: Vec, @@ -201,8 +203,9 @@ impl StatementTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct OperationTypeTarget { + #[serde(with = "serde_arrays")] pub elements: [Target; Params::operation_type_size()], } @@ -249,10 +252,11 @@ impl OperationTypeTarget { } // TODO: Implement Operation::to_field to determine the size of each element -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct OperationTarget { pub op_type: OperationTypeTarget, pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, + #[serde(with = "serde_arrays")] pub aux: [Target; OPERATION_AUX_F_LEN], } @@ -304,8 +308,9 @@ impl NativePredicateTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct PredicateTarget { + #[serde(with = "serde_arrays")] pub(crate) elements: [Target; Params::predicate_size()], } @@ -386,8 +391,9 @@ impl LiteralOrWildcardTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct StatementTmplArgTarget { + #[serde(with = "serde_arrays")] pub elements: [Target; Params::statement_tmpl_arg_size()], } @@ -432,7 +438,7 @@ impl StatementTmplArgTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct StatementTmplTarget { pub pred: PredicateTarget, pub args: Vec, @@ -449,7 +455,7 @@ impl StatementTmplTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateTarget { pub conjunction: BoolTarget, // len = params.max_custom_predicate_arity @@ -468,7 +474,7 @@ impl CustomPredicateTarget { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateBatchTarget { pub predicates: Vec, } @@ -500,6 +506,7 @@ impl CustomPredicateBatchTarget { } /// Custom predicate table entry +#[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateEntryTarget { pub id: HashOutTarget, pub index: Target, @@ -575,6 +582,7 @@ impl CustomPredicateEntryTarget { } // Custom predicate verification table entry +#[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateVerifyEntryTarget { pub custom_predicate_table_index: Target, pub custom_predicate: CustomPredicateEntryTarget, @@ -631,6 +639,7 @@ impl CustomPredicateVerifyEntryTarget { } /// Query for the custom predicate verification table +#[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateVerifyQueryTarget { pub statement: StatementTarget, pub op_type: OperationTypeTarget, @@ -1386,7 +1395,7 @@ impl CircuitBuilderPod for CircuitBuilder { } } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct LtMaskGenerator { pub(crate) n: Target, pub(crate) mask: Vec, diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 326a3eb..d896356 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -14,6 +14,7 @@ use plonky2::{ }, plonk::config::AlgebraicHasher, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -28,7 +29,7 @@ use crate::{ }, signedpod::{verify_signed_pod_circuit, SignedPodVerifyTarget}, }, - emptypod::{EmptyPod, STANDARD_EMPTY_POD_DATA}, + emptypod::{cache_get_standard_empty_pod_circuit_data, EmptyPod}, error::Result, mainpod::{self, pad_statement}, primitives::merkletree::{ @@ -1303,6 +1304,7 @@ fn verify_main_pod_circuit( Ok(id) } +#[derive(Clone, Serialize, Deserialize)] pub struct MainPodVerifyTarget { params: Params, vds_root: HashOutTarget, @@ -1427,9 +1429,13 @@ impl InnerCircuit for MainPodVerifyTarget { self.vd_mt_proofs[i].set_targets(pw, true, vd_mt_proof)?; } // the rest of vd_mt_proofs set them to the empty_pod vd_mt_proof - let vd_emptypod_mt_proof = input - .vds_set - .get_vds_proofs(&[STANDARD_EMPTY_POD_DATA.1.verifier_only.clone()])?; + let vd_emptypod_mt_proof = + input + .vds_set + .get_vds_proofs(&[cache_get_standard_empty_pod_circuit_data() + .1 + .verifier_only + .clone()])?; let vd_emptypod_mt_proof = vd_emptypod_mt_proof[0].clone(); for i in input.vd_mt_proofs.len()..self.vd_mt_proofs.len() { self.vd_mt_proofs[i].set_targets(pw, true, &vd_emptypod_mt_proof)?; diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index 901c06f..7bd9107 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -6,6 +6,7 @@ use plonky2::{ iop::witness::{PartialWitness, WitnessWrite}, plonk::circuit_builder::CircuitBuilder, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -71,6 +72,7 @@ pub fn verify_signed_pod_circuit( Ok(()) } +#[derive(Clone, Serialize, Deserialize)] pub struct SignedPodVerifyTarget { params: Params, id: HashOutTarget, diff --git a/src/backends/plonky2/circuits/utils.rs b/src/backends/plonky2/circuits/utils.rs index fa4c5ad..bd29127 100644 --- a/src/backends/plonky2/circuits/utils.rs +++ b/src/backends/plonky2/circuits/utils.rs @@ -21,7 +21,7 @@ use plonky2::{ /// vec![v1, v2, v3], /// )); /// ``` -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct DebugGenerator { pub(crate) name: String, pub(crate) xs: Vec, diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs index 7585f31..15f75fa 100644 --- a/src/backends/plonky2/emptypod.rs +++ b/src/backends/plonky2/emptypod.rs @@ -1,8 +1,3 @@ -use std::{ - collections::HashMap, - sync::{LazyLock, Mutex}, -}; - use itertools::Itertools; use plonky2::{ hash::hash_types::HashOutTarget, @@ -18,6 +13,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ basetypes::{Proof, C, D}, + cache_get_standard_rec_main_pod_common_circuit_data, circuits::{ common::{Flattenable, StatementTarget}, mainpod::{calculate_id_circuit, PI_OFFSET_ID}, @@ -26,8 +22,10 @@ use crate::{ error::{Error, Result}, mainpod::{self, calculate_id}, recursion::pad_circuit, - serialize_proof, DEFAULT_PARAMS, STANDARD_REC_MAIN_POD_CIRCUIT_DATA, + serialization::{CircuitDataSerializer, VerifierCircuitDataSerializer}, + serialize_proof, }, + cache::{self, CacheEntry}, middleware::{ self, AnchoredKey, Hash, Params, Pod, PodId, PodType, RecursivePod, Statement, ToFields, VDSet, Value, VerifierOnlyCircuitData, F, HASH_SIZE, KEY_TYPE, SELF, @@ -60,6 +58,7 @@ impl EmptyPodVerifyCircuit { } } +#[derive(Clone, Serialize, Deserialize)] pub struct EmptyPodVerifyTarget { vds_root: HashOutTarget, } @@ -70,7 +69,7 @@ impl EmptyPodVerifyTarget { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct EmptyPod { params: Params, id: PodId, @@ -80,11 +79,26 @@ pub struct EmptyPod { type CircuitData = circuit_data::CircuitData; -pub static STANDARD_EMPTY_POD_DATA: LazyLock<(EmptyPodVerifyTarget, CircuitData)> = - LazyLock::new(|| build().expect("successful build")); +pub fn cache_get_standard_empty_pod_circuit_data( +) -> CacheEntry<(EmptyPodVerifyTarget, CircuitDataSerializer)> { + cache::get("standard_empty_pod_circuit_data", &(), |_| { + let (target, circuit_data) = build().expect("successful build"); + (target, CircuitDataSerializer(circuit_data)) + }) + .expect("cache ok") +} + +pub fn cache_get_standard_empty_pod_verifier_circuit_data( +) -> CacheEntry { + cache::get("standard_empty_pod_verifier_circuit_data", &(), |_| { + let (_, standard_empty_pod_circuit_data) = &*cache_get_standard_empty_pod_circuit_data(); + VerifierCircuitDataSerializer(standard_empty_pod_circuit_data.verifier_data().clone()) + }) + .expect("cache ok") +} fn build() -> Result<(EmptyPodVerifyTarget, CircuitData)> { - let params = &*DEFAULT_PARAMS; + let params = Params::default(); #[cfg(not(feature = "zk"))] let config = CircuitConfig::standard_recursion_config(); @@ -96,20 +110,18 @@ fn build() -> Result<(EmptyPodVerifyTarget, CircuitData)> { params: params.clone(), } .eval(&mut builder)?; - let circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - pad_circuit(&mut builder, &circuit_data.common); + let common_circuit_data = &*cache_get_standard_rec_main_pod_common_circuit_data(); + pad_circuit(&mut builder, common_circuit_data); let data = timed!("EmptyPod build", builder.build::()); - assert_eq!(circuit_data.common, data.common); + assert_eq!(common_circuit_data.0, data.common); Ok((empty_pod_verify_target, data)) } -static EMPTY_POD_CACHE: LazyLock>> = - LazyLock::new(|| Mutex::new(HashMap::new())); - impl EmptyPod { pub fn new(params: &Params, vd_set: VDSet) -> Result { - let (empty_pod_verify_target, data) = &*STANDARD_EMPTY_POD_DATA; + let standard_empty_pod_data = cache_get_standard_empty_pod_circuit_data(); + let (empty_pod_verify_target, data) = &*standard_empty_pod_data; let mut pw = PartialWitness::::new(); empty_pod_verify_target.set_targets(&mut pw, vd_set.root())?; @@ -124,16 +136,16 @@ impl EmptyPod { }) } pub fn new_boxed(params: &Params, vd_set: VDSet) -> Box { - let default_params = &*DEFAULT_PARAMS; + let default_params = Params::default(); assert_eq!(default_params.id_params(), params.id_params()); - let empty_pod = EMPTY_POD_CACHE - .lock() - .unwrap() - .entry(vd_set.root()) - .or_insert_with(|| Self::new(params, vd_set).expect("prove EmptyPod")) - .clone(); - Box::new(empty_pod) + let empty_pod = cache::get( + "empty_pod", + &(default_params, vd_set), + |(params, vd_set)| Self::new(params, vd_set.clone()).expect("prove EmptyPod"), + ) + .expect("cache ok"); + Box::new(empty_pod.clone()) } } @@ -164,12 +176,13 @@ impl Pod for EmptyPod { .cloned() .collect_vec(); - let (_, data) = &*STANDARD_EMPTY_POD_DATA; - data.verify(ProofWithPublicInputs { - proof: self.proof.clone(), - public_inputs, - }) - .map_err(|e| Error::plonky2_proof_fail("EmptyPod", e)) + let standard_empty_pod_verifier_data = cache_get_standard_empty_pod_verifier_circuit_data(); + standard_empty_pod_verifier_data + .verify(ProofWithPublicInputs { + proof: self.proof.clone(), + public_inputs, + }) + .map_err(|e| Error::plonky2_proof_fail("EmptyPod", e)) } fn id(&self) -> PodId { @@ -193,8 +206,11 @@ impl Pod for EmptyPod { impl RecursivePod for EmptyPod { fn verifier_data(&self) -> VerifierOnlyCircuitData { - let (_, data) = &*STANDARD_EMPTY_POD_DATA; - data.verifier_only.clone() + let standard_empty_pod_verifier_circuit_data = + cache_get_standard_empty_pod_verifier_circuit_data(); + standard_empty_pod_verifier_circuit_data + .verifier_only + .clone() } fn proof(&self) -> Proof { self.proof.clone() @@ -209,8 +225,8 @@ impl RecursivePod for EmptyPod { id: PodId, ) -> Result> { let data: Data = serde_json::from_value(data)?; - let circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - let proof = deserialize_proof(&circuit_data.common, &data.proof)?; + let common_circuit_data = cache_get_standard_rec_main_pod_common_circuit_data(); + let proof = deserialize_proof(&common_circuit_data, &data.proof)?; Ok(Box::new(Self { params, id, diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 3ed8fe3..4ea7280 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -4,31 +4,34 @@ use std::{any::Any, iter, sync::Arc}; use itertools::Itertools; pub use operation::*; -use plonky2::{ - hash::poseidon::PoseidonHash, - plonk::{circuit_data::CommonCircuitData, config::Hasher}, -}; +use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher}; use serde::{Deserialize, Serialize}; pub use statement::*; use crate::{ backends::plonky2::{ - basetypes::{Proof, ProofWithPublicInputs, VerifierOnlyCircuitData, D}, + basetypes::{CircuitData, Proof, ProofWithPublicInputs, VerifierOnlyCircuitData}, + cache::{self, CacheEntry}, + cache_get_standard_rec_main_pod_common_circuit_data, circuits::mainpod::{CustomPredicateVerification, MainPodVerifyInput, MainPodVerifyTarget}, deserialize_proof, emptypod::EmptyPod, error::{Error, Result}, mock::emptypod::MockEmptyPod, primitives::merkletree::MerkleClaimAndProof, - recursion::{hash_verifier_data, RecursiveCircuit, RecursiveParams}, + recursion::{ + hash_verifier_data, prove_rec_circuit, RecursiveCircuit, RecursiveCircuitTarget, + }, + serialization::{ + CircuitDataSerializer, CommonCircuitDataSerializer, VerifierCircuitDataSerializer, + }, serialize_proof, signedpod::SignedPod, - STANDARD_REC_MAIN_POD_CIRCUIT_DATA, }, middleware::{ self, resolve_wildcard_values, value_from_op, AnchoredKey, CustomPredicateBatch, Hash, MainPodInputs, NativeOperation, OperationType, Params, Pod, PodId, PodProver, PodType, - RecursivePod, StatementArg, ToFields, VDSet, F, KEY_TYPE, SELF, + RecursivePod, StatementArg, ToFields, VDSet, KEY_TYPE, SELF, }, timed, }; @@ -434,24 +437,6 @@ impl PodProver for Prover { vd_set: &VDSet, inputs: MainPodInputs, ) -> Result> { - let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - let (main_pod_target, circuit_data) = - RecursiveCircuit::::target_and_circuit_data_padded( - params.max_input_recursive_pods, - &rec_circuit_data.common, - params, - )?; - let rec_params = RecursiveParams { - arity: params.max_input_recursive_pods, - common_data: circuit_data.common.clone(), - verifier_data: circuit_data.verifier_data(), - }; - let main_pod = RecursiveCircuit { - params: rec_params, - prover: circuit_data.prover_data(), - target: main_pod_target, - }; - let signed_pods_input: Vec = inputs .signed_pods .iter() @@ -541,9 +526,17 @@ impl PodProver for Prover { custom_predicate_batches, custom_predicate_verifications, }; + + let (main_pod_target, circuit_data) = &*cache_get_rec_main_pod_circuit_data(params); let proof_with_pis = timed!( "MainPod::prove", - main_pod.prove(&input, proofs, verifier_datas)? + prove_rec_circuit( + main_pod_target, + circuit_data, + &input, + proofs, + verifier_datas + )? ); Ok(Box::new(MainPod { @@ -570,21 +563,59 @@ pub struct MainPod { proof: Proof, } -// This is a helper function to get the CommonCircuitData necessary to decode -// a serialized proof. At some point in the future, this data may be available -// as a constant or with static initialization, but in the meantime we can -// generate it on-demand. -pub fn get_common_data(params: &Params) -> Result, Error> { - // TODO: Cache this somehow - // https://github.com/0xPARC/pod2/issues/247 - let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - let (_, circuit_data) = +pub(crate) fn rec_main_pod_circuit_data( + params: &Params, +) -> (RecursiveCircuitTarget, CircuitData) { + let rec_common_circuit_data = cache_get_standard_rec_main_pod_common_circuit_data(); + timed!( + "recursive MainPod circuit_data padded", RecursiveCircuit::::target_and_circuit_data_padded( params.max_input_recursive_pods, - &rec_circuit_data.common, + &rec_common_circuit_data, params, - )?; - Ok(circuit_data.common.clone()) + ) + .expect("calculate target_and_circuit_data_padded") + ) +} + +fn cache_get_rec_main_pod_circuit_data( + params: &Params, +) -> CacheEntry<( + RecursiveCircuitTarget, + CircuitDataSerializer, +)> { + // TODO(Edu): I believe that the standard_rec_main_pod_circuit data is the same as this when + // the params are Default: we're padding the circuit to itself, so we get the original one? + // If this is true we can deduplicate this cache entry because both rec_main_pod_circuit_data + // and standard_rec_main_pod_circuit_data are indexed by Params. This can be easily tested by + // comparing the cached artifacts on disk :) + cache::get("rec_main_pod_circuit_data", params, |params| { + let (target, circuit_data) = rec_main_pod_circuit_data(params); + (target, CircuitDataSerializer(circuit_data)) + }) + .expect("cache ok") +} + +pub fn cache_get_rec_main_pod_verifier_circuit_data( + params: &Params, +) -> CacheEntry { + cache::get("rec_main_pod_verifier_circuit_data", params, |params| { + let (_, rec_main_pod_circuit_data_padded) = &*cache_get_rec_main_pod_circuit_data(params); + VerifierCircuitDataSerializer(rec_main_pod_circuit_data_padded.verifier_data().clone()) + }) + .expect("cache ok") +} + +// This is a helper function to get the CommonCircuitData necessary to decode +// a serialized proof. +pub fn cache_get_rec_main_pod_common_circuit_data( + params: &Params, +) -> CacheEntry { + cache::get("rec_main_pod_common_circuit_data", params, |params| { + let (_, rec_main_pod_circuit_data_padded) = &*cache_get_rec_main_pod_circuit_data(params); + CommonCircuitDataSerializer(rec_main_pod_circuit_data_padded.common.clone()) + }) + .expect("cache ok") } #[derive(Serialize, Deserialize)] @@ -626,22 +657,15 @@ impl Pod for MainPod { } // 1, 3, 4, 5 verification via the zkSNARK proof - let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - // TODO: cache these artefacts - // https://github.com/0xPARC/pod2/issues/247 - let (_, circuit_data) = - RecursiveCircuit::::target_and_circuit_data_padded( - self.params.max_input_recursive_pods, - &rec_circuit_data.common, - &self.params, - )?; + let rec_main_pod_verifier_circuit_data = + &*cache_get_rec_main_pod_verifier_circuit_data(&self.params); let public_inputs = id .to_fields(&self.params) .iter() .chain(self.vd_set.root().0.iter()) .cloned() .collect_vec(); - circuit_data + rec_main_pod_verifier_circuit_data .verify(ProofWithPublicInputs { proof: self.proof.clone(), public_inputs, @@ -675,8 +699,9 @@ impl Pod for MainPod { impl RecursivePod for MainPod { fn verifier_data(&self) -> VerifierOnlyCircuitData { - let data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; - data.verifier_only.clone() + let rec_main_pod_verifier_circuit_data = + cache_get_rec_main_pod_verifier_circuit_data(&self.params); + rec_main_pod_verifier_circuit_data.verifier_only.clone() } fn proof(&self) -> Proof { self.proof.clone() @@ -691,7 +716,7 @@ impl RecursivePod for MainPod { id: PodId, ) -> Result> { let data: Data = serde_json::from_value(data)?; - let common = get_common_data(¶ms)?; + let common = cache_get_rec_main_pod_common_circuit_data(¶ms); let proof = deserialize_proof(&common, &data.proof)?; Ok(Box::new(Self { params, @@ -719,7 +744,8 @@ pub mod tests { self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, }, middleware::{ - self, containers::Set, CustomPredicateRef, NativePredicate as NP, DEFAULT_VD_SET, + self, containers::Set, CustomPredicateRef, NativePredicate as NP, DEFAULT_VD_LIST, + DEFAULT_VD_SET, }, op, }; @@ -735,7 +761,9 @@ pub mod tests { ..Default::default() }; println!("{:#?}", params); - let vd_set = &*DEFAULT_VD_SET; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); let (gov_id_builder, pay_stub_builder, sanction_list_builder) = zu_kyc_sign_pod_builders(¶ms); @@ -747,7 +775,7 @@ pub mod tests { let sanction_list_pod = sanction_list_builder.sign(&signer)?; let kyc_builder = zu_kyc_pod_builder( ¶ms, - vd_set, + &vd_set, &gov_id_pod, &pay_stub_pod, &sanction_list_pod, @@ -772,7 +800,9 @@ pub mod tests { max_input_pods_public_statements: 10, ..Default::default() }; - let vd_set = &*DEFAULT_VD_SET; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); let mut gov_id_builder = frontend::SignedPodBuilder::new(¶ms); gov_id_builder.insert("idNumber", "4242424242"); @@ -781,7 +811,7 @@ pub mod tests { let signer = Signer(SecretKey(42u64.into())); let gov_id = gov_id_builder.sign(&signer).unwrap(); let now_minus_18y: i64 = 1169909388; - let mut kyc_builder = frontend::MainPodBuilder::new(¶ms, vd_set); + let mut kyc_builder = frontend::MainPodBuilder::new(¶ms, &vd_set); kyc_builder.add_signed_pod(&gov_id); kyc_builder .pub_op(op!(lt, (&gov_id, "dateOfBirth"), now_minus_18y)) @@ -827,9 +857,11 @@ pub mod tests { max_depth_mt_containers: 4, max_depth_mt_vds: 6, }; - let vd_set = &*DEFAULT_VD_SET; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); - let pod_builder = frontend::MainPodBuilder::new(¶ms, vd_set); + let pod_builder = frontend::MainPodBuilder::new(¶ms, &vd_set); // Mock let prover = MockProver {}; @@ -889,7 +921,9 @@ pub mod tests { ..Default::default() }; println!("{:#?}", params); - let vd_set = &*DEFAULT_VD_SET; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42)); @@ -908,7 +942,7 @@ pub mod tests { let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); - let mut pod_builder = MainPodBuilder::new(¶ms, vd_set); + let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set); let st0 = pod_builder.priv_op(op!(new_entry, "score", 42))?; let st1 = pod_builder.priv_op(op!(new_entry, "key", 42))?; diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 80587f6..cedb8b0 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -69,7 +69,7 @@ impl fmt::Display for MockMainPod { for (i, st) in self.statements.iter().enumerate() { if self.params.max_input_signed_pods > 0 && (i >= offset_input_signed_pods && i < offset_input_recursive_pods) - && ((i - offset_input_signed_pods) % self.params.max_signed_pod_values == 0) + && (i - offset_input_signed_pods).is_multiple_of(self.params.max_signed_pod_values) { let index = (i - offset_input_signed_pods) / self.params.max_signed_pod_values; let pod = &self.input_signed_pods[index]; @@ -84,9 +84,8 @@ impl fmt::Display for MockMainPod { if self.params.max_input_recursive_pods > 0 && (i >= offset_input_recursive_pods) && (i < offset_input_statements) - && ((i - offset_input_recursive_pods) - % self.params.max_input_pods_public_statements - == 0) + && (i - offset_input_recursive_pods) + .is_multiple_of(self.params.max_input_pods_public_statements) { let index = (i - offset_input_recursive_pods) / self.params.max_input_pods_public_statements; diff --git a/src/backends/plonky2/mod.rs b/src/backends/plonky2/mod.rs index cbdaa9f..e15a953 100644 --- a/src/backends/plonky2/mod.rs +++ b/src/backends/plonky2/mod.rs @@ -6,39 +6,46 @@ pub mod mainpod; pub mod mock; pub mod primitives; pub mod recursion; +mod serialization; pub mod signedpod; -use std::sync::LazyLock; - use base64::{prelude::BASE64_STANDARD, Engine}; pub use error::*; use plonky2::util::serialization::{Buffer, Read}; use crate::{ backends::plonky2::{ - basetypes::{CircuitData, CommonCircuitData, Proof}, + basetypes::{CommonCircuitData, Proof}, circuits::mainpod::{MainPodVerifyTarget, NUM_PUBLIC_INPUTS}, recursion::RecursiveCircuit, + serialization::CommonCircuitDataSerializer, }, + cache::{self, CacheEntry}, middleware::Params, timed, }; -pub static DEFAULT_PARAMS: LazyLock = LazyLock::new(Params::default); - -pub static STANDARD_REC_MAIN_POD_CIRCUIT_DATA: LazyLock = LazyLock::new(|| { - let params = &*DEFAULT_PARAMS; - timed!( - "recursive MainPod circuit_data", - RecursiveCircuit::::target_and_circuit_data( - params.max_input_recursive_pods, - NUM_PUBLIC_INPUTS, - params - ) - .expect("calculate circuit_data") - .1 +pub fn cache_get_standard_rec_main_pod_common_circuit_data( +) -> CacheEntry { + let params = Params::default(); + cache::get( + "standard_rec_main_pod_common_circuit_data", + ¶ms, + |params| { + let circuit_data = timed!( + "recursive MainPod circuit_data", + RecursiveCircuit::::target_and_circuit_data( + params.max_input_recursive_pods, + NUM_PUBLIC_INPUTS, + params + ) + .expect("calculate circuit_data") + ); + CommonCircuitDataSerializer(circuit_data.1.common) + }, ) -}); + .expect("cache ok") +} pub fn serialize_bytes(bytes: &[u8]) -> String { BASE64_STANDARD.encode(bytes) diff --git a/src/backends/plonky2/primitives/ec/bits.rs b/src/backends/plonky2/primitives/ec/bits.rs index b03dc9d..66e1fb0 100644 --- a/src/backends/plonky2/primitives/ec/bits.rs +++ b/src/backends/plonky2/primitives/ec/bits.rs @@ -16,11 +16,12 @@ use plonky2::{ plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, util::serialization::{Buffer, IoResult, Read, Write}, }; +use serde::{Deserialize, Serialize}; use crate::backends::plonky2::basetypes::{D, F}; -#[derive(Debug)] -struct ConditionalZeroGenerator, const D: usize> { +#[derive(Debug, Default, Clone)] +pub(crate) struct ConditionalZeroGenerator, const D: usize> { if_zero: Target, then_zero: Target, quot: Target, @@ -78,9 +79,11 @@ impl, const D: usize> SimpleGenerator /// A big integer, represented in base `2^32` with 10 digits, in little endian /// form. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct BigUInt320Target { + #[serde(with = "serde_arrays")] pub limbs: [Target; 10], + #[serde(with = "serde_arrays")] pub bits: [BoolTarget; 320], } @@ -313,7 +316,7 @@ fn biguint_limbs_to_bits(builder: &mut CircuitBuilder, limbs: &[Target]) - Copied from https://github.com/0xPolygonZero/plonky2/blob/82791c4809d6275682c34b926390ecdbdc2a5297/plonky2/src/gadgets/range_check.rs#L62 */ -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct LowHighGenerator { integer: Target, n_log: usize, diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index 714d298..caf82c0 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -57,6 +57,7 @@ pub fn ec_field_sqrt(x: &ECField) -> Option { ]); // Compute x^((r-1)/2) = x^(p*((1+p)/2)*(1+p^2)) let x1 = x.frobenius(); + #[allow(clippy::manual_div_ceil)] let x2 = x1.exp_u64((1 + GoldilocksField::ORDER) / 2); let den = x2 * x2.repeated_frobenius(2); Some(num / den) @@ -440,7 +441,7 @@ impl Mul for &BigUint { type FieldTarget = OEFTarget<5, QuinticExtension>; -#[derive(Clone, Debug)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct PointTarget { pub x: FieldTarget, pub u: FieldTarget, @@ -470,8 +471,8 @@ impl PointTarget { } } -#[derive(Clone, Debug)] -struct PointSquareRootGenerator { +#[derive(Clone, Default, Debug)] +pub(crate) struct PointSquareRootGenerator { pub orig: PointTarget, pub sqrt: PointTarget, } diff --git a/src/backends/plonky2/primitives/ec/field.rs b/src/backends/plonky2/primitives/ec/field.rs index 4b7ac08..8234a38 100644 --- a/src/backends/plonky2/primitives/ec/field.rs +++ b/src/backends/plonky2/primitives/ec/field.rs @@ -15,6 +15,7 @@ use plonky2::{ plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, util::serialization::{Buffer, IoError, Read, Write}, }; +use serde::{Deserialize, Serialize}; //use super::gates::field::NNFMulGate; use crate::{ @@ -83,8 +84,9 @@ pub trait CircuitBuilderNNF< } /// Target type modelled on OEF. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OEFTarget> { + #[serde(with = "serde_arrays")] pub components: [Target; DEG], _phantom_data: PhantomData, } @@ -106,8 +108,8 @@ impl> Default for OEFTarget { /// Quotient generator for OEF targets. Allows us to automagically /// generate quotients as witnesses. -#[derive(Debug, Default)] -struct QuotientGeneratorOEF> { +#[derive(Debug, Default, Clone)] +pub(crate) struct QuotientGeneratorOEF> { numerator: OEFTarget, denominator: OEFTarget, quotient: OEFTarget, @@ -121,7 +123,11 @@ impl< > SimpleGenerator for QuotientGeneratorOEF { fn id(&self) -> String { - "QuotientGeneratorOEF".to_string() + format!( + "QuotientGeneratorOEF<{}, {}>", + DEG, + std::any::type_name::() + ) } fn dependencies(&self) -> Vec { diff --git a/src/backends/plonky2/primitives/ec/gates/curve.rs b/src/backends/plonky2/primitives/ec/gates/curve.rs index b405d1e..ecce193 100644 --- a/src/backends/plonky2/primitives/ec/gates/curve.rs +++ b/src/backends/plonky2/primitives/ec/gates/curve.rs @@ -25,7 +25,7 @@ use crate::backends::plonky2::primitives::ec::{ /// operation when all its witness wire values are zero (so that when the gate is partially used, /// the unused slots still pass the constraints). This is the reason why this gate doesn't add the /// final offset: if it did, the constraints wouldn't pass on the zero witness values. -#[derive(Debug, Clone)] +#[derive(Debug, Default, Clone)] pub struct ECAddHomogOffset; impl SimpleGate for ECAddHomogOffset { diff --git a/src/backends/plonky2/primitives/ec/gates/generic.rs b/src/backends/plonky2/primitives/ec/gates/generic.rs index d81c965..e6652d2 100644 --- a/src/backends/plonky2/primitives/ec/gates/generic.rs +++ b/src/backends/plonky2/primitives/ec/gates/generic.rs @@ -66,7 +66,7 @@ pub struct RecursiveGateAdapter { _gate: PhantomData, } -#[derive(Debug)] +#[derive(Debug, Default, Clone)] pub struct RecursiveGenerator { row: usize, index: usize, @@ -175,7 +175,7 @@ where G::F: RichField + Extendable + Extendable<1>, { fn id(&self) -> String { - G::ID.to_string() + format!("GateAdapter<{}>", std::any::type_name::()) } fn serialize( @@ -336,7 +336,7 @@ where } fn id(&self) -> String { - format!("Generator<{},{}>", D, G::ID) + format!("RecursiveGenerator<{}, {}>", D, std::any::type_name::()) } fn dependencies(&self) -> Vec { @@ -374,7 +374,11 @@ where F: RichField + Extendable, { fn id(&self) -> String { - format!("Recursive<{},{}>", D, G::ID) + format!( + "RecursiveGateAdapter<{}, {}>", + D, + std::any::type_name::() + ) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/src/backends/plonky2/primitives/ec/schnorr.rs b/src/backends/plonky2/primitives/ec/schnorr.rs index ca63d2e..8d2c83d 100644 --- a/src/backends/plonky2/primitives/ec/schnorr.rs +++ b/src/backends/plonky2/primitives/ec/schnorr.rs @@ -101,7 +101,7 @@ impl<'de> Deserialize<'de> for Signature { } /// Targets for Schnorr signature over ecGFp5. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SignatureTarget { pub s: BigUInt320Target, pub e: BigUInt320Target, diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index aa147f5..f9a1c74 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -23,6 +23,7 @@ use plonky2::{ }, plonk::circuit_builder::CircuitBuilder, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -35,7 +36,7 @@ use crate::{ middleware::{EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE}, }; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct MerkleClaimAndProofTarget { pub(crate) max_depth: usize, // `enabled` determines if the merkleproof verification is enabled @@ -194,6 +195,7 @@ impl MerkleClaimAndProofTarget { } } +#[derive(Clone, Serialize, Deserialize)] pub struct MerkleProofExistenceTarget { max_depth: usize, // `enabled` determines if the merkleproof verification is enabled diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index d59e406..afb118e 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -289,7 +289,7 @@ impl MerkleTree { } /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> Iter { + pub fn iter(&self) -> Iter<'_> { Iter { state: vec![&self.root], } diff --git a/src/backends/plonky2/primitives/signature/circuit.rs b/src/backends/plonky2/primitives/signature/circuit.rs index 6903be9..ff7e60d 100644 --- a/src/backends/plonky2/primitives/signature/circuit.rs +++ b/src/backends/plonky2/primitives/signature/circuit.rs @@ -20,6 +20,7 @@ use plonky2::{ proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, }, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -38,6 +39,7 @@ use crate::{ // TODO: This is a very simple wrapper over the signature verification implemented on // `SignatureTarget`. I think we can remove this and use it directly. Also we're not using the // `enabled` flag, so it should be straight-forward to remove this. +#[derive(Clone, Serialize, Deserialize)] pub struct SignatureVerifyTarget { // `enabled` determines if the signature verification is enabled pub(crate) enabled: BoolTarget, diff --git a/src/backends/plonky2/recursion/circuit.rs b/src/backends/plonky2/recursion/circuit.rs index daea999..4ed079f 100644 --- a/src/backends/plonky2/recursion/circuit.rs +++ b/src/backends/plonky2/recursion/circuit.rs @@ -32,6 +32,7 @@ use plonky2::{ }, util::log2_ceil, }; +use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ @@ -133,18 +134,59 @@ pub fn new_params_padded( /// RecursiveCircuit defines the circuit that verifies `arity` proofs. pub struct RecursiveCircuit { - pub(crate) params: RecursiveParams, pub(crate) prover: ProverCircuitData, pub(crate) target: RecursiveCircuitTarget, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct RecursiveCircuitTarget { innercircuit_targ: I, proofs_targ: Vec>, verifier_datas_targ: Vec, } +impl RecursiveCircuitTarget { + fn set_targets( + &self, + pw: &mut PartialWitness, + innercircuit_input: &I::Input, + recursive_proofs: Vec>, + verifier_datas: Vec>, + ) -> Result<()> { + let n = self.proofs_targ.len(); + assert_eq!(n, recursive_proofs.len()); + assert_eq!(n, verifier_datas.len()); + + // set the InnerCircuit related values + self.innercircuit_targ.set_targets(pw, innercircuit_input)?; + + #[allow(clippy::needless_range_loop)] + for i in 0..n { + pw.set_verifier_data_target(&self.verifier_datas_targ[i], &verifier_datas[i])?; + pw.set_proof_with_pis_target(&self.proofs_targ[i], &recursive_proofs[i])?; + } + + Ok(()) + } +} + +pub fn prove_rec_circuit( + target: &RecursiveCircuitTarget, + circuit_data: &CircuitData, + inner_inputs: &I::Input, + proofs: Vec>, + verifier_datas: Vec>, +) -> Result> { + let mut pw = PartialWitness::new(); + target.set_targets( + &mut pw, + inner_inputs, // innercircuit_input + proofs, + verifier_datas, + )?; + Ok(circuit_data.prove(pw)?) +} + impl RecursiveCircuit { pub fn prove( &self, @@ -153,7 +195,7 @@ impl RecursiveCircuit { verifier_datas: Vec>, ) -> Result> { let mut pw = PartialWitness::new(); - self.set_targets( + self.target.set_targets( &mut pw, inner_inputs, // innercircuit_input proofs, @@ -182,7 +224,6 @@ impl RecursiveCircuit { let prover: ProverCircuitData = builder.build_prover::(); Ok(Self { - params: params.clone(), prover, target: targets, }) @@ -242,31 +283,6 @@ impl RecursiveCircuit { }) } - fn set_targets( - &self, - pw: &mut PartialWitness, - innercircuit_input: &I::Input, - recursive_proofs: Vec>, - verifier_datas: Vec>, - ) -> Result<()> { - let n = recursive_proofs.len(); - assert_eq!(n, self.params.arity); - assert_eq!(n, verifier_datas.len()); - - // set the InnerCircuit related values - self.target - .innercircuit_targ - .set_targets(pw, innercircuit_input)?; - - #[allow(clippy::needless_range_loop)] - for i in 0..self.params.arity { - pw.set_verifier_data_target(&self.target.verifier_datas_targ[i], &verifier_datas[i])?; - pw.set_proof_with_pis_target(&self.target.proofs_targ[i], &recursive_proofs[i])?; - } - - Ok(()) - } - /// returns the target full-recursive circuit and its CircuitData pub fn target_and_circuit_data( arity: usize, diff --git a/src/backends/plonky2/recursion/mod.rs b/src/backends/plonky2/recursion/mod.rs index e9492a7..e4db61c 100644 --- a/src/backends/plonky2/recursion/mod.rs +++ b/src/backends/plonky2/recursion/mod.rs @@ -1,5 +1,6 @@ pub mod circuit; pub use circuit::{ common_data_for_recursion, hash_verifier_data, new_params, new_params_padded, pad_circuit, - InnerCircuit, RecursiveCircuit, RecursiveParams, VerifiedProofTarget, + prove_rec_circuit, InnerCircuit, RecursiveCircuit, RecursiveCircuitTarget, RecursiveParams, + VerifiedProofTarget, }; diff --git a/src/backends/plonky2/serialization.rs b/src/backends/plonky2/serialization.rs new file mode 100644 index 0000000..cdeb4b8 --- /dev/null +++ b/src/backends/plonky2/serialization.rs @@ -0,0 +1,252 @@ +use std::ops::Deref; + +use plonky2::{ + field::extension::quintic::QuinticExtension, + gates::{ + arithmetic_base::ArithmeticGate, arithmetic_extension::ArithmeticExtensionGate, + base_sum::BaseSumGate, constant::ConstantGate, coset_interpolation::CosetInterpolationGate, + exponentiation::ExponentiationGate, lookup::LookupGate, lookup_table::LookupTableGate, + multiplication_extension::MulExtensionGate, noop::NoopGate, poseidon::PoseidonGate, + poseidon_mds::PoseidonMdsGate, public_input::PublicInputGate, + random_access::RandomAccessGate, reducing::ReducingGate, + reducing_extension::ReducingExtensionGate, + }, + get_gate_tag_impl, impl_gate_serializer, read_gate_impl, + util::serialization::GateSerializer, +}; +use serde::{de, ser, Deserialize, Serialize}; + +use crate::backends::plonky2::{ + basetypes::{CircuitData, CommonCircuitData, VerifierCircuitData, C, D, F}, + circuits::{common::LtMaskGenerator, utils::DebugGenerator}, + primitives::ec::{ + bits::ConditionalZeroGenerator, + curve::PointSquareRootGenerator, + field::QuotientGeneratorOEF, + gates::{ + curve::ECAddHomogOffset, + field::NNFMulSimple, + generic::{GateAdapter, RecursiveGateAdapter, RecursiveGenerator}, + }, + }, +}; + +#[derive(Debug)] +pub(crate) struct Pod2GateSerializer; +impl GateSerializer for Pod2GateSerializer { + impl_gate_serializer! { + Pod2GateSerializer, + ArithmeticGate, + ArithmeticExtensionGate, + BaseSumGate<2>, + ConstantGate, + CosetInterpolationGate, + ExponentiationGate, + LookupGate, + LookupTableGate, + MulExtensionGate, + NoopGate, + PoseidonMdsGate, + PoseidonGate, + PublicInputGate, + RandomAccessGate, + ReducingExtensionGate, + ReducingGate, + // pod2 custom gates + GateAdapter::>>, + RecursiveGateAdapter::>>, + GateAdapter::, + RecursiveGateAdapter:: + } +} + +use plonky2::{ + gadgets::{ + arithmetic::EqualityGenerator, + arithmetic_extension::QuotientGeneratorExtension, + range_check::LowHighGenerator, + split_base::BaseSumGenerator, + split_join::{SplitGenerator, WireSplitGenerator}, + }, + gates::{ + arithmetic_base::ArithmeticBaseGenerator, + arithmetic_extension::ArithmeticExtensionGenerator, base_sum::BaseSplitGenerator, + coset_interpolation::InterpolationGenerator, exponentiation::ExponentiationGenerator, + lookup::LookupGenerator, lookup_table::LookupTableGenerator, + multiplication_extension::MulExtensionGenerator, poseidon::PoseidonGenerator, + poseidon_mds::PoseidonMdsGenerator, random_access::RandomAccessGenerator, + reducing::ReducingGenerator, + reducing_extension::ReducingGenerator as ReducingExtensionGenerator, + }, + get_generator_tag_impl, impl_generator_serializer, + iop::generator::{ + ConstantGenerator, CopyGenerator, NonzeroTestGenerator, RandomValueGenerator, + }, + read_generator_impl, + recursion::dummy_circuit::DummyProofGenerator, + util::serialization::WitnessGeneratorSerializer, +}; + +#[derive(Debug)] +pub(crate) struct Pod2GeneratorSerializer {} + +// TODO: Add pod2 custom generators +impl WitnessGeneratorSerializer for Pod2GeneratorSerializer { + impl_generator_serializer! { + Pod2GeneratorSerializer, + ArithmeticBaseGenerator, + ArithmeticExtensionGenerator, + BaseSplitGenerator<2>, + BaseSumGenerator<2>, + ConstantGenerator, + CopyGenerator, + DummyProofGenerator, + EqualityGenerator, + ExponentiationGenerator, + InterpolationGenerator, + LookupGenerator, + LookupTableGenerator, + LowHighGenerator, + MulExtensionGenerator, + NonzeroTestGenerator, + PoseidonGenerator, + PoseidonMdsGenerator, + QuotientGeneratorExtension, + RandomAccessGenerator, + RandomValueGenerator, + ReducingGenerator, + ReducingExtensionGenerator, + SplitGenerator, + WireSplitGenerator, + // pod2 custom generators + DebugGenerator, + LtMaskGenerator, + QuotientGeneratorOEF<5, QuinticExtension>, + PointSquareRootGenerator, + ConditionalZeroGenerator, + RecursiveGenerator>>, + RecursiveGenerator<1, NNFMulSimple<5, QuinticExtension>>, + RecursiveGenerator, + RecursiveGenerator<1, ECAddHomogOffset> + } +} + +/// Helper type to serialize and deserialize the pod2 `CircuitData` using serde traits. +#[derive(Clone)] +pub struct CircuitDataSerializer(pub(crate) CircuitData); + +impl Deref for CircuitDataSerializer { + type Target = CircuitData; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Serialize for CircuitDataSerializer { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let gate_serializer = Pod2GateSerializer {}; + let generator_serializer = Pod2GeneratorSerializer {}; + let bytes = self + .0 + .to_bytes(&gate_serializer, &generator_serializer) + .map_err(ser::Error::custom)?; + serde_bytes::ByteBuf::from(bytes).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for CircuitDataSerializer { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = <&'de serde_bytes::Bytes>::deserialize(deserializer)?; + let gate_serializer = Pod2GateSerializer {}; + let generator_serializer = Pod2GeneratorSerializer {}; + let circuit_data = CircuitData::from_bytes(bytes, &gate_serializer, &generator_serializer) + .map_err(de::Error::custom)?; + Ok(CircuitDataSerializer(circuit_data)) + } +} + +/// Helper type to serialize and deserialize the pod2 `CommonCircuitData` using serde traits. +#[derive(Clone)] +pub struct CommonCircuitDataSerializer(pub(crate) CommonCircuitData); + +impl Deref for CommonCircuitDataSerializer { + type Target = CommonCircuitData; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Serialize for CommonCircuitDataSerializer { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let gate_serializer = Pod2GateSerializer {}; + let bytes = self + .0 + .to_bytes(&gate_serializer) + .map_err(ser::Error::custom)?; + serde_bytes::ByteBuf::from(bytes).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for CommonCircuitDataSerializer { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = <&'de serde_bytes::Bytes>::deserialize(deserializer)?; + let gate_serializer = Pod2GateSerializer {}; + let circuit_data = + CommonCircuitData::from_bytes(bytes, &gate_serializer).map_err(de::Error::custom)?; + Ok(CommonCircuitDataSerializer(circuit_data)) + } +} + +/// Helper type to serialize and deserialize the pod2 `VerifierCircuitData` using serde traits. +#[derive(Clone)] +pub struct VerifierCircuitDataSerializer(pub(crate) VerifierCircuitData); + +impl Deref for VerifierCircuitDataSerializer { + type Target = VerifierCircuitData; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Serialize for VerifierCircuitDataSerializer { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let gate_serializer = Pod2GateSerializer {}; + let bytes = self + .0 + .to_bytes(&gate_serializer) + .map_err(ser::Error::custom)?; + serde_bytes::ByteBuf::from(bytes).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for VerifierCircuitDataSerializer { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes = <&'de serde_bytes::Bytes>::deserialize(deserializer)?; + + let gate_serializer = Pod2GateSerializer {}; + let circuit_data = + VerifierCircuitData::from_bytes(bytes, &gate_serializer).map_err(de::Error::custom)?; + Ok(VerifierCircuitDataSerializer(circuit_data)) + } +} diff --git a/src/cache/disk.rs b/src/cache/disk.rs new file mode 100644 index 0000000..dcec44b --- /dev/null +++ b/src/cache/disk.rs @@ -0,0 +1,120 @@ +use std::{ + fs::{create_dir_all, rename, File, TryLockError}, + io::{Error, ErrorKind, Read, Write}, + ops::Deref, + thread, time, +}; + +use directories::BaseDirs; +use serde::{de::DeserializeOwned, Serialize}; +use sha2::{Digest, Sha256}; + +pub struct CacheEntry { + value: T, +} + +impl Deref for CacheEntry { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +/// Get the artifact named `name` from the disk cache. If it doesn't exist, it will be built by +/// calling `build_fn` and stored. +/// The artifact is indexed by git commit first and then by `params: P` second. +pub(crate) fn get( + name: &str, + params: &P, + build_fn: fn(&P) -> T, +) -> Result, Box> { + let commit_hash_str = env!("VERGEN_GIT_SHA"); + let params_json = serde_json::to_string(params)?; + let params_json_hash = Sha256::digest(¶ms_json); + let params_json_hash_str_long = format!("{:x}", params_json_hash); + let params_json_hash_str = format!("{}", ¶ms_json_hash_str_long[..32]); + let log_name = format!("{}/{}/{}.cbor", commit_hash_str, params_json_hash_str, name); + log::debug!("getting {} from the disk cache", log_name); + + let base_dirs = + BaseDirs::new().ok_or(Error::new(ErrorKind::Other, "no valid home directory"))?; + let user_cache_dir = base_dirs.cache_dir(); + let pod2_cache_dir = user_cache_dir.join("pod2"); + let commit_cache_dir = pod2_cache_dir.join(&commit_hash_str); + create_dir_all(&commit_cache_dir)?; + + let cache_dir = commit_cache_dir.join(¶ms_json_hash_str); + create_dir_all(&cache_dir)?; + + // Store the params.json if it doesn't exist for better debuggability + let params_path = cache_dir.join("params.json"); + if !params_path.try_exists()? { + // First write the file to .tmp and then rename to avoid a corrupted file if we crash in + // the middle of the write. + let params_path_tmp = cache_dir.join("params.json.tmp"); + let mut file = File::create(¶ms_path_tmp)?; + file.write_all(params_json.as_bytes())?; + rename(params_path_tmp, params_path)?; + } + + let cache_path = cache_dir.join(format!("{}.cbor", name)); + let cache_path_tmp = cache_dir.join(format!("{}.cbor.tmp", name)); + + // First try to open the cached file. If it exists we assume a previous build+cache succeeded + // so we read, deserialize it and return it. + // If it doesn't exist we open a corresponding tmp file and try to acquire it exclusively. If + // we can't acquire it means another process is building the artifact so we retry again in 100 + // ms. If we acquire the lock we build the artifact store it in the tmp file and finally + // rename it to the final cached file. This way the final cached file either exists and is + // complete or doesn't exist at all (in case of a crash the corruputed file will be tmp). + + loop { + let mut file = match File::open(&cache_path) { + Ok(file) => file, + Err(err) => { + if err.kind() == ErrorKind::NotFound { + let mut file_tmp = File::create(&cache_path_tmp)?; + match file_tmp.try_lock() { + Ok(_) => (), + Err(TryLockError::WouldBlock) => { + // Lock not acquired. Another process is building the artifact, let's + // try again in 100 ms. + thread::sleep(time::Duration::from_millis(100)); + continue; + } + Err(TryLockError::Error(err)) => return Err(Box::new(err)), + } + // Exclusive lock acquired, build the artifact, serialize it and store it. + log::info!("building {} and storing to the disk cache", log_name); + let start = std::time::Instant::now(); + let data = build_fn(params); + let elapsed = std::time::Instant::now() - start; + log::debug!("built {} in {:?}", log_name, elapsed); + let data_cbor = minicbor_serde::to_vec(&data)?; + // First write the file to .tmp and then rename to avoid a corrupted file if we + // crash in the middle of the write. + file_tmp.write_all(&data_cbor)?; + rename(cache_path_tmp, cache_path)?; + return Ok(CacheEntry { value: data }); + } else { + return Err(Box::new(err)); + } + } + }; + log::debug!("found {} in the disk cache", log_name); + + let start = std::time::Instant::now(); + let mut data_cbor = Vec::new(); + file.read_to_end(&mut data_cbor)?; + let elapsed = std::time::Instant::now() - start; + log::debug!("read {} from disk in {:?}", log_name, elapsed); + + let start = std::time::Instant::now(); + let data: T = minicbor_serde::from_slice(&data_cbor)?; + let elapsed = std::time::Instant::now() - start; + log::debug!("deserialized {} in {:?}", log_name, elapsed); + + return Ok(CacheEntry { value: data }); + } +} diff --git a/src/cache/mem.rs b/src/cache/mem.rs new file mode 100644 index 0000000..6a3804b --- /dev/null +++ b/src/cache/mem.rs @@ -0,0 +1,83 @@ +use std::{ + any::Any, + collections::HashMap, + ops::Deref, + sync::{LazyLock, Mutex}, + thread, time, +}; + +use serde::{de::DeserializeOwned, Serialize}; +use sha2::{Digest, Sha256}; + +#[allow(clippy::type_complexity)] +static CACHE: LazyLock>>>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +pub struct CacheEntry { + value: &'static T, +} + +impl Deref for CacheEntry { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +/// Get the artifact named `name` from the memory cache. If it doesn't exist, it will be built by +/// calling `build_fn` and stored. +/// The artifact is indexed by `params: P`. +pub(crate) fn get( + name: &str, + params: &P, + build_fn: fn(&P) -> T, +) -> Result, Box> { + let params_json = serde_json::to_string(params)?; + let params_json_hash = Sha256::digest(¶ms_json); + let params_json_hash_str_long = format!("{:x}", params_json_hash); + let key = format!("{}/{}", ¶ms_json_hash_str_long[..32], name); + log::debug!("getting {} from the mem cache", name); + + loop { + let mut cache = CACHE.lock()?; + if let Some(entry) = cache.get(&key) { + if let Some(boxed_data) = entry { + if let Some(data) = boxed_data.downcast_ref::() { + log::debug!("found {} in the mem cache", name); + // The data is now in the heap (boxed), and will never go away because we can + // only insert into the CACHE if there's no entry, we can't delete nor update. + // Since it's not going away, not moving, and the CACHE is 'static, it's safe + // to extend the lifetime of data to 'static. + let data_static = unsafe { std::mem::transmute::<&T, &'static T>(data) }; + return Ok(CacheEntry { value: data_static }); + } else { + panic!( + "type={} doesn't match the type in the cached boxed value with name={}", + std::any::type_name::(), + name + ); + } + } else { + // Another thread is building this entry, let's retry again in 100 ms + drop(cache); // release the lock + thread::sleep(time::Duration::from_millis(100)); + continue; + } + } + // No entry in the cache, let's put a `None` to signal that we're building the + // artifact, release the lock, build the artifact and insert it. We do this to avoid + // locking for a long time. + cache.insert(key.clone(), None); + drop(cache); // release the lock + log::info!("building {} and storing to the mem cache", name); + let start = std::time::Instant::now(); + let data = build_fn(params); + let elapsed = std::time::Instant::now() - start; + log::debug!("built {} in {:?}", name, elapsed); + + CACHE.lock()?.insert(key, Some(Box::new(data))); + // Call `get` again and this time we'll retrieve the data from the cache + return get(name, params, build_fn); + } +} diff --git a/src/cache/mod.rs b/src/cache/mod.rs new file mode 100644 index 0000000..26a1caa --- /dev/null +++ b/src/cache/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "disk_cache")] +mod disk; +#[cfg(feature = "disk_cache")] +pub(crate) use disk::{get, CacheEntry}; + +#[cfg(feature = "mem_cache")] +mod mem; +#[cfg(feature = "mem_cache")] +pub(crate) use mem::{get, CacheEntry}; diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 22f8591..2090a67 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -119,7 +119,9 @@ mod tests { use super::*; use crate::{ backends::plonky2::{ - mainpod::Prover, mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, + mainpod::{rec_main_pod_circuit_data, Prover}, + mock::mainpod::MockProver, + primitives::ec::schnorr::SecretKey, signedpod::Signer, }, examples::{ @@ -130,7 +132,7 @@ mod tests { middleware::{ self, containers::{Array, Dictionary, Set}, - Params, TypedValue, DEFAULT_VD_SET, + Params, TypedValue, DEFAULT_VD_LIST, }, }; @@ -300,7 +302,9 @@ mod tests { max_input_recursive_pods: 1, ..Default::default() }; - let vd_set = &*DEFAULT_VD_SET; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); let (gov_id_builder, pay_stub_builder, sanction_list_builder) = zu_kyc_sign_pod_builders(¶ms); @@ -312,7 +316,7 @@ mod tests { let sanction_list_pod = sanction_list_builder.sign(&signer)?; let kyc_builder = zu_kyc_pod_builder( ¶ms, - vd_set, + &vd_set, &gov_id_pod, &pay_stub_pod, &sanction_list_pod, diff --git a/src/lang/parser.rs b/src/lang/parser.rs index f8995f7..bdbf535 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -40,12 +40,11 @@ mod tests { } fn assert_fails(rule: Rule, input: &str) { - match PodlangParser::parse(rule, input) { - Ok(pairs) => panic!( + if let Ok(pairs) = PodlangParser::parse(rule, input) { + panic!( "Expected parse to fail, but it succeeded. Parsed:\n{:#?}", pairs - ), - Err(_) => (), // Failed as expected + ) } } diff --git a/src/lang/processor.rs b/src/lang/processor.rs index 3ac244e..7024273 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -975,7 +975,7 @@ mod processor_tests { middleware::Params, }; - fn get_document_content_pairs(input: &str) -> Result, ProcessorError> { + fn get_document_content_pairs(input: &str) -> Result, ProcessorError> { let full_parse_tree = parse_podlang(input) .map_err(|e| ProcessorError::Internal(format!("Test parsing failed: {:?}", e)))?; diff --git a/src/lib.rs b/src/lib.rs index 11b20bc..4b7de65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,11 @@ #![allow(clippy::get_first)] -#![feature(trait_upcasting)] +#![allow(clippy::uninlined_format_args)] // TODO: Remove this in another PR +#![allow(clippy::manual_repeat_n)] // TODO: Remove this in another PR +#![allow(clippy::large_enum_variant)] // TODO: Remove this in another PR #![feature(mapped_lock_guards)] pub mod backends; +mod cache; pub mod frontend; pub mod lang; pub mod middleware; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 2c7f53a..8b43ab7 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -358,7 +358,7 @@ impl Eq for Value {} impl PartialOrd for Value { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.raw.cmp(&other.raw)) + Some(self.cmp(other)) } }