From 88a75986b88a2ad413e4d6f78f1822d4bdc44399 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Thu, 29 May 2025 17:10:19 +0200 Subject: [PATCH] Integrate recursion into MainPod (#243) * calculate MainPod id in a dynamic-friendly way The MainPod id is now calculated with front padding and a fixed size independent of max_public_statements so that introduction gadgets can be verified by a MainPod while paying only for the number of statements they use. This is because with front padding of none-statements we can precompute the poseidon state corresponding to absorbing all the padding statements and only pay constraints for the non-padding statements. The id is calculated as follows: `id = hash(serialize(reverse(statements || none-statements)))` * add time feature and disable timing by default * apply suggestions from @arnaucube * link issues in todos --- .github/workflows/build.yml | 2 + Cargo.toml | 1 + src/backends/plonky2/basetypes.rs | 21 +- src/backends/plonky2/circuits/common.rs | 70 +- src/backends/plonky2/circuits/mainpod.rs | 278 ++++-- src/backends/plonky2/circuits/signedpod.rs | 2 +- src/backends/plonky2/emptypod.rs | 207 +++++ src/backends/plonky2/mainpod/mod.rs | 261 +++--- src/backends/plonky2/mainpod/statement.rs | 14 +- src/backends/plonky2/mock/emptypod.rs | 93 ++ src/backends/plonky2/mock/mainpod.rs | 42 +- src/backends/plonky2/mock/mod.rs | 1 + src/backends/plonky2/mock/signedpod.rs | 10 +- src/backends/plonky2/mod.rs | 29 + .../plonky2/primitives/signature/circuit.rs | 4 +- src/backends/plonky2/recursion/circuit.rs | 839 +++++++++--------- src/backends/plonky2/recursion/mod.rs | 5 +- src/backends/plonky2/signedpod.rs | 10 +- src/frontend/mod.rs | 10 +- src/frontend/serialization.rs | 32 +- src/lib.rs | 28 + src/middleware/basetypes.rs | 69 +- src/middleware/mod.rs | 106 ++- 23 files changed, 1405 insertions(+), 729 deletions(-) create mode 100644 src/backends/plonky2/emptypod.rs create mode 100644 src/backends/plonky2/mock/emptypod.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3bb217a..899621f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,4 +22,6 @@ jobs: run: cargo build - name: Build metrics run: cargo build --features metrics + - name: Build time + run: cargo build --features time diff --git a/Cargo.toml b/Cargo.toml index ee8b7e0..9acc509 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,3 +39,4 @@ jsonschema = "0.30.0" default = ["backend_plonky2"] backend_plonky2 = ["plonky2"] metrics = [] +time = [] diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 7a4bbf5..3144ed4 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -1,15 +1,16 @@ -//! This file exposes the middleware::basetypes to be used in the middleware when the -//! `backend_plonky2` feature is enabled. +//! This file exposes the basetypes to be used in the middleware when the `backend_plonky2` feature +//! is enabled. //! See src/middleware/basetypes.rs for more details. use plonky2::{ - field::extension::quadratic::QuadraticExtension, + field::{extension::quadratic::QuadraticExtension, goldilocks_field::GoldilocksField}, hash::poseidon::PoseidonHash, - plonk::{config::GenericConfig, proof::Proof as Plonky2Proof}, + plonk::{circuit_builder, circuit_data, config::GenericConfig, proof}, }; use serde::Serialize; -use crate::middleware::F; +/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 +pub type F = GoldilocksField; /// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension). pub const D: usize = 2; @@ -27,5 +28,11 @@ impl GenericConfig for C { type InnerHasher = PoseidonHash; } -/// proof system proof -pub type Proof = Plonky2Proof; +pub type CircuitData = circuit_data::CircuitData; +pub type CommonCircuitData = circuit_data::CommonCircuitData; +pub type ProverOnlyCircuitData = circuit_data::ProverOnlyCircuitData; +pub type VerifierOnlyCircuitData = circuit_data::VerifierOnlyCircuitData; +pub type VerifierCircuitData = circuit_data::VerifierCircuitData; +pub type CircuitBuilder = circuit_builder::CircuitBuilder; +pub type Proof = proof::Proof; +pub type ProofWithPublicInputs = proof::ProofWithPublicInputs; diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 633ed61..7e470d2 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -17,13 +17,12 @@ use plonky2::{ target::{BoolTarget, Target}, witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite}, }, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, util::serialization::{Buffer, IoResult, Read, Write}, }; use crate::{ backends::plonky2::{ - basetypes::D, + basetypes::{CircuitBuilder, CommonCircuitData, D}, circuits::mainpod::CustomPredicateVerification, error::Result, mainpod::{Operation, OperationArg, Statement}, @@ -47,13 +46,13 @@ pub struct ValueTarget { } impl ValueTarget { - pub fn zero(builder: &mut CircuitBuilder) -> Self { + pub fn zero(builder: &mut CircuitBuilder) -> Self { Self { elements: [builder.zero(); VALUE_SIZE], } } - pub fn one(builder: &mut CircuitBuilder) -> Self { + pub fn one(builder: &mut CircuitBuilder) -> Self { Self { elements: array::from_fn(|i| { if i == 0 { @@ -99,25 +98,25 @@ impl StatementArgTarget { } } - pub fn none(builder: &mut CircuitBuilder) -> Self { + pub fn none(builder: &mut CircuitBuilder) -> Self { let empty = builder.constant_value(EMPTY_VALUE); Self::new(empty, empty) } - pub fn literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self { + pub fn literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self { let empty = builder.constant_value(EMPTY_VALUE); Self::new(*value, empty) } pub fn anchored_key( - _builder: &mut CircuitBuilder, + _builder: &mut CircuitBuilder, pod_id: &ValueTarget, key: &ValueTarget, ) -> Self { Self::new(*pod_id, *key) } - pub fn wildcard_literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self { + pub fn wildcard_literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self { let empty = builder.constant_value(EMPTY_VALUE); Self::new(*value, empty) } @@ -137,17 +136,17 @@ pub struct StatementTarget { } pub trait Build { - fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T; + fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T; } impl Build for NativePredicate { - fn build(self, builder: &mut CircuitBuilder, params: &Params) -> NativePredicateTarget { + fn build(self, builder: &mut CircuitBuilder, params: &Params) -> NativePredicateTarget { NativePredicateTarget::constant(builder, params, self) } } impl Build for T { - fn build(self, _builder: &mut CircuitBuilder, _params: &Params) -> T { + fn build(self, _builder: &mut CircuitBuilder, _params: &Params) -> T { self } } @@ -155,7 +154,7 @@ impl Build for T { impl StatementTarget { /// Build a new native StatementTarget. Pads the arguments. pub fn new_native( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, params: &Params, native_predicate: impl Build, args: &[StatementArgTarget], @@ -194,7 +193,7 @@ impl StatementTarget { pub fn has_native_type( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, params: &Params, t: NativePredicate, ) -> BoolTarget { @@ -210,7 +209,7 @@ pub struct OperationTypeTarget { impl OperationTypeTarget { pub fn new_custom( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, batch_id: HashOutTarget, index: Target, ) -> Self { @@ -222,10 +221,7 @@ impl OperationTypeTarget { } } - pub fn as_custom( - &self, - builder: &mut CircuitBuilder, - ) -> (BoolTarget, HashOutTarget, Target) { + pub fn as_custom(&self, builder: &mut CircuitBuilder) -> (BoolTarget, HashOutTarget, Target) { // TODO: Use an enum for these prefixes let three = builder.constant(F::from_canonical_usize(3)); let op_is_custom = builder.is_equal(self.elements[0], three); @@ -234,7 +230,7 @@ impl OperationTypeTarget { (op_is_custom, batch_id, index) } - pub fn has_native(&self, builder: &mut CircuitBuilder, t: NativeOperation) -> BoolTarget { + pub fn has_native(&self, builder: &mut CircuitBuilder, t: NativeOperation) -> BoolTarget { // TODO: Use an enum for these prefixes let one = builder.one(); let op_is_native = builder.is_equal(self.elements[0], one); @@ -288,7 +284,7 @@ pub struct NativePredicateTarget(Target); impl NativePredicateTarget { pub fn constant( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, params: &Params, native_predicate: NativePredicate, ) -> Self { @@ -316,7 +312,7 @@ pub struct PredicateTarget { impl PredicateTarget { pub fn new_native( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, params: &Params, native_predicate: impl Build, ) -> Self { @@ -328,7 +324,7 @@ impl PredicateTarget { } } - pub fn new_batch_self(builder: &mut CircuitBuilder, index: Target) -> Self { + pub fn new_batch_self(builder: &mut CircuitBuilder, index: Target) -> Self { let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf)); let zero = builder.zero(); Self { @@ -337,7 +333,7 @@ impl PredicateTarget { } pub fn new_custom( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, batch_id: HashOutTarget, index: Target, ) -> Self { @@ -373,7 +369,7 @@ impl LiteralOrWildcardTarget { /// cases: ((is_key, key), (is_wildcard, wildcard_index)) pub fn cases( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, ) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) { let zero = builder.zero(); let is_zero_tail: Vec<_> = (1..4) @@ -397,12 +393,12 @@ pub struct StatementTmplArgTarget { } impl StatementTmplArgTarget { - pub fn as_none(&self, builder: &mut CircuitBuilder) -> BoolTarget { + pub fn as_none(&self, builder: &mut CircuitBuilder) -> BoolTarget { let prefix = builder.constant(F::from(StatementTmplArgPrefix::None)); builder.is_equal(self.elements[0], prefix) } - pub fn as_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, ValueTarget) { + pub fn as_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, ValueTarget) { let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal)); let case_ok = builder.is_equal(self.elements[0], prefix); let value = ValueTarget::from_slice(&self.elements[1..5]); @@ -411,7 +407,7 @@ impl StatementTmplArgTarget { pub fn as_anchored_key( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, ) -> (BoolTarget, Target, LiteralOrWildcardTarget) { let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey)); let case_ok = builder.is_equal(self.elements[0], prefix); @@ -420,7 +416,7 @@ impl StatementTmplArgTarget { (case_ok, id_wildcard_index, value_key_or_wildcard) } - pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, Target) { + pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, Target) { let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral)); let case_ok = builder.is_equal(self.elements[0], prefix); let wildcard_index = self.elements[1]; @@ -479,7 +475,7 @@ pub struct CustomPredicateBatchTarget { } impl CustomPredicateBatchTarget { - pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget { let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); builder.hash_n_to_hash_no_pad::(flattened) } @@ -573,7 +569,7 @@ impl Flattenable for CustomPredicateEntryTarget { } impl CustomPredicateEntryTarget { - pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { builder.hash_n_to_hash_no_pad::(self.flatten()) } } @@ -630,7 +626,7 @@ pub struct CustomPredicateVerifyQueryTarget { } impl CustomPredicateVerifyQueryTarget { - pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { builder.hash_n_to_hash_no_pad::(self.flatten()) } } @@ -930,7 +926,7 @@ pub trait CircuitBuilderPod, const D: usize> { fn lt_mask(&mut self, len: usize, n: Target) -> Vec; } -impl CircuitBuilderPod for CircuitBuilder { +impl CircuitBuilderPod for CircuitBuilder { fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) { assert_eq!(xs.len(), ys.len()); for (x, y) in xs.iter().zip(ys.iter()) { @@ -1267,11 +1263,11 @@ impl CircuitBuilderPod for CircuitBuilder { // then do `ts: &[HashCache]`. fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T { // TODO: Revisit this when we need more than 64 statements. - let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { + let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { assert!(v.len() <= 64); builder.random_access(i, v.to_vec()) }; - let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { + let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { let num_rows = m.len(); let num_columns = m .first() @@ -1367,7 +1363,7 @@ pub struct LtMaskGenerator { pub(crate) mask: Vec, } -impl, const D: usize> SimpleGenerator for LtMaskGenerator { +impl SimpleGenerator for LtMaskGenerator { fn id(&self) -> String { "LtMaskGenerator".to_string() } @@ -1390,12 +1386,12 @@ impl, const D: usize> SimpleGenerator for LtM Ok(()) } - fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { dst.write_target(self.n)?; dst.write_target_vec(&self.mask) } - fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { let n = src.read_target()?; let mask = src.read_target_vec()?; Ok(Self { n, mask }) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index a8faeca..5d303c9 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -10,14 +10,14 @@ use plonky2::{ }, iop::{ target::{BoolTarget, Target}, - witness::PartialWitness, + witness::{PartialWitness, WitnessWrite}, }, - plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher}, + plonk::config::AlgebraicHasher, }; use crate::{ backends::plonky2::{ - basetypes::D, + basetypes::CircuitBuilder, circuits::{ common::{ CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget, @@ -28,18 +28,21 @@ use crate::{ }, signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, }, + emptypod::EmptyPod, error::Result, mainpod::{self, pad_statement}, primitives::merkletree::{ MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget, }, + recursion::{InnerCircuit, VerifiedProofTarget}, signedpod::SignedPod, }, measure_gates_begin, measure_gates_end, middleware::{ - AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields, - Value, WildcardValue, F, KEY_TYPE, SELF, VALUE_SIZE, + AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash, + NativeOperation, NativePredicate, Params, PodType, PredicatePrefix, Statement, + StatementArg, ToFields, Value, WildcardValue, EMPTY_VALUE, F, HASH_SIZE, KEY_TYPE, SELF, + VALUE_SIZE, }, }; @@ -47,6 +50,13 @@ use crate::{ // MainPod verification // +/// Offset in public inputs where we store the pod id +pub const PI_OFFSET_ID: usize = 0; +/// Offset in public inputs where we store the verified data array root +pub const PI_OFFSET_VDSROOT: usize = 4; + +pub const NUM_PUBLIC_INPUTS: usize = 8; + struct OperationVerifyGadget { params: Params, } @@ -58,7 +68,7 @@ impl OperationVerifyGadget { /// argument. fn first_n_args_as_values( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, resolved_op_args: &[StatementTarget], ) -> (BoolTarget, [ValueTarget; N]) { let arg_is_valueof = resolved_op_args[..N] @@ -80,7 +90,7 @@ impl OperationVerifyGadget { fn eval( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, prev_statements: &[StatementTarget], @@ -212,7 +222,7 @@ impl OperationVerifyGadget { fn eval_contains_from_entries( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, @@ -260,7 +270,7 @@ impl OperationVerifyGadget { fn eval_not_contains_from_entries( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, @@ -306,7 +316,7 @@ impl OperationVerifyGadget { fn eval_custom( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_custom_pred_verification: HashOutTarget, @@ -331,7 +341,7 @@ impl OperationVerifyGadget { /// NotEqualFromEntries. fn eval_eq_neq_from_entries( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -382,7 +392,7 @@ impl OperationVerifyGadget { /// LtEqFromEntries. fn eval_lt_lteq_from_entries( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -451,7 +461,7 @@ impl OperationVerifyGadget { fn eval_hash_of( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -485,7 +495,7 @@ impl OperationVerifyGadget { fn eval_sum_of( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -524,7 +534,7 @@ impl OperationVerifyGadget { fn eval_product_of( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -563,7 +573,7 @@ impl OperationVerifyGadget { fn eval_max_of( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -609,7 +619,7 @@ impl OperationVerifyGadget { fn eval_transitive_eq( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -645,7 +655,7 @@ impl OperationVerifyGadget { } fn eval_none( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, ) -> BoolTarget { @@ -663,7 +673,7 @@ impl OperationVerifyGadget { fn eval_new_entry( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, prev_statements: &[StatementTarget], @@ -701,7 +711,7 @@ impl OperationVerifyGadget { fn eval_lt_to_neq( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -730,7 +740,7 @@ impl OperationVerifyGadget { fn eval_copy( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], @@ -759,7 +769,7 @@ struct CustomOperationVerifyGadget { impl CustomOperationVerifyGadget { fn statement_arg_from_template( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st_tmpl_arg: &StatementTmplArgTarget, args: &[ValueTarget], ) -> StatementArgTarget { @@ -812,7 +822,7 @@ impl CustomOperationVerifyGadget { fn statement_from_template( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st_tmpl: &StatementTmplTarget, args: &[ValueTarget], ) -> StatementTarget { @@ -836,7 +846,7 @@ impl CustomOperationVerifyGadget { /// - Build the expected operation type fn eval( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, custom_predicate: &CustomPredicateEntryTarget, op_args: &[StatementTarget], args: &[ValueTarget], // arguments to the custom predicate, public and private @@ -901,10 +911,49 @@ impl CustomOperationVerifyGadget { } } -struct CalculateIdGadget { +/// Replace references to SELF by `self_id` in a statement. +struct NormalizeStatementGadget { params: Params, } +impl NormalizeStatementGadget { + fn eval( + &self, + builder: &mut CircuitBuilder, + statement: &StatementTarget, + self_id: &ValueTarget, + ) -> StatementTarget { + let zero_value = builder.constant_value(EMPTY_VALUE); + let self_value = builder.constant_value(SELF.0.into()); + let args = statement + .args + .iter() + .map(|arg| { + let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]); + let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]); + let is_not_ak = builder.is_equal_flattenable(&zero_value, &second); + let is_ak = builder.not(is_not_ak); + let is_self = builder.is_equal_flattenable(&self_value, &first); + let normalize = builder.and(is_ak, is_self); + let first_normalized = + builder.select_flattenable(&self.params, normalize, self_id, &first); + StatementArgTarget::new(first_normalized, second) + }) + .collect_vec(); + StatementTarget { + predicate: statement.predicate.clone(), + args, + } + } +} + +pub struct CalculateIdGadget { + /// `params.num_public_statements_id` is the total number of statements that will be hashed. + /// The id is calculated with front-padded none-statements and then the input statements + /// reversed. The part of the hash from the front-padded none-statements is precomputed. + pub params: Params, +} + impl CalculateIdGadget { /// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder /// elements that didn't fit into a chunk. @@ -923,7 +972,7 @@ impl CalculateIdGadget { /// Hash `inputs` starting from a circuit-constant `perm` state. fn hash_from_state, P: PlonkyPermutation>( - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, perm: P, inputs: &[Target], ) -> HashOutTarget { @@ -953,17 +1002,19 @@ impl CalculateIdGadget { } } - fn eval( + pub fn eval( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, + // These statements will be padded to reach `self.num_statements` statements: &[StatementTarget], ) -> HashOutTarget { + assert!(statements.len() <= self.params.num_public_statements_id); let measure = measure_gates_begin!(builder, "CalculateId"); let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten()); let mut none_st = mainpod::Statement::from(Statement::None); pad_statement(&self.params, &mut none_st); let front_pad_elts = iter::repeat(&none_st) - .take(self.params.num_public_statements_id - self.params.max_public_statements) + .take(self.params.num_public_statements_id - statements.len()) .flat_map(|s| s.to_fields(&self.params)) .collect_vec(); let (perm, front_pad_elts_rem) = @@ -992,7 +1043,7 @@ impl MainPodVerifyGadget { // index fn normalize_st_tmpl( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, st_tmpl: &StatementTmplTarget, id: HashOutTarget, ) -> StatementTmplTarget { @@ -1012,7 +1063,7 @@ impl MainPodVerifyGadget { /// calculate the id of each batch. fn build_custom_predicate_table( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, ) -> Result<(Vec, Vec)> { let measure = measure_gates_begin!(builder, "BuildCustomPredicateTable"); let params = &self.params; @@ -1053,7 +1104,7 @@ impl MainPodVerifyGadget { /// custom predicate against the operation and statement. fn build_custom_predicate_verification_table( &self, - builder: &mut CircuitBuilder, + builder: &mut CircuitBuilder, custom_predicate_table: &[HashOutTarget], ) -> Result<(Vec, Vec)> { let measure = measure_gates_begin!(builder, "BuildCustomPredicateVerificationTable"); @@ -1105,7 +1156,13 @@ impl MainPodVerifyGadget { )) } - fn eval(&self, builder: &mut CircuitBuilder) -> Result { + fn eval( + &self, + builder: &mut CircuitBuilder, + verified_proofs: &[VerifiedProofTarget], + ) -> Result { + assert_eq!(self.params.max_input_recursive_pods, verified_proofs.len()); + let measure = measure_gates_begin!(builder, "MainPodVerify"); let params = &self.params; // 1. Verify all input signed pods @@ -1115,6 +1172,7 @@ impl MainPodVerifyGadget { params: params.clone(), } .eval(builder)?; + builder.assert_one(signed_pod.signature.enabled.target); signed_pods.push(signed_pod); } @@ -1132,16 +1190,47 @@ impl MainPodVerifyGadget { statements.len(), 1 + self.params.max_input_signed_pods * self.params.max_signed_pod_values ); - // TODO: Fill with input main pods - for _main_pod in 0..self.params.max_input_main_pods { - for _statement in 0..self.params.max_public_statements { - statements.push(StatementTarget::new_native( - builder, - &self.params, - NativePredicate::None, - &[], - )) + + let id_gadget = CalculateIdGadget { + params: params.clone(), + }; + let mut input_pods_self_statements: Vec> = Vec::new(); + let normalize_statement_gadget = NormalizeStatementGadget { + params: self.params.clone(), + }; + for verified_proof in verified_proofs { + let expected_id = HashOutTarget::try_from( + &verified_proof.public_inputs[PI_OFFSET_ID..PI_OFFSET_ID + HASH_SIZE], + ) + .expect("4 elements"); + let id_value = ValueTarget { + elements: expected_id.elements, + }; + + let mut input_pod_self_statements = Vec::new(); + for _ in 0..self.params.max_input_pods_public_statements { + let self_st = builder.add_virtual_statement(params); + let normalized_st = normalize_statement_gadget.eval(builder, &self_st, &id_value); + input_pod_self_statements.push(self_st); + statements.push(normalized_st); } + let id = id_gadget.eval(builder, &input_pod_self_statements); + builder.connect_hashes(expected_id, id); + input_pods_self_statements.push(input_pod_self_statements); + } + + let vds_root = builder.add_virtual_hash(); + // TODO: verify that all input pod proofs use verifier data from the public input VD array + // This requires merkle proofs + // https://github.com/0xPARC/pod2/issues/250 + + // Verify that VD array that input pod uses is the same we use now. + for verified_proof in verified_proofs { + let verified_proof_vds_root = HashOutTarget::try_from( + &verified_proof.public_inputs[PI_OFFSET_VDSROOT..PI_OFFSET_VDSROOT + HASH_SIZE], + ) + .expect("4 elements"); + builder.connect_hashes(vds_root, verified_proof_vds_root); } // Add the input (private and public) statements and corresponding operations @@ -1220,8 +1309,10 @@ impl MainPodVerifyGadget { measure_gates_end!(builder, measure); Ok(MainPodVerifyTarget { params: params.clone(), + vds_root, id, signed_pods, + input_pods_self_statements, statements: input_statements.to_vec(), operations, merkle_proofs, @@ -1233,8 +1324,10 @@ impl MainPodVerifyGadget { pub struct MainPodVerifyTarget { params: Params, + vds_root: HashOutTarget, id: HashOutTarget, signed_pods: Vec, + input_pods_self_statements: Vec>, // The KEY_TYPE statement must be the first public one statements: Vec, operations: Vec, @@ -1251,7 +1344,9 @@ pub struct CustomPredicateVerification { } pub struct MainPodVerifyInput { + pub vds_root: Hash, pub signed_pods: Vec, + pub recursive_pods_pub_self_statements: Vec>, pub statements: Vec, pub operations: Vec, pub merkle_proofs: Vec, @@ -1259,18 +1354,44 @@ pub struct MainPodVerifyInput { pub custom_predicate_verifications: Vec, } +fn set_targets_input_pods_self_statements( + pw: &mut PartialWitness, + params: &Params, + statements_target: &[StatementTarget], + statements: &[Statement], +) -> Result<()> { + assert_eq!( + statements_target.len(), + params.max_input_pods_public_statements + ); + assert!(statements.len() <= params.num_public_statements_id); + + for (i, statement) in statements.iter().enumerate() { + statements_target[i].set_targets(pw, params, &statement.clone().into())?; + } + // Padding + let mut none_st = mainpod::Statement::from(Statement::None); + pad_statement(params, &mut none_st); + for statement_target in statements_target.iter().skip(statements.len()) { + statement_target.set_targets(pw, params, &none_st)?; + } + Ok(()) +} + impl MainPodVerifyTarget { pub fn set_targets( &self, pw: &mut PartialWitness, input: &MainPodVerifyInput, ) -> Result<()> { + pw.set_target_arr(&self.vds_root.elements, &input.vds_root.0)?; + assert!(input.signed_pods.len() <= self.params.max_input_signed_pods); for (i, signed_pod) in input.signed_pods.iter().enumerate() { self.signed_pods[i].set_targets(pw, signed_pod)?; } // Padding - if self.params.max_input_signed_pods > 0 { + if input.signed_pods.len() != self.params.max_input_signed_pods { // TODO: Instead of using an input for padding, use a canonical minimal SignedPod, // without it a MainPod configured to support input signed pods must have at least one // input signed pod :( @@ -1279,6 +1400,34 @@ impl MainPodVerifyTarget { self.signed_pods[i].set_targets(pw, pad_pod)?; } } + + assert!( + input.recursive_pods_pub_self_statements.len() <= self.params.max_input_recursive_pods + ); + for (i, pod_pub_statements) in input.recursive_pods_pub_self_statements.iter().enumerate() { + set_targets_input_pods_self_statements( + pw, + &self.params, + &self.input_pods_self_statements[i], + pod_pub_statements, + )?; + } + // Padding + if input.recursive_pods_pub_self_statements.len() != self.params.max_input_recursive_pods { + let empty_pod = EmptyPod::new_boxed(&self.params, input.vds_root); + let empty_pod_statements = empty_pod.pub_statements(); + for i in + input.recursive_pods_pub_self_statements.len()..self.params.max_input_recursive_pods + { + set_targets_input_pods_self_statements( + pw, + &self.params, + &self.input_pods_self_statements[i], + &empty_pod_statements, + )?; + } + } + assert_eq!(input.statements.len(), self.params.max_statements); for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() { self.statements[i].set_targets(pw, &self.params, st)?; @@ -1342,17 +1491,44 @@ pub struct MainPodVerifyCircuit { pub params: Params, } +// TODO: Remove this type and implement it's logic directly in `impl InnerCircuit for MainPodVerifyTarget` impl MainPodVerifyCircuit { - pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { + pub fn eval( + &self, + builder: &mut CircuitBuilder, + verified_proofs: &[VerifiedProofTarget], + ) -> Result { let main_pod = MainPodVerifyGadget { params: self.params.clone(), } - .eval(builder)?; + .eval(builder, verified_proofs)?; builder.register_public_inputs(&main_pod.id.elements); + builder.register_public_inputs(&main_pod.vds_root.elements); Ok(main_pod) } } +impl InnerCircuit for MainPodVerifyTarget { + type Input = MainPodVerifyInput; + type Params = Params; + + fn build( + builder: &mut CircuitBuilder, + params: &Self::Params, + verified_proofs: &[VerifiedProofTarget], + ) -> Result { + MainPodVerifyCircuit { + params: params.clone(), + } + .eval(builder, verified_proofs) + } + + /// assigns the values to the targets + fn set_targets(&self, pw: &mut PartialWitness, input: &Self::Input) -> Result<()> { + self.set_targets(pw, input) + } +} + #[cfg(test)] mod tests { use std::{iter, ops::Not}; @@ -1395,7 +1571,7 @@ mod tests { }; let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::new(config); let st_target = builder.add_virtual_statement(¶ms); let op_target = builder.add_virtual_operation(¶ms); @@ -2268,7 +2444,7 @@ mod tests { expected_st_arg: StatementArg, ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::new(config); let gadget = CustomOperationVerifyGadget { params: params.clone(), }; @@ -2369,7 +2545,7 @@ mod tests { expected_st: Statement, ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::new(config); let gadget = CustomOperationVerifyGadget { params: params.clone(), }; @@ -2433,7 +2609,7 @@ mod tests { expected_st: Option, ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::new(config); let gadget = CustomOperationVerifyGadget { params: params.clone(), }; @@ -2775,7 +2951,7 @@ mod tests { fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + let mut builder = CircuitBuilder::new(config); let gadget = CalculateIdGadget { params: params.clone(), }; diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index 7e460cd..d80f84d 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -83,7 +83,7 @@ pub struct SignedPodVerifyTarget { // the KEY_TYPE entry must be the first one // the KEY_SIGNER entry must be the second one mt_proofs: Vec, - signature: SignatureVerifyTarget, + pub(crate) signature: SignatureVerifyTarget, } impl SignedPodVerifyTarget { diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs new file mode 100644 index 0000000..50d13c6 --- /dev/null +++ b/src/backends/plonky2/emptypod.rs @@ -0,0 +1,207 @@ +use std::{collections::HashMap, sync::Mutex}; + +use base64::{prelude::BASE64_STANDARD, Engine}; +use itertools::Itertools; +use plonky2::{ + hash::hash_types::HashOutTarget, + iop::witness::{PartialWitness, WitnessWrite}, + plonk::{ + circuit_builder::CircuitBuilder, + circuit_data::{self, CircuitConfig}, + proof::ProofWithPublicInputs, + }, +}; + +use crate::{ + backends::plonky2::{ + basetypes::{Proof, C, D}, + circuits::{ + common::{Flattenable, StatementTarget}, + mainpod::{CalculateIdGadget, PI_OFFSET_ID}, + }, + error::{Error, Result}, + mainpod::{self, calculate_id}, + recursion::pad_circuit, + LazyLock, DEFAULT_PARAMS, STANDARD_REC_MAIN_POD_CIRCUIT_DATA, + }, + middleware::{ + self, AnchoredKey, DynError, Hash, Params, Pod, PodId, PodType, RecursivePod, Statement, + ToFields, Value, VerifierOnlyCircuitData, EMPTY_HASH, F, HASH_SIZE, KEY_TYPE, SELF, + }, + timed, +}; + +struct EmptyPodVerifyCircuit { + params: Params, +} + +fn type_statement() -> Statement { + Statement::ValueOf( + AnchoredKey::from((SELF, KEY_TYPE)), + Value::from(PodType::Empty), + ) +} + +impl EmptyPodVerifyCircuit { + fn eval(&self, builder: &mut CircuitBuilder) -> Result { + let type_statement = StatementTarget::from_flattened( + &self.params, + &builder.constants(&type_statement().to_fields(&self.params)), + ); + let id = CalculateIdGadget { + params: self.params.clone(), + } + .eval(builder, &[type_statement]); + let vds_root = builder.add_virtual_hash(); + builder.register_public_inputs(&id.elements); + builder.register_public_inputs(&vds_root.elements); + Ok(EmptyPodVerifyTarget { vds_root }) + } +} + +pub struct EmptyPodVerifyTarget { + vds_root: HashOutTarget, +} + +impl EmptyPodVerifyTarget { + pub fn set_targets(&self, pw: &mut PartialWitness, vds_root: Hash) -> Result<()> { + Ok(pw.set_target_arr(&self.vds_root.elements, &vds_root.0)?) + } +} + +#[derive(Clone, Debug)] +pub struct EmptyPod { + params: Params, + id: PodId, + vds_root: Hash, + proof: Proof, +} + +type CircuitData = circuit_data::CircuitData; + +static STANDARD_EMPTY_POD_DATA: LazyLock<(EmptyPodVerifyTarget, CircuitData)> = + LazyLock::new(|| build().expect("successful build")); + +fn build() -> Result<(EmptyPodVerifyTarget, CircuitData)> { + let params = &*DEFAULT_PARAMS; + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let empty_pod_verify_target = EmptyPodVerifyCircuit { + params: params.clone(), + } + .eval(&mut builder)?; + let circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; + pad_circuit(&mut builder, &circuit_data.common); + + let data = timed!("EmptyPod build", builder.build::()); + assert_eq!(circuit_data.common, data.common); + Ok((empty_pod_verify_target, data)) +} + +static EMPTY_POD_CACHE: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +impl EmptyPod { + pub fn _prove(params: &Params, vds_root: Hash) -> Result { + let (empty_pod_verify_target, data) = &*STANDARD_EMPTY_POD_DATA; + + let mut pw = PartialWitness::::new(); + empty_pod_verify_target.set_targets(&mut pw, vds_root)?; + let proof = timed!("EmptyPod prove", data.prove(pw)?); + let id = &proof.public_inputs[PI_OFFSET_ID..PI_OFFSET_ID + HASH_SIZE]; + let id = PodId(Hash([id[0], id[1], id[2], id[3]])); + Ok(EmptyPod { + params: params.clone(), + id, + vds_root, + proof: proof.proof, + }) + } + pub fn new_boxed(params: &Params, vds_root: Hash) -> Box { + let default_params = &*DEFAULT_PARAMS; + assert_eq!(default_params.id_params(), params.id_params()); + + let empty_pod = EMPTY_POD_CACHE + .lock() + .unwrap() + .entry(vds_root) + .or_insert_with(|| Self::_prove(params, vds_root).expect("prove EmptyPod")) + .clone(); + Box::new(empty_pod) + } + fn _verify(&self) -> Result<()> { + let statements = self + .pub_self_statements() + .into_iter() + .map(mainpod::Statement::from) + .collect_vec(); + let id = PodId(calculate_id(&statements, &self.params)); + if id != self.id { + return Err(Error::id_not_equal(self.id, id)); + } + + let public_inputs = id + .to_fields(&self.params) + .iter() + .chain(EMPTY_HASH.0.iter()) // slot for the unused vds root + .cloned() + .collect_vec(); + + let (_, data) = &*STANDARD_EMPTY_POD_DATA; + data.verify(ProofWithPublicInputs { + proof: self.proof.clone(), + public_inputs, + }) + .map_err(|e| Error::custom(format!("EmptyPod proof verification failure: {:?}", e))) + } +} + +impl Pod for EmptyPod { + fn params(&self) -> &Params { + &self.params + } + fn verify(&self) -> Result<(), Box> { + Ok(self._verify()?) + } + + fn id(&self) -> PodId { + self.id + } + + fn pub_self_statements(&self) -> Vec { + vec![type_statement()] + } + + fn serialized_proof(&self) -> String { + let mut buffer = Vec::new(); + use plonky2::util::serialization::Write; + buffer.write_proof(&self.proof).unwrap(); + BASE64_STANDARD.encode(buffer) + } +} + +impl RecursivePod for EmptyPod { + fn verifier_data(&self) -> VerifierOnlyCircuitData { + let (_, data) = &*STANDARD_EMPTY_POD_DATA; + data.verifier_only.clone() + } + fn proof(&self) -> Proof { + self.proof.clone() + } + fn vds_root(&self) -> Hash { + self.vds_root + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_empty_pod() { + let params = Params::default(); + + let empty_pod = EmptyPod::new_boxed(¶ms, EMPTY_HASH); + empty_pod.verify().unwrap(); + } +} diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index fa602b8..deaab64 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -7,31 +7,29 @@ use itertools::Itertools; pub use operation::*; use plonky2::{ hash::poseidon::PoseidonHash, - iop::witness::PartialWitness, - plonk::{ - circuit_builder::CircuitBuilder, - circuit_data::{CircuitConfig, CommonCircuitData}, - config::Hasher, - proof::{Proof, ProofWithPublicInputs}, - }, + plonk::{circuit_data::CommonCircuitData, config::Hasher}, util::serialization::{Buffer, Read}, }; pub use statement::*; use crate::{ backends::plonky2::{ - basetypes::{C, D}, + basetypes::{Proof, ProofWithPublicInputs, VerifierOnlyCircuitData, D}, circuits::mainpod::{ - CustomPredicateVerification, MainPodVerifyCircuit, MainPodVerifyInput, + CustomPredicateVerification, MainPodVerifyInput, MainPodVerifyTarget, NUM_PUBLIC_INPUTS, }, + emptypod::EmptyPod, error::{Error, Result}, + mock::emptypod::MockEmptyPod, primitives::merkletree::MerkleClaimAndProof, + recursion::{self, RecursiveCircuit, RecursiveParams}, signedpod::SignedPod, + STANDARD_REC_MAIN_POD_CIRCUIT_DATA, }, middleware::{ self, resolve_wildcard_values, AnchoredKey, CustomPredicateBatch, DynError, Hash, MainPodInputs, NativeOperation, NonePod, OperationType, Params, Pod, PodId, PodProver, - PodType, StatementArg, ToFields, F, KEY_TYPE, SELF, + PodType, RecursivePod, StatementArg, ToFields, F, KEY_TYPE, SELF, }, }; @@ -42,11 +40,8 @@ use crate::{ /// with a precomputed constant corresponding to the front-padding part: /// `id = hash(serialize(reverse(statements || none-statements)))` pub(crate) fn calculate_id(statements: &[Statement], params: &Params) -> middleware::Hash { - assert_eq!(params.max_public_statements, statements.len()); + assert!(statements.len() <= params.num_public_statements_id); assert!(params.max_public_statements <= params.num_public_statements_id); - statements - .iter() - .for_each(|st| assert_eq!(params.max_statement_args, st.1.len())); let mut none_st: Statement = middleware::Statement::None.into(); pad_statement(params, &mut none_st); @@ -250,7 +245,11 @@ fn pad_operation_args(params: &Params, args: &mut Vec) { /// Returns the statements from the given MainPodInputs, padding to the /// respective max lengths defined at the given Params. -pub(crate) fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec { +pub(crate) fn layout_statements( + params: &Params, + mock: bool, + inputs: &MainPodInputs, +) -> Result> { let mut statements = Vec::new(); // Statement at index 0 is always None to be used for padding operation arguments in custom @@ -258,6 +257,8 @@ pub(crate) fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec< statements.push(middleware::Statement::None.into()); // Input signed pods region + // TODO: Replace this with a dumb signed pod + // https://github.com/0xPARC/pod2/issues/246 let none_sig_pod_box: Box = Box::new(NonePod {}); let none_sig_pod = none_sig_pod_box.as_ref(); assert!(inputs.signed_pods.len() <= params.max_input_signed_pods); @@ -277,14 +278,20 @@ pub(crate) fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec< } // Input main pods region - let none_main_pod_box: Box = Box::new(NonePod {}); - let none_main_pod = none_main_pod_box.as_ref(); - assert!(inputs.main_pods.len() <= params.max_input_main_pods); - for i in 0..params.max_input_main_pods { - let pod = inputs.main_pods.get(i).copied().unwrap_or(none_main_pod); + let empty_pod_box: Box = + if mock || inputs.recursive_pods.len() == params.max_input_recursive_pods { + // We mocking or we don't need padding so we skip creating an EmptyPod + MockEmptyPod::new_boxed(params) + } else { + EmptyPod::new_boxed(params, inputs.vds_root) + }; + let empty_pod = empty_pod_box.as_ref(); + assert!(inputs.recursive_pods.len() <= params.max_input_recursive_pods); + for i in 0..params.max_input_recursive_pods { + let pod = inputs.recursive_pods.get(i).copied().unwrap_or(empty_pod); let sts = pod.pub_statements(); assert!(sts.len() <= params.max_public_statements); - for j in 0..params.max_public_statements { + for j in 0..params.max_input_pods_public_statements { let mut st = sts .get(j) .unwrap_or(&middleware::Statement::None) @@ -329,7 +336,7 @@ pub(crate) fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec< statements.push(st); } - statements + Ok(statements) } pub(crate) fn process_private_statements_operations( @@ -399,15 +406,25 @@ pub(crate) fn process_public_statements_operations( pub struct Prover {} impl Prover { - fn _prove(&mut self, params: &Params, inputs: MainPodInputs) -> Result { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let main_pod = MainPodVerifyCircuit { - params: params.clone(), - } - .eval(&mut builder)?; + fn _prove(&self, params: &Params, inputs: MainPodInputs) -> Result { + let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; + let (main_pod_target, circuit_data) = + RecursiveCircuit::::circuit_data_padded( + params.max_input_recursive_pods, + &rec_circuit_data.common, + params, + )?; + let rec_params = RecursiveParams { + arity: params.max_input_recursive_pods, + common_data: circuit_data.common.clone(), + verifier_data: circuit_data.verifier_data(), + }; + let main_pod = RecursiveCircuit { + params: rec_params, + prover: circuit_data.prover_data(), + target: main_pod_target, + }; - let mut pw = PartialWitness::::new(); let signed_pods_input: Vec = inputs .signed_pods .iter() @@ -419,6 +436,33 @@ impl Prover { }) .collect_vec(); + // Pad input recursive pods with empty pods if necessary + let empty_pod = if inputs.recursive_pods.len() == params.max_input_recursive_pods { + // We don't need padding so we skip creating an EmptyPod + MockEmptyPod::new_boxed(params) + } else { + EmptyPod::new_boxed(params, inputs.vds_root) + }; + let inputs = MainPodInputs { + recursive_pods: &inputs + .recursive_pods + .iter() + .copied() + .chain(iter::repeat(&*empty_pod)) + .take(params.max_input_recursive_pods) + .collect_vec(), + ..inputs + }; + + let recursive_pods_pub_self_statements = inputs + .recursive_pods + .iter() + .map(|pod| { + assert_eq!(params.id_params(), pod.params().id_params()); + pod.pub_self_statements() + }) + .collect_vec(); + let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( @@ -427,7 +471,7 @@ impl Prover { &custom_predicate_batches, )?; - let statements = layout_statements(params, &inputs); + let statements = layout_statements(params, false, &inputs)?; let operations = process_private_statements_operations( params, &statements, @@ -442,23 +486,38 @@ impl Prover { // get the id out of the public statements let id: PodId = PodId(calculate_id(&public_statements, params)); + let proofs = inputs + .recursive_pods + .iter() + .map(|pod| { + assert_eq!(inputs.vds_root, pod.vds_root()); + ProofWithPublicInputs { + proof: pod.proof(), + public_inputs: [pod.id().0 .0, inputs.vds_root.0].concat(), + } + }) + .collect_vec(); + let verifier_datas = inputs + .recursive_pods + .iter() + .map(|pod| pod.verifier_data()) + .collect_vec(); let input = MainPodVerifyInput { + vds_root: inputs.vds_root, signed_pods: signed_pods_input, + recursive_pods_pub_self_statements, statements: statements[statements.len() - params.max_statements..].to_vec(), operations, merkle_proofs, custom_predicate_batches, custom_predicate_verifications, }; - main_pod.set_targets(&mut pw, &input)?; - - // generate & verify proof - let data = builder.build::(); - let proof_with_pis = data.prove(pw)?; + let proof_with_pis = main_pod.prove(&input, proofs, verifier_datas)?; Ok(MainPod { params: params.clone(), id, + vds_root: inputs.vds_root, public_statements, proof: proof_with_pis.proof, }) @@ -467,41 +526,21 @@ impl Prover { impl PodProver for Prover { fn prove( - &mut self, + &self, params: &Params, inputs: MainPodInputs, - ) -> Result, Box> { + ) -> Result, Box> { Ok(self._prove(params, inputs).map(Box::new)?) } } -pub type MainPodProof = Proof; - #[derive(Clone, Debug)] pub struct MainPod { params: Params, id: PodId, + vds_root: Hash, public_statements: Vec, - proof: MainPodProof, -} - -/// Convert a Statement into middleware::Statement and replace references to SELF by `self_id`. -pub(crate) fn normalize_statement(statement: &Statement, self_id: PodId) -> middleware::Statement { - Statement( - statement.0.clone(), - statement - .1 - .iter() - .map(|sa| match &sa { - StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => { - StatementArg::Key(AnchoredKey::new(self_id, key.clone())) - } - _ => sa.clone(), - }) - .collect(), - ) - .try_into() - .unwrap() + proof: Proof, } // This is a helper function to get the CommonCircuitData necessary to decode @@ -509,66 +548,76 @@ pub(crate) fn normalize_statement(statement: &Statement, self_id: PodId) -> midd // as a constant or with static initialization, but in the meantime we can // generate it on-demand. fn get_common_data(params: &Params) -> Result, Error> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let _main_pod = MainPodVerifyCircuit { - params: params.clone(), - } - .eval(&mut builder) - .map_err(|e| Error::custom(format!("Failed to evaluate MainPodVerifyCircuit: {}", e)))?; - - let data = builder.build::(); - Ok(data.common) + // TODO: Cache this somehow + // https://github.com/0xPARC/pod2/issues/247 + let rec_params = recursion::new_params::( + params.max_input_recursive_pods, + NUM_PUBLIC_INPUTS, + params, + )?; + Ok(rec_params.common_data().clone()) } impl MainPod { fn _verify(&self) -> Result<()> { // 2. get the id out of the public statements - let id: PodId = PodId(calculate_id(&self.public_statements, &self.params)); + let id = PodId(calculate_id(&self.public_statements, &self.params)); if id != self.id { return Err(Error::id_not_equal(self.id, id)); } // 1, 3, 4, 5 verification via the zkSNARK proof + let rec_circuit_data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; // TODO: cache these artefacts - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let _main_pod = MainPodVerifyCircuit { - params: self.params.clone(), - } - .eval(&mut builder)?; - - let data = builder.build::(); - data.verify(ProofWithPublicInputs { - proof: self.proof.clone(), - public_inputs: id.to_fields(&self.params), - }) - .map_err(|e| Error::custom(format!("MainPod proof verification failure: {:?}", e))) + // https://github.com/0xPARC/pod2/issues/247 + let (_, circuit_data) = RecursiveCircuit::::circuit_data_padded( + self.params.max_input_recursive_pods, + &rec_circuit_data.common, + &self.params, + )?; + let public_inputs = id + .to_fields(&self.params) + .iter() + .chain(self.vds_root.0.iter()) + .cloned() + .collect_vec(); + circuit_data + .verify(ProofWithPublicInputs { + proof: self.proof.clone(), + public_inputs, + }) + .map_err(|e| Error::custom(format!("MainPod proof verification failure: {:?}", e))) } - pub fn proof(&self) -> MainPodProof { + pub fn proof(&self) -> Proof { self.proof.clone() } + pub fn vds_root(&self) -> Hash { + self.vds_root + } + pub fn params(&self) -> &Params { &self.params } pub(crate) fn new( - proof: MainPodProof, + proof: Proof, public_statements: Vec, id: PodId, + vds_root: Hash, params: Params, ) -> Self { Self { params, id, + vds_root, public_statements, proof, } } - pub fn decode_proof(proof: &str, params: &Params) -> Result { + pub fn decode_proof(proof: &str, params: &Params) -> Result { let decoded = BASE64_STANDARD.decode(proof).map_err(|e| { Error::custom(format!( "Failed to decode proof from base64: {}. Value: {}", @@ -590,6 +639,9 @@ impl MainPod { } impl Pod for MainPod { + fn params(&self) -> &Params { + &self.params + } fn verify(&self) -> Result<(), Box> { Ok(self._verify()?) } @@ -598,12 +650,11 @@ impl Pod for MainPod { self.id } - fn pub_statements(&self) -> Vec { - // return the public statements, where when origin=SELF is replaced by origin=self.id() + fn pub_self_statements(&self) -> Vec { self.public_statements .iter() .cloned() - .map(|statement| normalize_statement(&statement, self.id())) + .map(|st| st.try_into().expect("valid statement")) .collect() } @@ -615,6 +666,19 @@ impl Pod for MainPod { } } +impl RecursivePod for MainPod { + fn verifier_data(&self) -> VerifierOnlyCircuitData { + let data = &*STANDARD_REC_MAIN_POD_CIRCUIT_DATA; + data.verifier_only.clone() + } + fn proof(&self) -> Proof { + self.proof.clone() + } + fn vds_root(&self) -> Hash { + self.vds_root + } +} + #[cfg(test)] pub mod tests { use super::*; @@ -642,7 +706,7 @@ pub mod tests { let params = middleware::Params { // Currently the circuit uses random access that only supports vectors of length 64. // With max_input_main_pods=3 we need random access to a vector of length 73. - max_input_main_pods: 0, + max_input_recursive_pods: 0, max_custom_predicate_batches: 0, max_custom_predicate_verifications: 0, ..Default::default() @@ -672,10 +736,11 @@ pub mod tests { fn test_mini_0() { let params = middleware::Params { max_input_signed_pods: 1, - max_input_main_pods: 1, + max_input_recursive_pods: 1, max_signed_pod_values: 6, max_statements: 8, max_public_statements: 4, + max_input_pods_public_statements: 10, ..Default::default() }; @@ -715,7 +780,8 @@ pub mod tests { fn test_mainpod_small_empty() { let params = middleware::Params { max_input_signed_pods: 0, - max_input_main_pods: 0, + max_input_recursive_pods: 0, + max_input_pods_public_statements: 2, max_statements: 5, max_signed_pod_values: 2, max_public_statements: 2, @@ -753,14 +819,11 @@ pub mod tests { fn test_main_ethdos() -> frontend::Result<()> { let params = Params { max_input_signed_pods: 2, - max_input_main_pods: 1, + max_input_recursive_pods: 1, max_statements: 26, max_public_statements: 5, max_signed_pod_values: 8, - max_statement_args: 3, - max_operation_args: 4, - max_custom_predicate_arity: 4, - max_custom_batch_size: 3, + max_operation_args: 5, max_custom_predicate_wildcards: 6, max_custom_predicate_verifications: 8, ..Default::default() @@ -805,7 +868,7 @@ pub mod tests { fn test_main_mini_custom_1() -> frontend::Result<()> { let params = Params { max_input_signed_pods: 0, - max_input_main_pods: 0, + max_input_recursive_pods: 0, max_statements: 9, max_public_statements: 4, max_statement_args: 3, diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 7fa9c6c..4f05ab3 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::{fmt, iter}; use serde::{Deserialize, Serialize}; @@ -28,9 +28,15 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, _params: &Params) -> Vec { - let mut fields = self.0.to_fields(_params); - fields.extend(self.1.iter().flat_map(|arg| arg.to_fields(_params))); + fn to_fields(&self, params: &Params) -> Vec { + let mut fields = self.0.to_fields(params); + fields.extend( + self.1 + .iter() + .chain(iter::repeat(&StatementArg::None)) + .take(params.max_statement_args) + .flat_map(|arg| arg.to_fields(params)), + ); fields } } diff --git a/src/backends/plonky2/mock/emptypod.rs b/src/backends/plonky2/mock/emptypod.rs new file mode 100644 index 0000000..981c0df --- /dev/null +++ b/src/backends/plonky2/mock/emptypod.rs @@ -0,0 +1,93 @@ +use itertools::Itertools; + +use crate::{ + backends::plonky2::{ + basetypes::{Proof, VerifierOnlyCircuitData}, + error::{Error, Result}, + mainpod::{self, calculate_id}, + }, + middleware::{ + AnchoredKey, DynError, Hash, Params, Pod, PodId, PodType, RecursivePod, Statement, Value, + KEY_TYPE, SELF, + }, +}; + +#[derive(Clone, Debug)] +pub struct MockEmptyPod { + params: Params, + id: PodId, +} + +fn type_statement() -> Statement { + Statement::ValueOf( + AnchoredKey::from((SELF, KEY_TYPE)), + Value::from(PodType::Empty), + ) +} + +impl MockEmptyPod { + pub fn new_boxed(params: &Params) -> Box { + let statements = [mainpod::Statement::from(type_statement())]; + let id = PodId(calculate_id(&statements, params)); + Box::new(Self { + params: params.clone(), + id, + }) + } + fn _verify(&self) -> Result<()> { + let statements = self + .pub_self_statements() + .into_iter() + .map(mainpod::Statement::from) + .collect_vec(); + let id = PodId(calculate_id(&statements, &self.params)); + if id != self.id { + return Err(Error::id_not_equal(self.id, id)); + } + Ok(()) + } +} + +impl Pod for MockEmptyPod { + fn params(&self) -> &Params { + &self.params + } + fn verify(&self) -> Result<(), Box> { + Ok(self._verify()?) + } + fn id(&self) -> PodId { + self.id + } + fn pub_self_statements(&self) -> Vec { + vec![type_statement()] + } + + fn serialized_proof(&self) -> String { + todo!() + } +} + +impl RecursivePod for MockEmptyPod { + fn verifier_data(&self) -> VerifierOnlyCircuitData { + panic!("MockEmptyPod can't be verified in a recursive MainPod circuit"); + } + fn proof(&self) -> Proof { + panic!("MockEmptyPod can't be verified in a recursive MainPod circuit"); + } + fn vds_root(&self) -> Hash { + panic!("MockEmptyPod can't be verified in a recursive MainPod circuit"); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_mock_empty_pod() { + let params = Params::default(); + + let empty_pod = MockEmptyPod::new_boxed(¶ms); + empty_pod.verify().unwrap(); + } +} diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 00de19d..f93b7da 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -9,17 +9,18 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ + basetypes::{Proof, VerifierOnlyCircuitData}, error::{Error, Result}, mainpod::{ - calculate_id, extract_merkle_proofs, layout_statements, normalize_statement, + calculate_id, extract_merkle_proofs, layout_statements, process_private_statements_operations, process_public_statements_operations, Operation, Statement, }, primitives::merkletree::MerkleClaimAndProof, }, middleware::{ - self, hash_str, AnchoredKey, DynError, MainPodInputs, NativePredicate, Params, Pod, PodId, - PodProver, Predicate, StatementArg, KEY_TYPE, SELF, + self, hash_str, AnchoredKey, DynError, Hash, MainPodInputs, NativePredicate, Params, Pod, + PodId, PodProver, Predicate, RecursivePod, StatementArg, KEY_TYPE, SELF, }, }; @@ -27,10 +28,10 @@ pub struct MockProver {} impl PodProver for MockProver { fn prove( - &mut self, + &self, params: &Params, inputs: MainPodInputs, - ) -> Result, Box> { + ) -> Result, Box> { Ok(Box::new(MockMainPod::new(params, inputs)?)) } } @@ -73,7 +74,8 @@ impl fmt::Display for MockMainPod { } if (i >= offset_input_main_pods) && (i < offset_input_statements) - && ((i - offset_input_main_pods) % self.params.max_public_statements == 0) + && ((i - offset_input_main_pods) % self.params.max_input_pods_public_statements + == 0) { writeln!( f, @@ -137,7 +139,7 @@ impl MockMainPod { } fn offset_input_statements(&self) -> usize { self.offset_input_main_pods() - + self.params.max_input_main_pods * self.params.max_public_statements + + self.params.max_input_recursive_pods * self.params.max_input_pods_public_statements } fn offset_public_statements(&self) -> usize { self.offset_input_statements() + self.params.max_priv_statements() @@ -146,7 +148,7 @@ impl MockMainPod { pub fn new(params: &Params, inputs: MainPodInputs) -> Result { // TODO: Insert a new public statement of ValueOf with `key=KEY_TYPE, // value=PodType::MockMainPod` - let statements = layout_statements(params, &inputs); + let statements = layout_statements(params, true, &inputs)?; // Extract Merkle proofs and pad. let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; @@ -278,20 +280,20 @@ impl MockMainPod { } impl Pod for MockMainPod { + fn params(&self) -> &Params { + &self.params + } fn verify(&self) -> Result<(), Box> { Ok(self._verify()?) } fn id(&self) -> PodId { self.id } - fn pub_statements(&self) -> Vec { - // return the public statements, where when origin=SELF is replaced by origin=self.id() - // By convention we expect the KEY_TYPE to be the first statement - self.statements + fn pub_self_statements(&self) -> Vec { + self.public_statements .iter() - .skip(self.offset_public_statements()) .cloned() - .map(|statement| normalize_statement(&statement, self.id())) + .map(|st| st.try_into().expect("valid statement")) .collect() } @@ -300,6 +302,18 @@ impl Pod for MockMainPod { } } +impl RecursivePod for MockMainPod { + fn verifier_data(&self) -> VerifierOnlyCircuitData { + panic!("MockMainPod can't be verified in a recursive MainPod circuit"); + } + fn proof(&self) -> Proof { + panic!("MockMainPod can't be verified in a recursive MainPod circuit"); + } + fn vds_root(&self) -> Hash { + panic!("MockMainPod can't be verified in a recursive MainPod circuit"); + } +} + #[cfg(test)] pub mod tests { use std::any::Any; diff --git a/src/backends/plonky2/mock/mod.rs b/src/backends/plonky2/mock/mod.rs index f6aa1b7..a174a59 100644 --- a/src/backends/plonky2/mock/mod.rs +++ b/src/backends/plonky2/mock/mod.rs @@ -1,2 +1,3 @@ +pub mod emptypod; pub mod mainpod; pub mod signedpod; diff --git a/src/backends/plonky2/mock/signedpod.rs b/src/backends/plonky2/mock/signedpod.rs index 278f257..010a16b 100644 --- a/src/backends/plonky2/mock/signedpod.rs +++ b/src/backends/plonky2/mock/signedpod.rs @@ -10,7 +10,7 @@ use crate::{ constants::MAX_DEPTH, middleware::{ containers::Dictionary, hash_str, AnchoredKey, DynError, Hash, Key, Params, Pod, PodId, - PodSigner, PodType, RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, + PodSigner, PodType, RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, SELF, }, }; @@ -111,6 +111,9 @@ impl MockSignedPod { } impl Pod for MockSignedPod { + fn params(&self) -> &Params { + panic!("MockSignedPod doesn't have params"); + } fn verify(&self) -> Result<(), Box> { Ok(self._verify()?) } @@ -119,8 +122,7 @@ impl Pod for MockSignedPod { self.id } - fn pub_statements(&self) -> Vec { - let id = self.id(); + fn pub_self_statements(&self) -> Vec { // By convention we put the KEY_TYPE first and KEY_SIGNER second let mut kvs = self.kvs.clone(); let key_type = Key::from(KEY_TYPE); @@ -130,7 +132,7 @@ impl Pod for MockSignedPod { [(key_type, value_type), (key_signer, value_signer)] .into_iter() .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) - .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((id, k)), v)) + .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((SELF, k)), v)) .collect() } diff --git a/src/backends/plonky2/mod.rs b/src/backends/plonky2/mod.rs index 9d65f2f..0ed29a7 100644 --- a/src/backends/plonky2/mod.rs +++ b/src/backends/plonky2/mod.rs @@ -1,5 +1,6 @@ pub mod basetypes; pub mod circuits; +pub mod emptypod; mod error; pub mod mainpod; pub mod mock; @@ -7,4 +8,32 @@ pub mod primitives; pub mod recursion; pub mod signedpod; +use std::sync::LazyLock; + pub use error::*; + +use crate::{ + backends::plonky2::{ + basetypes::CircuitData, + circuits::mainpod::{MainPodVerifyTarget, NUM_PUBLIC_INPUTS}, + recursion::RecursiveCircuit, + }, + middleware::Params, + timed, +}; + +pub static DEFAULT_PARAMS: LazyLock = LazyLock::new(Params::default); + +pub static STANDARD_REC_MAIN_POD_CIRCUIT_DATA: LazyLock = LazyLock::new(|| { + let params = &*DEFAULT_PARAMS; + timed!( + "recursive MainPod circuit_data", + RecursiveCircuit::::target_and_circuit_data( + params.max_input_recursive_pods, + NUM_PUBLIC_INPUTS, + params + ) + .expect("calculate circuit_data") + .1 + ) +}); diff --git a/src/backends/plonky2/primitives/signature/circuit.rs b/src/backends/plonky2/primitives/signature/circuit.rs index 3c2846b..36b75ea 100644 --- a/src/backends/plonky2/primitives/signature/circuit.rs +++ b/src/backends/plonky2/primitives/signature/circuit.rs @@ -23,7 +23,7 @@ use plonky2::{ use crate::{ backends::plonky2::{ - basetypes::{Proof, C, D}, + basetypes::{C, D}, circuits::common::{CircuitBuilderPod, ValueTarget}, error::Result, primitives::signature::{ @@ -31,7 +31,7 @@ use crate::{ }, }, measure_gates_begin, measure_gates_end, - middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, + middleware::{Hash, Proof, RawValue, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, }; lazy_static! { diff --git a/src/backends/plonky2/recursion/circuit.rs b/src/backends/plonky2/recursion/circuit.rs index 9719ed2..0c728fb 100644 --- a/src/backends/plonky2/recursion/circuit.rs +++ b/src/backends/plonky2/recursion/circuit.rs @@ -8,48 +8,66 @@ /// verifies. When arity>1, using the RecursiveCircuit has the shape of a tree /// of the same arity. /// -use hashbrown::HashMap; +use itertools::Itertools; use plonky2::{ self, + field::types::Field, gates::noop::NoopGate, - hash::{ - hash_types::{HashOut, HashOutTarget}, - poseidon::PoseidonHash, - }, + hash::hash_types::HashOutTarget, iop::{ - target::{BoolTarget, Target}, + target::Target, witness::{PartialWitness, WitnessWrite}, }, plonk::{ circuit_builder::CircuitBuilder, circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, VerifierCircuitData, - VerifierCircuitTarget, + VerifierCircuitTarget, VerifierOnlyCircuitData, }, - config::Hasher, proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}, }, - recursion::dummy_circuit::{dummy_circuit, dummy_proof as plonky2_dummy_proof}, + util::log2_ceil, }; use crate::{ backends::plonky2::{ - basetypes::{Proof, C, D}, - error::{Error, Result}, + basetypes::{C, D}, + error::Result, }, middleware::F, + timed, }; +#[derive(Clone, Debug)] +pub struct VerifiedProofTarget { + pub public_inputs: Vec, + pub verifier_data_hash: HashOutTarget, +} + +/// Expected maximum number of constant gates +const MAX_CONSTANT_GATES: usize = 64; + +impl VerifiedProofTarget { + fn add_virtual(builder: &mut CircuitBuilder, num_public_inputs: usize) -> Self { + Self { + public_inputs: (0..num_public_inputs) + .map(|_| builder.add_virtual_target()) + .collect_vec(), + verifier_data_hash: builder.add_virtual_hash(), + } + } +} + /// InnerCircuit is the trait used to define the logic of the circuit that is /// computed inside the RecursiveCircuit. -pub trait InnerCircuit: Clone { - type Input; - type Params; +pub trait InnerCircuit: Sized { + type Input; // Input for witness generation + type Params; // Configuration parameters fn build( builder: &mut CircuitBuilder, params: &Self::Params, - selectors: Vec, + verified_proofs: &[VerifiedProofTarget], ) -> Result; /// assigns the values to the targets @@ -59,64 +77,81 @@ pub trait InnerCircuit: Clone { #[derive(Clone, Debug)] pub struct RecursiveParams { /// determines the arity of the RecursiveCircuit - arity: usize, - common_data: CommonCircuitData, - dummy_proof_pi: ProofWithPublicInputs, - dummy_verifier_data: VerifierCircuitData, + pub(crate) arity: usize, + pub(crate) common_data: CommonCircuitData, + pub(crate) verifier_data: VerifierCircuitData, +} + +impl RecursiveParams { + pub fn common_data(&self) -> &CommonCircuitData { + &self.common_data + } + pub fn verifier_data(&self) -> &VerifierCircuitData { + &self.verifier_data + } } pub fn new_params( arity: usize, + num_public_inputs: usize, inner_params: &I::Params, ) -> Result { - let circuit_data = RecursiveCircuit::::circuit_data(arity, inner_params)?; + let (_, circuit_data) = + RecursiveCircuit::::target_and_circuit_data(arity, num_public_inputs, inner_params)?; let common_data = circuit_data.common.clone(); let verifier_data = circuit_data.verifier_data(); - let dummy_proof_pi = RecursiveCircuit::::dummy_proof(circuit_data)?; Ok(RecursiveParams { arity, common_data, - dummy_proof_pi, - dummy_verifier_data: verifier_data, + verifier_data, + }) +} + +pub fn new_params_padded( + arity: usize, + common_data: &CommonCircuitData, + inner_params: &I::Params, +) -> Result { + let (_, circuit_data) = + RecursiveCircuit::::circuit_data_padded(arity, common_data, inner_params)?; + let common_data = circuit_data.common.clone(); + let verifier_data = circuit_data.verifier_data(); + Ok(RecursiveParams { + arity, + common_data, + verifier_data, }) } /// RecursiveCircuit defines the circuit that verifies `arity` proofs. pub struct RecursiveCircuit { - params: RecursiveParams, - prover: ProverCircuitData, - targets: RecursiveCircuitTarget, + pub(crate) params: RecursiveParams, + pub(crate) prover: ProverCircuitData, + pub(crate) target: RecursiveCircuitTarget, } #[derive(Clone, Debug)] pub struct RecursiveCircuitTarget { - selectors_targ: Vec, innercircuit_targ: I, proofs_targ: Vec>, - vds_hash: HashOutTarget, verifier_datas_targ: Vec, } impl RecursiveCircuit { pub fn prove( - &mut self, - inner_inputs: I::Input, - proofs: Vec, - proofs_inp: Vec>, - prev_hashes: Vec>, - verifier_datas: Vec>, - ) -> Result<(Proof, HashOut)> { + &self, + inner_inputs: &I::Input, + proofs: Vec>, + verifier_datas: Vec>, + ) -> Result> { let mut pw = PartialWitness::new(); - let vds_hash = self.set_targets( + self.set_targets( &mut pw, inner_inputs, // innercircuit_input proofs, - proofs_inp, - prev_hashes, verifier_datas, )?; - let proof = self.prover.prove(pw)?; - Ok((proof.proof, vds_hash)) + Ok(self.prover.prove(pw)?) } /// builds the targets and returns also a ProverCircuitData @@ -124,8 +159,12 @@ impl RecursiveCircuit { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config.clone()); - let targets: RecursiveCircuitTarget = - Self::build_targets(&mut builder, params, inner_params)?; + let targets: RecursiveCircuitTarget = Self::build_targets( + &mut builder, + params.arity, + ¶ms.common_data, + inner_params, + )?; println!("RecursiveCircuit num_gates {}", builder.num_gates()); @@ -133,333 +172,278 @@ impl RecursiveCircuit { Ok(Self { params: params.clone(), prover, - targets, + target: targets, }) } - /// builds the targets - fn build_targets( + /// builds the targets and pad to match the input `common_data` + pub fn build_targets( builder: &mut CircuitBuilder, - params: &RecursiveParams, + arity: usize, + common_data: &CommonCircuitData, inner_params: &I::Params, ) -> Result> { - let selectors_targ: Vec = (0..params.arity) - .map(|_| builder.add_virtual_bool_target_safe()) - .collect(); - - // TODO: investigate - builder.add_gate( - // add a ConstantGate, because without this, when later generating the `dummy_circuit` - // (inside the `conditionally_verify_proof_or_dummy`), it fails due the - // `CommonCircuitData` of the generated circuit not matching the given - // `CommonCircuitData` to create it. Without this it fails because it misses a - // ConstantGate. - plonky2::gates::constant::ConstantGate::new(params.common_data.config.num_constants), - vec![], - ); - // proof verification - let verifier_datas_targ: Vec = (0..params.arity) + let verifier_datas_targ: Vec = (0..arity) .map(|_| builder.add_virtual_verifier_data(builder.config.fri_config.cap_height)) .collect(); - let proofs_targ: Result>> = (0..params.arity) + let proofs_targ: Vec> = (0..arity) .map(|i| { - let proof_targ = builder.add_virtual_proof_with_pis(¶ms.common_data); - builder.conditionally_verify_proof_or_dummy::( - selectors_targ[i], - &proof_targ, - &verifier_datas_targ[i], - ¶ms.common_data, - )?; - Ok(proof_targ) + let proof_targ = builder.add_virtual_proof_with_pis(common_data); + builder.verify_proof::(&proof_targ, &verifier_datas_targ[i], common_data); + proof_targ }) .collect(); - let proofs_targ = proofs_targ?; - // hash the various verifier_data - let prev_verifier_datas_hashes: Vec = proofs_targ - .iter() - .map(|p| HashOutTarget::from_vec(p.public_inputs[..4].to_vec())) - .collect(); - let vds_hash = gadget_hash_verifier_datas( - builder, - params.arity, - prev_verifier_datas_hashes.clone(), - verifier_datas_targ.clone(), - ); - // set vds_hash as public input, which are registered before the - // InnerCircuit public inputs in case that there are - builder.register_public_inputs(&vds_hash.elements); + let verified_proofs = (0..arity) + .map(|i| VerifiedProofTarget { + public_inputs: proofs_targ[i].public_inputs.clone(), + verifier_data_hash: verifier_datas_targ[i].circuit_digest, + }) + .collect_vec(); - // build the InnerCircuit logic. Notice that if the InnerCircuit - // registers any public inputs, they will be placed after the - // `vds_hash` in the public inputs array - let innercircuit_targ: I = I::build(builder, inner_params, selectors_targ.clone())?; + // Build the InnerCircuit logic + let innercircuit_targ: I = I::build(builder, inner_params, &verified_proofs)?; + + pad_circuit(builder, common_data); Ok(RecursiveCircuitTarget { - selectors_targ, innercircuit_targ, proofs_targ, - vds_hash, verifier_datas_targ, }) } fn set_targets( - &mut self, + &self, pw: &mut PartialWitness, - innercircuit_input: I::Input, - recursive_proofs: Vec, - recursive_proofs_inp: Vec>, - prev_verifier_datas_hashes: Vec>, - verifier_datas: Vec>, - ) -> Result> { + innercircuit_input: &I::Input, + recursive_proofs: Vec>, + verifier_datas: Vec>, + ) -> Result<()> { let n = recursive_proofs.len(); - assert!(n <= self.params.arity); - assert_eq!(n, recursive_proofs_inp.len()); - assert_eq!(n, prev_verifier_datas_hashes.len()); + assert_eq!(n, self.params.arity); assert_eq!(n, verifier_datas.len()); - // fill the missing proofs with dummy_proofs - let dummy_proofs: Vec = (n..self.params.arity) - .map(|_| self.params.dummy_proof_pi.proof.clone()) - .collect(); - let recursive_proofs: Vec = [recursive_proofs, dummy_proofs].concat(); - - // fill the missing prev_verifier_data_hashes with the 'zero' hash - let mut prev_verifier_datas_hashes = prev_verifier_datas_hashes.clone(); - prev_verifier_datas_hashes.resize(self.params.arity, HashOut::::ZERO); - - let mut recursive_proofs_inp = recursive_proofs_inp.clone(); - recursive_proofs_inp.resize( - self.params.arity, - // skip the first 4 elements, which contain the vds_hash - self.params.dummy_proof_pi.public_inputs[4..].to_vec(), - ); - - // fill the missing verifier_datas with dummy_verifier_datas - let dummy_verifier_datas: Vec> = (n..self.params.arity) - .map(|_| self.params.dummy_verifier_data.clone()) - .collect(); - let verifier_datas: Vec> = - [verifier_datas, dummy_verifier_datas].concat(); - - // set the first n selectors to true, and the rest to false - for i in 0..n { - pw.set_bool_target(self.targets.selectors_targ[i], true)?; - } - for i in n..self.params.arity { - pw.set_bool_target(self.targets.selectors_targ[i], false)?; - } - // set the InnerCircuit related values - self.targets + self.target .innercircuit_targ - .set_targets(pw, &innercircuit_input)?; + .set_targets(pw, innercircuit_input)?; #[allow(clippy::needless_range_loop)] for i in 0..self.params.arity { - pw.set_verifier_data_target( - &self.targets.verifier_datas_targ[i], - &verifier_datas[i].verifier_only, - )?; - - // put together the public inputs with the verifier_data used to - // verify the current proof - let proof_i_public_inputs = Self::prepare_public_inputs( - prev_verifier_datas_hashes[i], - recursive_proofs_inp[i].clone(), - ); - - pw.set_proof_with_pis_target( - &self.targets.proofs_targ[i], - &ProofWithPublicInputs { - proof: recursive_proofs[i].clone(), - public_inputs: proof_i_public_inputs.clone(), - }, - )?; + pw.set_verifier_data_target(&self.target.verifier_datas_targ[i], &verifier_datas[i])?; + pw.set_proof_with_pis_target(&self.target.proofs_targ[i], &recursive_proofs[i])?; } - // vds_hash is returned since it will be used as public input to verify - // the proof of the current instance of the circuit - let vds_hash = hash_verifier_datas( - self.params.arity, - prev_verifier_datas_hashes.clone(), - verifier_datas.clone(), - ); - pw.set_hash_target(self.targets.vds_hash, vds_hash)?; - - Ok(vds_hash) + Ok(()) } - /// returns the full-recursive CircuitData - pub fn circuit_data(arity: usize, inner_params: &I::Params) -> Result> { - let data: CircuitData = common_data_for_recursion::(arity, inner_params)?; - let common_data = data.common.clone(); - let verifier_data = data.verifier_data(); - let dummy_proof_pi = Self::dummy_proof(data)?; - let params = RecursiveParams { - arity, - common_data, - dummy_proof_pi, - dummy_verifier_data: verifier_data, - }; + /// returns the target full-recursive circuit and its CircuitData + pub fn target_and_circuit_data( + arity: usize, + num_public_inputs: usize, + inner_params: &I::Params, + ) -> Result<(RecursiveCircuitTarget, CircuitData)> { + let rec_common_data = timed!( + "common_data_for_recursion", + common_data_for_recursion::(arity, num_public_inputs, inner_params)? + ); // build the actual RecursiveCircuit circuit data let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let _ = Self::build_targets(&mut builder, ¶ms, inner_params)?; - let data = builder.build::(); + let target = timed!( + "RecursiveCircuit::build_targets", + Self::build_targets(&mut builder, arity, &rec_common_data, inner_params)? + ); + let data = timed!("RecursiveCircuit build", builder.build::()); + assert_eq!(rec_common_data, data.common); - Ok(data) + Ok((target, data)) } - fn dummy_proof(circuit_data: CircuitData) -> Result> { - let dummy_circuit_data = dummy_circuit(&circuit_data.common); - let dummy_proof_pis = plonky2_dummy_proof(&dummy_circuit_data, HashMap::new())?; - Ok(dummy_proof_pis) - } + /// returns the full-recursive CircuitData padded to share the input `common_data` + pub fn circuit_data_padded( + arity: usize, + common_data: &CommonCircuitData, + inner_params: &I::Params, + ) -> Result<(RecursiveCircuitTarget, CircuitData)> { + // build the actual RecursiveCircuit circuit data + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); - pub fn prepare_public_inputs( - prev_verifier_datas_hash: HashOut, - inner_public_inputs: Vec, - ) -> Vec { - [ - prev_verifier_datas_hash.elements.to_vec(), - inner_public_inputs, - ] - .concat() + let target = timed!( + "RecursiveCircuit::build_targets", + Self::build_targets(&mut builder, arity, common_data, inner_params)? + ); + let data = timed!("RecursiveCircuit build", builder.build::()); + assert_eq!(*common_data, data.common); + + Ok((target, data)) } } -fn hash_verifier_datas( - arity: usize, - prev_hashes: Vec>, - verifier_datas: Vec>, -) -> HashOut { - // sanity check - assert_eq!(verifier_datas.len(), arity); - - let zero_hash = HashOut::::ZERO; - let mut prev_hashes = prev_hashes.clone(); - prev_hashes.resize(arity, zero_hash); - let prev_hashes: Vec = prev_hashes +fn coset_interpolation_gate( + subgroup_bits: usize, + degree: usize, + barycentric_weights: &[u64], +) -> plonky2::gates::coset_interpolation::CosetInterpolationGate { + #[allow(dead_code)] + struct Mirror { + subgroup_bits: usize, + degree: usize, + barycentric_weights: Vec, + } + let barycentric_weights = barycentric_weights .iter() - .flat_map(|h| h.elements.to_vec()) - .collect(); - - let hashes: Vec = verifier_datas - .iter() - .flat_map(|vd| vd.verifier_only.circuit_digest.elements) - .collect(); - - let inp: Vec = [prev_hashes, hashes].concat(); - - PoseidonHash::hash_no_pad(&inp) + .map(|v| F::from_canonical_u64(*v)) + .collect_vec(); + let gate = Mirror { + subgroup_bits, + degree, + barycentric_weights, + }; + unsafe { std::mem::transmute(gate) } } -fn gadget_hash_verifier_datas( - builder: &mut CircuitBuilder, - arity: usize, - prev_hashes: Vec, - verifier_datas: Vec, -) -> HashOutTarget { - // sanity checks - assert_eq!(prev_hashes.len(), arity); - assert_eq!(verifier_datas.len(), arity); - - let prev_hashes: Vec = prev_hashes - .iter() - .flat_map(|h| h.elements.to_vec()) - .collect(); - - let hashes: Vec = verifier_datas - .iter() - .flat_map(|vd| vd.circuit_digest.elements) - .collect(); - - let inp: Vec = [prev_hashes, hashes].concat(); - - builder.hash_n_to_hash_no_pad::(inp) -} - -fn common_data_for_recursion( +pub fn common_data_for_recursion( arity: usize, + num_public_inputs: usize, inner_params: &I::Params, -) -> Result> { - // 1st +) -> Result> { let config = CircuitConfig::standard_recursion_config(); - let builder = CircuitBuilder::::new(config); - let data = builder.build::(); - // 2nd - let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config.clone()); - for _ in 0..arity { - let verifier_data_i = - builder.add_virtual_verifier_data(builder.config.fri_config.cap_height); - - let proof = builder.add_virtual_proof_with_pis(&data.common); - builder.verify_proof::(&proof, &verifier_data_i, &data.common); - } - let data = builder.build::(); - - // 3rd - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config.clone()); - - builder.add_gate( - plonky2::gates::constant::ConstantGate::new(config.num_constants), - vec![], - ); - - let verifier_datas_targ: Vec = (0..arity) - .map(|_| builder.add_virtual_verifier_data(builder.config.fri_config.cap_height)) - .collect(); - for vd_i in verifier_datas_targ.iter() { - let proof = builder.add_virtual_proof_with_pis(&data.common); - builder.verify_proof::(&proof, vd_i, &data.common); + use plonky2::gates::gate::GateRef; + // Add our standard set of gates + for gate in [ + GateRef::new(plonky2::gates::noop::NoopGate {}), + GateRef::new(plonky2::gates::constant::ConstantGate::new( + config.num_constants, + )), + GateRef::new(plonky2::gates::poseidon_mds::PoseidonMdsGate::new()), + GateRef::new(plonky2::gates::poseidon::PoseidonGate::new()), + GateRef::new(plonky2::gates::public_input::PublicInputGate {}), + GateRef::new(plonky2::gates::base_sum::BaseSumGate::<2>::new_from_config::(&config)), + GateRef::new(plonky2::gates::reducing_extension::ReducingExtensionGate::new(32)), + GateRef::new(plonky2::gates::reducing::ReducingGate::new(43)), + GateRef::new( + plonky2::gates::arithmetic_extension::ArithmeticExtensionGate::new_from_config(&config), + ), + GateRef::new(plonky2::gates::arithmetic_base::ArithmeticGate::new_from_config(&config)), + GateRef::new( + plonky2::gates::multiplication_extension::MulExtensionGate::new_from_config(&config), + ), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 1)), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 2)), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 3)), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 4)), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 5)), + GateRef::new(plonky2::gates::random_access::RandomAccessGate::new_from_config(&config, 6)), + // It would be better do `CosetInterpolationGate::with_max_degree(4, 6)` but unfortunately + // that plonk2 method is `pub(crate)`, so we need to get around that somehow. + GateRef::new(coset_interpolation_gate( + 4, + 6, + &[ + 17293822565076172801, + 256, + 1048576, + 4294967296, + 17592186044416, + 72057594037927936, + 68719476720, + 281474976645120, + 1152921504338411520, + 18446744069414584065, + 18446744069413535745, + 18446744065119617025, + 18446726477228539905, + 18374686475376656385, + 18446744000695107601, + 18446462594437939201, + ], + )), + ] { + builder.add_gate_to_gate_set(gate); } - let prev_verifier_datas_hashes = builder.add_virtual_hashes(arity); - let vds_hash = gadget_hash_verifier_datas( - &mut builder, - arity, - prev_verifier_datas_hashes.clone(), - verifier_datas_targ.clone(), + let verified_proof = VerifiedProofTarget::add_virtual(&mut builder, num_public_inputs); + let verified_proofs = (0..arity).map(|_| verified_proof.clone()).collect_vec(); + let _ = timed!( + "common_data_for_recursion I::build", + I::build(&mut builder, inner_params, &verified_proofs)? ); - // set vds_hash as public input - builder.register_public_inputs(&vds_hash.elements); + let inner_num_gates = builder.num_gates(); - // set the targets for the InnerCircuit - let _ = I::build(&mut builder, inner_params, vec![])?; + let circuit_data = timed!( + "common_data_for_recursion builder.build", + builder.build::() + ); - // pad min gates - let n_gates = compute_num_gates(arity)?; - while builder.num_gates() < n_gates { + let estimate_verif_num_gates = |degree_bits: usize| { + // Formula obtained via linear regression using `test_measure_recursion` results with + // `standard_recursion_config`. + let num_gates: usize = 236 * degree_bits + 698; + // Add 8% for error because the results are not a clean line + num_gates * 108 / 100 + }; + + // Loop until we find a circuit size that can verify `arity` proofs of itself + let mut degree_bits = log2_ceil(inner_num_gates); + loop { + let verif_num_gates = estimate_verif_num_gates(degree_bits); + // Leave space for public input hashing, a `PublicInputGate` and some `ConstantGate`s (that's + // MAX_CONSTANT_GATES*2 constants in the standard_recursion_config). + let total_num_gates = inner_num_gates + + verif_num_gates * arity + + circuit_data.common.num_public_inputs.div_ceil(8) + + 1 + + MAX_CONSTANT_GATES; + if total_num_gates < (1 << degree_bits) { + break; + } + degree_bits = log2_ceil(total_num_gates); + } + + let mut common_data = circuit_data.common.clone(); + common_data.fri_params.degree_bits = degree_bits; + common_data.fri_params.reduction_arity_bits = vec![4, 4, 4]; + Ok(common_data) +} + +/// Pad the circuit to match a given `CommonCircuitData`. +pub fn pad_circuit(builder: &mut CircuitBuilder, common_data: &CommonCircuitData) { + assert_eq!(common_data.config, builder.config); + assert_eq!(common_data.num_public_inputs, builder.num_public_inputs()); + // TODO: We need to figure this out once we enable zero-knowledge + // https://github.com/0xPARC/pod2/issues/248 + assert!( + !common_data.config.zero_knowledge, + "Degree calculation can be off if zero-knowledge is on." + ); + + let degree = common_data.degree(); + // Need to account for public input hashing, a `PublicInputGate` and MAX_CONSTANT_GATES + // `ConstantGate`. NOTE: the builder doesn't have any public method to see how many constants + // have been registered, so we can't know exactly how many `ConstantGates` will be required. + // We hope that no more than MAX_CONSTANT_GATES*2 constants are used :pray:. Maybe we should + // make a PR to plonky2 to expose this? + let num_gates = degree - common_data.num_public_inputs.div_ceil(8) - 1 - MAX_CONSTANT_GATES; + assert!( + builder.num_gates() < num_gates, + "builder has more gates ({}) than the padding target ({})", + builder.num_gates(), + num_gates, + ); + while builder.num_gates() < num_gates { builder.add_gate(NoopGate, vec![]); } - Ok(builder.build::()) -} - -fn compute_num_gates(arity: usize) -> Result { - // Note: the following numbers are WIP, obtained by trial-error by running different - // configurations in the tests. - let n_gates = match arity { - 0..=1 => 1 << 12, - 2 => 1 << 13, - 3..=5 => 1 << 14, - 6 => 1 << 15, - _ => 0, - }; - if n_gates == 0 { - return Err(Error::custom(format!( - "arity={} not supported. Currently supported arity from 1 to 6 (both included)", - arity - ))); + for gate in &common_data.gates { + builder.add_gate_to_gate_set(gate.clone()); } - Ok(n_gates) } #[cfg(test)] @@ -476,6 +460,7 @@ mod tests { }; use super::*; + use crate::{measure_gates_begin, measure_gates_end, measure_gates_print}; // out-of-circuit input-output computation for Circuit1 fn circuit1_io(inp: HashOut) -> HashOut { @@ -514,7 +499,7 @@ mod tests { fn build( builder: &mut CircuitBuilder, _params: &Self::Params, - _selectors: Vec, + _verified_proofs: &[VerifiedProofTarget], ) -> Result { let input_targ = builder.add_virtual_hash(); let mut aux: Target = input_targ.elements[0]; @@ -549,7 +534,7 @@ mod tests { fn build( builder: &mut CircuitBuilder, _params: &Self::Params, - _selectors: Vec, + _verified_proofs: &[VerifiedProofTarget], ) -> Result { let input_targ = builder.add_virtual_hash(); @@ -583,7 +568,7 @@ mod tests { fn build( builder: &mut CircuitBuilder, _params: &Self::Params, - _selectors: Vec, + _verified_proofs: &[VerifiedProofTarget], ) -> Result { let input_targ = builder.add_virtual_hash(); @@ -632,7 +617,7 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = IC::build(&mut builder, inner_params, vec![])?; + let targets = IC::build(&mut builder, inner_params, &[])?; targets.set_targets(&mut pw, &inner_inputs)?; // generate & verify proof @@ -643,51 +628,35 @@ mod tests { Ok(()) } - #[test] - fn test_hash_verifier_datas() -> Result<()> { - let arity: usize = 2; - let circuit_data = RecursiveCircuit::::circuit_data(1, &())?; - let verifier_data = circuit_data.verifier_data(); + // Build an dummy empty circuit that uses common_data and make a proof from it. + fn dummy( + common_data: &CommonCircuitData, + num_public_inputs: usize, + ) -> Result<( + VerifierOnlyCircuitData, + ProofWithPublicInputs, + )> { + let config = common_data.config.clone(); + let mut builder = CircuitBuilder::new(config.clone()); - let h = hash_verifier_datas( - arity, - vec![], - vec![verifier_data.clone(), verifier_data.clone()], - ); + let public_inputs = (0..num_public_inputs) + .map(|_| { + let target = builder.add_virtual_target(); + builder.register_public_input(target); + target + }) + .collect_vec(); + pad_circuit(&mut builder, common_data); + + let circuit_data = builder.build::(); + assert_eq!(*common_data, circuit_data.common); - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - - // circuit logic - let vd_targ1 = builder.add_virtual_verifier_data(builder.config.fri_config.cap_height); - let vd_targ2 = builder.add_virtual_verifier_data(builder.config.fri_config.cap_height); - let expected_h = builder.add_virtual_hash(); - let prev_hashes_targ = builder.add_virtual_hashes(arity); - - let h_targ = gadget_hash_verifier_datas( - &mut builder, - arity, - prev_hashes_targ.clone(), - vec![vd_targ1.clone(), vd_targ2.clone()], - ); - builder.connect_hashes(expected_h, h_targ); - - // set targets - for ph_targ in prev_hashes_targ { - pw.set_hash_target(ph_targ, HashOut::::ZERO)?; + for target in &public_inputs { + pw.set_target(*target, F::ZERO)?; } - pw.set_hash_target(expected_h, h)?; - pw.set_verifier_data_target(&vd_targ1, &verifier_data.verifier_only)?; - pw.set_verifier_data_target(&vd_targ2, &verifier_data.verifier_only)?; - pw.set_hash_target(expected_h, h)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof.clone())?; - - Ok(()) + let proof = circuit_data.prove(pw)?; + Ok((circuit_data.verifier_only, proof)) } // test that recurses with arity=2, with the following shape: @@ -702,128 +671,148 @@ mod tests { #[test] fn test_recursive_circuit() -> Result<()> { let arity: usize = 2; + let num_public_inputs: usize = 4; type RC = RecursiveCircuit; let inner_params = (); - let params: RecursiveParams = new_params::(arity, &inner_params)?; // build the circuit_data & verifier_data for the recursive circuit let start = Instant::now(); - let circuit_data_1 = RC::::circuit_data(arity, &inner_params)?; - let verifier_data_1 = circuit_data_1.verifier_data(); + let (_, circuit_data_3) = + RC::::target_and_circuit_data(arity, num_public_inputs, &inner_params)?; + let params_3 = RecursiveParams { + arity, + common_data: circuit_data_3.common.clone(), + verifier_data: circuit_data_3.verifier_data(), + }; + let common_data = &circuit_data_3.common; - let circuit_data_2 = RC::::circuit_data(arity, &inner_params)?; - let verifier_data_2 = circuit_data_2.verifier_data(); + let (_, circuit_data_1) = + RC::::circuit_data_padded(arity, &common_data, &inner_params)?; + let params_1 = RecursiveParams { + arity, + common_data: circuit_data_1.common.clone(), + verifier_data: circuit_data_1.verifier_data(), + }; - let circuit_data_3 = RC::::circuit_data(arity, &inner_params)?; - let verifier_data_3 = circuit_data_3.verifier_data(); + let (_, circuit_data_2) = + RC::::circuit_data_padded(arity, &common_data, &inner_params)?; + let params_2 = RecursiveParams { + arity, + common_data: circuit_data_2.common.clone(), + verifier_data: circuit_data_2.verifier_data(), + }; println!( "new_params & (c1, c2, c3).circuit_data generated {:?}", start.elapsed() ); - let mut circuit1 = RC::::build(¶ms, &())?; - let mut circuit2 = RC::::build(¶ms, &())?; - let mut circuit3 = RC::::build(¶ms, &())?; + let (dummy_verifier_data, dummy_proof) = dummy(&common_data, num_public_inputs)?; + + let circuit1 = RC::::build(¶ms_1, &())?; + let circuit2 = RC::::build(¶ms_2, &())?; + let circuit3 = RC::::build(¶ms_3, &())?; println!("circuit1.prove"); let inp = HashOut::::ZERO; let inner_inputs = (inp, circuit1_io(inp)); - let (proof_1a, vds_hash_1a) = - circuit1.prove(inner_inputs, vec![], vec![], vec![], vec![])?; - let inner_publicinputs_1a = circuit1_io(inp).elements.to_vec(); - let public_inputs = - RC::::prepare_public_inputs(vds_hash_1a, inner_publicinputs_1a.clone()); - verifier_data_1.clone().verify(ProofWithPublicInputs { - proof: proof_1a.clone(), - public_inputs: public_inputs.clone(), - })?; + let proof_1a = circuit1.prove( + &inner_inputs, + vec![dummy_proof.clone(), dummy_proof.clone()], + vec![dummy_verifier_data.clone(), dummy_verifier_data.clone()], + )?; + params_1.verifier_data.verify(proof_1a.clone())?; println!( "circuit1.prove (2nd iteration), verifies the proof of 1st iteration with circuit1" ); let inp = HashOut::::ZERO; let inner_inputs = (inp, circuit1_io(inp)); - let (proof_1b, vds_hash_1b) = circuit1.prove( - inner_inputs, - vec![proof_1a.clone()], - vec![inner_publicinputs_1a], - vec![vds_hash_1a], - vec![verifier_data_1.clone()], + let proof_1b = circuit1.prove( + &inner_inputs, + vec![proof_1a.clone(), dummy_proof.clone()], + vec![ + params_1.verifier_data.verifier_only.clone(), + dummy_verifier_data.clone(), + ], )?; - let inner_publicinputs_1b = circuit1_io(inp).elements.to_vec(); - let public_inputs = - RC::::prepare_public_inputs(vds_hash_1b, inner_publicinputs_1b.clone()); - verifier_data_1.clone().verify(ProofWithPublicInputs { - proof: proof_1b.clone(), - public_inputs: public_inputs.clone(), - })?; + params_1.verifier_data.verify(proof_1b.clone())?; println!("circuit3.prove"); let inp = HashOut::::ZERO; let inner_inputs = (inp, circuit3_io(inp)); - let (proof_3, vds_hash_3) = circuit3.prove(inner_inputs, vec![], vec![], vec![], vec![])?; - let inner_publicinputs_3 = circuit3_io(inp).elements.to_vec(); - let public_inputs = - RC::::prepare_public_inputs(vds_hash_3, inner_publicinputs_3.clone()); - verifier_data_3.clone().verify(ProofWithPublicInputs { - proof: proof_3.clone(), - public_inputs: public_inputs.clone(), - })?; + let proof_3 = circuit3.prove( + &inner_inputs, + vec![dummy_proof.clone(), dummy_proof.clone()], + vec![dummy_verifier_data.clone(), dummy_verifier_data.clone()], + )?; + params_3.verifier_data.verify(proof_3.clone())?; println!("circuit1.prove"); let inp = HashOut::::ZERO; let inner_inputs = (inp, circuit1_io(inp)); - let (proof_1c, vds_hash_1c) = - circuit1.prove(inner_inputs, vec![], vec![], vec![], vec![])?; - let inner_publicinputs_1c = circuit1_io(inp).elements.to_vec(); - let public_inputs = - RC::::prepare_public_inputs(vds_hash_1c, inner_publicinputs_1c.clone()); - verifier_data_1.clone().verify(ProofWithPublicInputs { - proof: proof_1c.clone(), - public_inputs: public_inputs.clone(), - })?; + let proof_1c = circuit1.prove( + &inner_inputs, + vec![dummy_proof.clone(), dummy_proof.clone()], + vec![dummy_verifier_data.clone(), dummy_verifier_data.clone()], + )?; + params_1.verifier_data.verify(proof_1c.clone())?; // generate a proof of Circuit2, which internally verifies the proof_3 & proof_1c println!("circuit2.prove, which internally verifies the proof_3 & proof_1c"); let inner_inputs = (inp, circuit2_io(inp)); - let (proof_2, vds_hash_2) = circuit2.prove( - inner_inputs, + let proof_2 = circuit2.prove( + &inner_inputs, vec![proof_3.clone(), proof_1c], - vec![inner_publicinputs_3.clone(), inner_publicinputs_1c.clone()], - vec![vds_hash_3, vds_hash_1c], - vec![verifier_data_3.clone(), verifier_data_1.clone()], + vec![ + params_3.verifier_data.verifier_only.clone(), + params_1.verifier_data.verifier_only.clone(), + ], )?; - let inner_publicinputs_2 = circuit2_io(inp).elements.to_vec(); - let public_inputs = - RC::::prepare_public_inputs(vds_hash_2, inner_publicinputs_2.clone()); - verifier_data_2.clone().verify(ProofWithPublicInputs { - proof: proof_2.clone(), - public_inputs: public_inputs.clone(), - })?; + params_2.verifier_data.verify(proof_2.clone())?; // verify the last proof of circuit2, inside a new circuit1's proof println!("proof_1d = c1.prove([proof_1b, proof_2], [verifier_data_1, verifier_data_2])"); let inp = HashOut::::ZERO; let inner_inputs = (inp, circuit1_io(inp)); - let (proof_1d, vds_hash_1d) = circuit1.prove( - inner_inputs, - // NOTE: if it makes external usage easier, we could group as a - // single input the: proof + inner_publicinputs + vds_hash, in a - // single object `ProofWithPublicInputs`. + let proof_1d = circuit1.prove( + &inner_inputs, vec![proof_1b, proof_2], - vec![inner_publicinputs_1b, inner_publicinputs_2], - vec![vds_hash_1b, vds_hash_2], - vec![verifier_data_1.clone(), verifier_data_2.clone()], + vec![ + params_1.verifier_data.verifier_only.clone(), + params_2.verifier_data.verifier_only.clone(), + ], )?; - let inner_publicinputs = circuit1_io(inp).elements.to_vec(); - let public_inputs = RC::::prepare_public_inputs(vds_hash_1d, inner_publicinputs); - verifier_data_1.clone().verify(ProofWithPublicInputs { - proof: proof_1d.clone(), - public_inputs: public_inputs.clone(), - })?; + params_1.verifier_data.verify(proof_1d)?; Ok(()) } + + #[ignore] + #[test] + fn test_measure_recursion() { + let config = CircuitConfig::standard_recursion_config(); + for i in 7..18 { + let mut builder = CircuitBuilder::new(config.clone()); + builder.add_gate_to_gate_set(plonky2::gates::gate::GateRef::new( + plonky2::gates::constant::ConstantGate::new(config.num_constants), + )); + while builder.num_gates() < (1 << i) - MAX_CONSTANT_GATES { + builder.add_gate(NoopGate, vec![]); + } + println!("build degree 2^{} ...", i); + let circuit_data = builder.build::(); + assert_eq!(circuit_data.common.degree_bits(), i); + + let mut builder = CircuitBuilder::new(config.clone()); + let measure = measure_gates_begin!(&builder, format!("verifier for 2^{}", i)); + let verifier_data_i = + builder.add_virtual_verifier_data(builder.config.fri_config.cap_height); + let proof = builder.add_virtual_proof_with_pis(&circuit_data.common); + builder.verify_proof::(&proof, &verifier_data_i, &circuit_data.common); + measure_gates_end!(&builder, measure); + } + measure_gates_print!(); + } } diff --git a/src/backends/plonky2/recursion/mod.rs b/src/backends/plonky2/recursion/mod.rs index f9452a5..fcf679b 100644 --- a/src/backends/plonky2/recursion/mod.rs +++ b/src/backends/plonky2/recursion/mod.rs @@ -1,2 +1,5 @@ pub mod circuit; -pub use circuit::{InnerCircuit, RecursiveCircuit, RecursiveParams}; +pub use circuit::{ + common_data_for_recursion, new_params, new_params_padded, pad_circuit, InnerCircuit, + RecursiveCircuit, RecursiveParams, VerifiedProofTarget, +}; diff --git a/src/backends/plonky2/signedpod.rs b/src/backends/plonky2/signedpod.rs index bd9f464..0254c2a 100644 --- a/src/backends/plonky2/signedpod.rs +++ b/src/backends/plonky2/signedpod.rs @@ -15,7 +15,7 @@ use crate::{ constants::MAX_DEPTH, middleware::{ containers::Dictionary, AnchoredKey, DynError, Hash, Key, Params, Pod, PodId, PodSigner, - PodType, RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, + PodType, RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, SELF, }, }; @@ -120,6 +120,9 @@ impl SignedPod { } impl Pod for SignedPod { + fn params(&self) -> &Params { + panic!("SignedPod doesn't have params"); + } fn verify(&self) -> Result<(), Box> { Ok(self._verify().map_err(Box::new)?) } @@ -128,8 +131,7 @@ impl Pod for SignedPod { self.id } - fn pub_statements(&self) -> Vec { - let id = self.id(); + fn pub_self_statements(&self) -> Vec { // By convention we put the KEY_TYPE first and KEY_SIGNER second let mut kvs: HashMap = self.dict.kvs().clone(); let key_type = Key::from(KEY_TYPE); @@ -139,7 +141,7 @@ impl Pod for SignedPod { [(key_type, value_type), (key_signer, value_signer)] .into_iter() .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) - .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((id, k)), v)) + .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((SELF, k)), v)) .collect() } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 5ff0c98..c39ddcc 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -10,7 +10,8 @@ use serialization::{SerializedMainPod, SerializedSignedPod}; use crate::middleware::{ self, check_st_tmpl, hash_str, hash_values, AnchoredKey, Hash, Key, MainPodInputs, NativeOperation, NativePredicate, OperationAux, OperationType, Params, PodId, PodProver, - PodSigner, Predicate, Statement, StatementArg, Value, WildcardValue, KEY_TYPE, SELF, + PodSigner, Predicate, Statement, StatementArg, Value, WildcardValue, EMPTY_HASH, KEY_TYPE, + SELF, }; mod custom; @@ -553,7 +554,7 @@ impl MainPodBuilder { .iter() .map(|p| p.pod.as_ref()) .collect_vec(), - main_pods: &self + recursive_pods: &self .input_main_pods .iter() .map(|p| p.pod.as_ref()) @@ -561,6 +562,7 @@ impl MainPodBuilder { statements: &statements, operations: &operations, public_statements: &public_statements, + vds_root: EMPTY_HASH, // TODO https://github.com/0xPARC/pod2/issues/249 }; let pod = prover.prove(&self.params, inputs)?; @@ -618,7 +620,7 @@ impl MainPodBuilder { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(try_from = "SerializedMainPod", into = "SerializedMainPod")] pub struct MainPod { - pub pod: Box, + pub pod: Box, pub public_statements: Vec, pub params: Params, } @@ -898,7 +900,7 @@ pub mod tests { fn test_ethdos() -> Result<()> { let params = Params { max_input_signed_pods: 3, - max_input_main_pods: 3, + max_input_recursive_pods: 3, max_statements: 31, max_signed_pod_values: 8, max_public_statements: 10, diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index c084e5a..29abbe9 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -12,8 +12,8 @@ use crate::{ }, frontend::{MainPod, SignedPod}, middleware::{ - self, containers::Dictionary, serialization::ordered_map, AnchoredKey, Key, Params, PodId, - Statement, StatementArg, Value, SELF, + self, containers::Dictionary, serialization::ordered_map, AnchoredKey, Hash, Key, Params, + PodId, Statement, StatementArg, Value, EMPTY_HASH, SELF, }, }; @@ -45,6 +45,7 @@ pub enum MainPodType { #[schemars(rename = "MainPod")] pub struct SerializedMainPod { id: PodId, + vds_root: Hash, public_statements: Vec, proof: String, params: Params, @@ -99,23 +100,23 @@ impl From for SignedPod { impl From for SerializedMainPod { fn from(pod: MainPod) -> Self { - SerializedMainPod { - id: pod.id(), - proof: pod.pod.serialized_proof(), - params: pod.params.clone(), - pod_type: if (&*pod.pod as &dyn Any) - .downcast_ref::() - .is_some() - { - MainPodType::Main + let (pod_type, vds_root) = + if let Some(pod) = (&*pod.pod as &dyn Any).downcast_ref::() { + (MainPodType::Main, pod.vds_root()) } else if (&*pod.pod as &dyn Any) .downcast_ref::() .is_some() { - MainPodType::MockMain + (MainPodType::MockMain, EMPTY_HASH) } else { unreachable!() - }, + }; + SerializedMainPod { + id: pod.id(), + vds_root, + proof: pod.pod.serialized_proof(), + params: pod.params.clone(), + pod_type, public_statements: pod.public_statements.clone(), } } @@ -142,6 +143,7 @@ impl TryFrom for MainPod { serialized.id, ), serialized.id, + serialized.vds_root, serialized.params.clone(), )), public_statements: serialized.public_statements, @@ -369,7 +371,7 @@ mod tests { let params = middleware::Params { // Currently the circuit uses random access that only supports vectors of length 64. // With max_input_main_pods=3 we need random access to a vector of length 73. - max_input_main_pods: 1, + max_input_recursive_pods: 1, ..Default::default() }; @@ -420,7 +422,7 @@ mod tests { fn build_ethdos_pod() -> Result { let params = Params { max_input_signed_pods: 3, - max_input_main_pods: 3, + max_input_recursive_pods: 3, max_statements: 31, max_signed_pod_values: 8, max_public_statements: 10, diff --git a/src/lib.rs b/src/lib.rs index f7823a5..72c1bbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![allow(clippy::get_first)] #![feature(trait_upcasting)] +#![feature(mapped_lock_guards)] pub mod backends; pub mod constants; @@ -8,3 +9,30 @@ pub mod middleware; #[cfg(test)] pub mod examples; + +#[cfg(feature = "time")] +pub mod time_macros { + #[macro_export] + macro_rules! timed { + ($ctx:expr, $exp:expr) => {{ + let start = std::time::Instant::now(); + let res = $exp; + println!( + "timed \"{}\": {:?}", + $ctx, + std::time::Instant::now() - start + ); + res + }}; + } +} + +#[cfg(not(feature = "time"))] +pub mod time_macros { + #[macro_export] + macro_rules! timed { + ($ctx:expr, $exp:expr) => {{ + $exp + }}; + } +} diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index f2c4827..b43e81f 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -1,45 +1,30 @@ -// TODO: Update this doc -//! This file exposes the backend dependent basetypes as middleware types, -//! taking them from the feature-enabled backend. +//! This file exposes the imported backend dependent basetypes as middleware types, taking them +//! from the feature-enabled backend. //! -//! This is done in order to avoid inconsistencies where a type or parameter is -//! defined in the middleware to have certain carachteristic and later in the -//! backend it gets used differently. The idea is that those types and -//! parameters (eg. lengths) have a single source of truth in the code; and in -//! the case of the "base types" this is determined by the backend being used -//! under the hood, not by a choice of the middleware parameters. +//! This is done in order to avoid inconsistencies where a type or parameter is defined in the +//! middleware to have certain carachteristic and later in the backend it gets used differently. +//! The idea is that those types and parameters (eg. lengths) have a single source of truth in the +//! code; and in the case of the "base types" this is determined by the backend being used under +//! the hood, not by a choice of the middleware parameters. //! -//! The idea with this approach, is that the frontend & middleware should not -//! need to import the proving library used by the backend (eg. plonky2, -//! plonky3, etc). +//! The idea with this approach, is that the frontend & middleware should not need to import the +//! proving library used by the backend (eg. plonky2, plonky3, etc). //! -//! For example, the `Hash` and `Value` types are types belonging at the -//! middleware, and is the middleware who reasons about them, but depending on -//! the backend being used, the `Hash` and `Value` types will have different -//! sizes. So it's the backend being used who actually defines their nature -//! under the hood. For example with a plonky2 backend, these types will have a -//! length of 4 field elements, whereas with a plonky3 backend they will have a -//! length of 8 field eleements. +//! For example, the `Hash` and `Value` types are types belonging at the middleware, and is the +//! middleware who reasons about them, but depending on the backend being used, the `Hash` and +//! `Value` types will have different sizes. So it's the backend being used who actually defines +//! their nature under the hood. For example with a plonky2 backend, these types will have a length +//! of 4 field elements, whereas with a plonky3 backend they will have a length of 8 field +//! eleements. //! -//! Note that his approach does not introduce new traits or abstract code, -//! just makes use of rust features to define 'base types' that are being used -//! in the middleware. +//! Note that his approach does not introduce new traits or abstract code, just makes use of rust +//! features to define 'base types' that are being used in the middleware. //! //! -//! NOTE (TMP): current implementation still uses plonky2 in the middleware for -//! u64/i64 to F conversion. Eventually we will do those conversions through the -//! approach described in this file, removing the imports of plonky2 in the -//! middleware. -//! TODO: Update this doc +//! NOTE (TMP): current implementation still uses plonky2 in the middleware for u64/i64 to F +//! conversion. Eventually we will do those conversions through the approach described in this +//! file, removing the imports of plonky2 in the middleware. -/// Value, Hash and F are imported based on 'features'. For example by default -/// we use the 'plonky2' feature, but it could be used a 'plonky3' feature, so -/// then the Value, Hash and F types would come from the plonky3 backend. -// #[cfg(feature = "backend_plonky2")] -// pub use crate::backends::plonky2::basetypes::{ -// hash_fields, hash_str, hash_value, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE, -// SELF_ID_HASH, VALUE_SIZE, -// }; use std::{ cmp::{Ord, Ordering}, fmt, @@ -47,10 +32,7 @@ use std::{ use hex::{FromHex, FromHexError}; use plonky2::{ - field::{ - goldilocks_field::GoldilocksField, - types::{Field, PrimeField64}, - }, + field::types::{Field, PrimeField64}, hash::poseidon::PoseidonHash, plonk::config::Hasher, }; @@ -58,11 +40,14 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use super::serialization::*; +// Plonky2 specific types. +// Value, Hash, F and other types are imported based on 'features'. For example by default we use +// theg'plonky2' feature, but it could be used a 'plonky3' feature, so then the Value, Hash and F +// types would come from the plonky3 backend. +#[cfg(feature = "backend_plonky2")] +pub use crate::backends::plonky2::basetypes::*; use crate::middleware::{Params, ToFields, Value}; -/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 -pub type F = GoldilocksField; - pub const HASH_SIZE: usize = 4; pub const VALUE_SIZE: usize = 4; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 41d7e10..b6b9239 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -558,8 +558,10 @@ pub enum PodType { None = 0, MockSigned = 1, MockMain = 2, - Signed = 3, - Main = 4, + MockEmpty = 3, + Signed = 4, + Main = 5, + Empty = 6, } impl fmt::Display for PodType { @@ -568,45 +570,56 @@ impl fmt::Display for PodType { PodType::None => write!(f, "None"), PodType::MockSigned => write!(f, "MockSigned"), PodType::MockMain => write!(f, "MockMain"), + PodType::MockEmpty => write!(f, "MockEmpty"), PodType::Signed => write!(f, "Signed"), PodType::Main => write!(f, "Main"), + PodType::Empty => write!(f, "Empty"), } } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] pub struct Params { pub max_input_signed_pods: usize, - pub max_input_main_pods: usize, + pub max_input_recursive_pods: usize, + pub max_input_pods_public_statements: usize, pub max_statements: usize, pub max_signed_pod_values: usize, pub max_public_statements: usize, - // Number of public statements to hash to calculate the id. Must be equal or greater than - // `max_public_statements`. - pub num_public_statements_id: usize, - pub max_statement_args: usize, pub max_operation_args: usize, // max number of custom predicates batches that a MainPod can use pub max_custom_predicate_batches: usize, // max number of operations using custom predicates that can be verified in the MainPod pub max_custom_predicate_verifications: usize, - // max number of statements that can be ANDed or ORed together - // in a custom predicate - pub max_custom_predicate_arity: usize, pub max_custom_predicate_wildcards: usize, - pub max_custom_batch_size: usize, // maximum number of merkle proofs pub max_merkle_proofs: usize, // maximum depth for merkle tree gadget pub max_depth_mt_gadget: usize, + // + // The following parameters define how a pod id is calculated. They need to be the same among + // different circuits to be compatible in their verification. + // + // Number of public statements to hash to calculate the id. Must be equal or greater than + // `max_public_statements`. + pub num_public_statements_id: usize, + pub max_statement_args: usize, + // + // The following parameters define how a custom predicate batch id is calculated. + // + // max number of statements that can be ANDed or ORed together + // in a custom predicate + pub max_custom_predicate_arity: usize, + pub max_custom_batch_size: usize, } impl Default for Params { fn default() -> Self { Self { max_input_signed_pods: 3, - max_input_main_pods: 3, + max_input_recursive_pods: 2, + max_input_pods_public_statements: 10, max_statements: 20, max_signed_pod_values: 8, max_public_statements: 10, @@ -661,6 +674,16 @@ impl Params { self.max_custom_batch_size * self.custom_predicate_size() } + /// Parameters that define how the id is calculated + pub fn id_params(&self) -> Vec { + vec![ + self.num_public_statements_id, + self.max_statement_args, + self.max_custom_predicate_arity, + self.max_custom_batch_size, + ] + } + pub fn print_serialized_sizes(&self) { println!("Parameter sizes:"); println!( @@ -678,12 +701,39 @@ impl Params { } } +/// Replace references to SELF by `self_id` in anchored keys of the statement. +pub fn normalize_statement(statement: &Statement, self_id: PodId) -> Statement { + let predicate = statement.predicate(); + let args = statement + .args() + .iter() + .map(|sa| match &sa { + StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => { + StatementArg::Key(AnchoredKey::new(self_id, key.clone())) + } + _ => sa.clone(), + }) + .collect(); + Statement::from_args(predicate, args).expect("statement was valid before normalization") +} + pub type DynError = dyn std::error::Error + Send + Sync; pub trait Pod: fmt::Debug + DynClone + Any { + fn params(&self) -> &Params; fn verify(&self) -> Result<(), Box>; fn id(&self) -> PodId; - fn pub_statements(&self) -> Vec; + /// Statements as internally generated, where self-referencing arguments use SELF in the + /// anchored key. The serialization of these statements is used to calculate the id. + fn pub_self_statements(&self) -> Vec; + /// Normalized statements, where self-referencing arguments use the pod id instead of SELF in + /// the anchored key. + fn pub_statements(&self) -> Vec { + self.pub_self_statements() + .into_iter() + .map(|statement| normalize_statement(&statement, self.id())) + .collect() + } /// Extract key-values from ValueOf public statements fn kvs(&self) -> HashMap { self.pub_statements() @@ -708,9 +758,21 @@ pub trait Pod: fmt::Debug + DynClone + Any { fn serialized_proof(&self) -> String; } -// impl Clone for Box +// impl Clone for Box dyn_clone::clone_trait_object!(Pod); +/// Trait for pods that are generated with a plonky2 circuit and that can be verified by a +/// recursive MainPod circuit. A Pod implementing this trait does not necesarilly come from +/// recursion: for example an introduction Pod in general is not recursive. +pub trait RecursivePod: Pod { + fn verifier_data(&self) -> VerifierOnlyCircuitData; + fn proof(&self) -> Proof; + fn vds_root(&self) -> Hash; +} + +// impl Clone for Box +dyn_clone::clone_trait_object!(RecursivePod); + pub trait PodSigner { fn sign( &mut self, @@ -719,19 +781,24 @@ pub trait PodSigner { ) -> Result, Box>; } +// TODO: Delete once we have a fully working EmptyPod and a dumb SignedPod +// https://github.com/0xPARC/pod2/issues/246 /// This is a filler type that fulfills the Pod trait and always verifies. It's empty. This /// can be used to simulate padding in a circuit. #[derive(Debug, Clone)] pub struct NonePod {} impl Pod for NonePod { + fn params(&self) -> &Params { + panic!("NonePod doesn't have params"); + } fn verify(&self) -> Result<(), Box> { Ok(()) } fn id(&self) -> PodId { PodId(EMPTY_HASH) } - fn pub_statements(&self) -> Vec { + fn pub_self_statements(&self) -> Vec { Vec::new() } fn serialized_proof(&self) -> String { @@ -742,20 +809,21 @@ impl Pod for NonePod { #[derive(Debug)] pub struct MainPodInputs<'a> { pub signed_pods: &'a [&'a dyn Pod], - pub main_pods: &'a [&'a dyn Pod], + pub recursive_pods: &'a [&'a dyn RecursivePod], pub statements: &'a [Statement], pub operations: &'a [Operation], /// Statements that need to be made public (they can come from input pods or input /// statements) pub public_statements: &'a [Statement], + pub vds_root: Hash, // TODO: Figure out if we use Hash or a Map here https://github.com/0xPARC/pod2/issues/249 } pub trait PodProver { fn prove( - &mut self, + &self, params: &Params, inputs: MainPodInputs, - ) -> Result, Box>; + ) -> Result, Box>; } pub trait ToFields {