diff --git a/Cargo.toml b/Cargo.toml index cfe1a0c..458c0d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ serde_arrays = "0.2.0" sha2 = { version = "0.10.9" } # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. -# [patch."https://github.com/0xPolygonZero/plonky2"] +# [patch."https://github.com/0xPARC/plonky2"] # plonky2 = { path = "../plonky2/plonky2" } [dev-dependencies] @@ -61,3 +61,7 @@ time = [] examples = [] disk_cache = ["directories", "minicbor-serde"] mem_cache = [] + +# Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release. +# [profile.release] +# debug = true diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index c189725..0d18d97 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -26,7 +26,7 @@ use crate::{ basetypes::{CircuitBuilder, CommonCircuitData, D}, circuits::mainpod::CustomPredicateVerification, error::Result, - mainpod::{Operation, OperationArg, Statement}, + mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::MerkleClaimAndProofTarget, }, middleware::{ @@ -128,6 +128,10 @@ impl StatementArgTarget { pub fn as_value(&self) -> ValueTarget { ValueTarget::from_slice(&self.elements[..VALUE_SIZE]) } + + fn size(_params: &Params) -> usize { + STATEMENT_ARG_F_LEN + } } #[derive(Clone, Serialize, Deserialize)] @@ -249,6 +253,10 @@ impl OperationTypeTarget { ) -> Result<()> { Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?) } + + fn size(_params: &Params) -> usize { + Params::operation_type_size() + } } // TODO: Implement Operation::to_field to determine the size of each element @@ -256,8 +264,7 @@ impl OperationTypeTarget { pub struct OperationTarget { pub op_type: OperationTypeTarget, pub args: Vec, - #[serde(with = "serde_arrays")] - pub aux: [IndexTarget; 2], + pub aux_index: IndexTarget, } impl OperationTarget { @@ -277,11 +284,13 @@ impl OperationTarget { { self.args[i].set_targets(pw, arg.as_usize())?; } - let indexes = op.aux().as_usizes(); - for (index_target, index) in self.aux.iter().zip_eq(indexes.iter()) { - index_target.set_targets(pw, *index)?; - } - Ok(()) + self.aux_index.set_targets(pw, op.aux().table_index(params)) + } + + fn size(params: &Params) -> usize { + OperationTypeTarget::size(params) + + params.max_operation_args * IndexTarget::size(params) + + IndexTarget::size(params) } } @@ -570,12 +579,16 @@ impl Flattenable for CustomPredicateEntryTarget { .collect() } fn from_flattened(params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), Self::size(params)); Self { id: HashOutTarget::from_flattened(params, &vs[0..4]), index: vs[4], predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]), } } + fn size(params: &Params) -> usize { + HashOutTarget::size(params) + 1 + CustomPredicateTarget::size(params) + } } impl CustomPredicateEntryTarget { @@ -669,15 +682,16 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { .collect() } fn from_flattened(params: &Params, vs: &[Target]) -> Self { - let (pos, size) = (0, params.statement_size()); + assert_eq!(vs.len(), Self::size(params)); + let (pos, size) = (0, StatementTarget::size(params)); let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]); - let (pos, size) = (pos + size, params.operation_size(IndexTarget::f_len())); + let (pos, size) = (pos + size, OperationTypeTarget::size(params)); let op_type = OperationTypeTarget { elements: vs[pos..pos + size] .try_into() .expect("len = operation_type_size"), }; - let (pos, size) = (pos + size, params.statement_size()); + let (pos, size) = (pos + size, StatementTarget::size(params)); let op_args = (0..params.max_operation_args) .map(|i| { StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) @@ -689,6 +703,10 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { op_args, } } + fn size(params: &Params) -> usize { + StatementTarget::size(params) * (1 + params.max_operation_args) + + OperationTarget::size(params) + } } /// Trait for target structs that may be converted to and from vectors @@ -696,8 +714,11 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { pub trait Flattenable { fn flatten(&self) -> Vec; fn from_flattened(params: &Params, vs: &[Target]) -> Self; + /// Size in number of `Target`s + fn size(params: &Params) -> usize; } +// TODO: Figure out why this is defined in common and not in the merkletree directory /// For the purpose of op verification, we need only look up the /// Merkle claim rather than the Merkle proof since it is verified /// elsewhere. @@ -726,21 +747,28 @@ impl Flattenable for HashOutTarget { fn flatten(&self) -> Vec { self.elements.to_vec() } - fn from_flattened(_params: &Params, vs: &[Target]) -> Self { - assert_eq!(vs.len(), HASH_SIZE); + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), Self::size(params)); Self { elements: array::from_fn(|i| vs[i]), } } + fn size(_params: &Params) -> usize { + 4 + } } impl Flattenable for ValueTarget { fn flatten(&self) -> Vec { self.elements.to_vec() } - fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), Self::size(params)); Self::from_slice(vs) } + fn size(_params: &Params) -> usize { + 4 + } } impl Flattenable for MerkleClaimTarget { @@ -755,7 +783,8 @@ impl Flattenable for MerkleClaimTarget { .concat() } - fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), Self::size(params)); Self { enabled: BoolTarget::new_unsafe(vs[0]), root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), @@ -768,6 +797,10 @@ impl Flattenable for MerkleClaimTarget { existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), } } + + fn size(params: &Params) -> usize { + 2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params) + } } impl Flattenable for PredicateTarget { @@ -775,11 +808,15 @@ impl Flattenable for PredicateTarget { self.elements.to_vec() } - fn from_flattened(_params: &Params, v: &[Target]) -> Self { + fn from_flattened(params: &Params, v: &[Target]) -> Self { + assert_eq!(v.len(), Self::size(params)); Self { elements: v.try_into().expect("len is predicate_size"), } } + fn size(_params: &Params) -> usize { + Params::predicate_size() + } } impl Flattenable for StatementTarget { @@ -792,13 +829,9 @@ impl Flattenable for StatementTarget { } fn from_flattened(params: &Params, v: &[Target]) -> Self { - let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN; - assert_eq!( - v.len(), - Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN - ); + assert_eq!(v.len(), Self::size(params)); let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]); - let args = (0..num_args) + let args = (0..params.max_statement_args) .map(|i| StatementArgTarget { elements: array::from_fn(|j| { v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j] @@ -808,6 +841,10 @@ impl Flattenable for StatementTarget { Self { predicate, args } } + + fn size(params: &Params) -> usize { + PredicateTarget::size(params) + params.max_statement_args * StatementArgTarget::size(params) + } } impl Flattenable for CustomPredicateTarget { @@ -819,6 +856,7 @@ impl Flattenable for CustomPredicateTarget { } fn from_flattened(params: &Params, v: &[Target]) -> Self { + assert_eq!(v.len(), Self::size(params)); // We assume that `from_flattened` is always called with the output of `flattened`, so // this `BoolTarget` should actually safe. let conjunction = BoolTarget::new_unsafe(v[0]); @@ -836,6 +874,9 @@ impl Flattenable for CustomPredicateTarget { args_len, } } + fn size(params: &Params) -> usize { + 2 + params.max_custom_predicate_arity * StatementTmplTarget::size(params) + } } impl Flattenable for StatementTmplTarget { @@ -848,6 +889,7 @@ impl Flattenable for StatementTmplTarget { } fn from_flattened(params: &Params, v: &[Target]) -> Self { + assert_eq!(v.len(), Self::size(params)); let pred_end = Params::predicate_size(); let pred = PredicateTarget::from_flattened(params, &v[..pred_end]); let sta_size = Params::statement_tmpl_arg_size(); @@ -859,6 +901,11 @@ impl Flattenable for StatementTmplTarget { .collect(); Self { pred, args } } + + fn size(params: &Params) -> usize { + PredicateTarget::size(params) + + params.max_statement_args * StatementTmplArgTarget::size(params) + } } impl Flattenable for StatementTmplArgTarget { @@ -866,24 +913,28 @@ impl Flattenable for StatementTmplArgTarget { self.elements.to_vec() } - fn from_flattened(_params: &Params, v: &[Target]) -> Self { + fn from_flattened(params: &Params, v: &[Target]) -> Self { + assert_eq!(v.len(), Self::size(params)); Self { elements: v.try_into().expect("len is statement_tmpl_arg_size"), } } + fn size(_params: &Params) -> usize { + Params::statement_tmpl_arg_size() + } } /// Index to an array for random access -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct IndexTarget { - max_array_len: usize, - low: Target, - high: Target, + pub max_array_len: usize, + pub low: Target, + pub high: Target, } impl IndexTarget { // Length in field elements - pub const fn f_len() -> usize { + pub fn size(_params: &Params) -> usize { 2 } pub fn new_virtual(max_array_len: usize, builder: &mut CircuitBuilder) -> Self { @@ -1051,10 +1102,7 @@ impl CircuitBuilderPod for CircuitBuilder { args: (0..params.max_operation_args) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .collect(), - aux: [ - IndexTarget::new_virtual(params.max_merkle_proofs_containers, self), - IndexTarget::new_virtual(params.max_custom_predicate_verifications, self), - ], + aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), } } diff --git a/src/backends/plonky2/circuits/hash.rs b/src/backends/plonky2/circuits/hash.rs new file mode 100644 index 0000000..3ab4f90 --- /dev/null +++ b/src/backends/plonky2/circuits/hash.rs @@ -0,0 +1,57 @@ +use plonky2::{ + hash::{ + hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + hashing::PlonkyPermutation, + }, + iop::target::Target, + plonk::config::AlgebraicHasher, +}; + +use crate::{backends::plonky2::basetypes::CircuitBuilder, middleware::F}; + +/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder +/// elements that didn't fit into a chunk. +pub fn precompute_hash_state>(inputs: &[F]) -> (P, &[F]) { + let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE); + let mut perm = P::new(core::iter::repeat(F::ZERO)); + + // Absorb all inputs up to the biggest multiple of RATE. + for input_chunk in inputs.chunks(P::RATE) { + perm.set_from_slice(input_chunk, 0); + perm.permute(); + } + + (perm, inputs_rem) +} + +/// Hash `inputs` starting from a circuit-constant `perm` state. +pub fn hash_from_state_circuit, P: PlonkyPermutation>( + builder: &mut CircuitBuilder, + perm: P, + inputs: &[Target], +) -> HashOutTarget { + let mut state = + H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v))); + + // Absorb all input chunks. + for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { + // Overwrite the first r elements with the inputs. This differs from a standard sponge, + // where we would xor or add in the inputs. This is a well-known variant, though, + // sometimes called "overwrite mode". + state.set_from_slice(input_chunk, 0); + state = builder.permute::(state); + } + + let num_outputs = NUM_HASH_OUT_ELTS; + // Squeeze until we have the desired number of outputs. + let mut outputs = Vec::with_capacity(num_outputs); + loop { + for &s in state.squeeze() { + outputs.push(s); + if outputs.len() == num_outputs { + return HashOutTarget::from_vec(outputs); + } + } + state = builder.permute::(state); + } +} diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 81bd5a6..306fe00 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -5,15 +5,13 @@ use num::{BigUint, One}; use plonky2::{ field::types::Field, hash::{ - hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, - hashing::PlonkyPermutation, + hash_types::HashOutTarget, poseidon::{PoseidonHash, PoseidonPermutation}, }, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::config::AlgebraicHasher, }; use plonky2_u32::gadgets::multiple_comparison::list_le_circuit; use serde::{Deserialize, Serialize}; @@ -29,15 +27,18 @@ use crate::{ OperationTypeTarget, PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget, }, + hash::{hash_from_state_circuit, precompute_hash_state}, + mux_table::{MuxTableTarget, TableEntryTarget}, signedpod::{verify_signed_pod_circuit, SignedPodVerifyTarget}, }, emptypod::{cache_get_standard_empty_pod_circuit_data, EmptyPod}, error::Result, - mainpod::{self, pad_statement, OperationArg}, + mainpod::{self, pad_statement}, primitives::{ ec::{ bits::{BigUInt320Target, CircuitBuilderBits}, curve::{CircuitBuilderElliptic, Point, WitnessWriteCurve, GROUP_ORDER}, + schnorr::SecretKey, }, merkletree::{ verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, @@ -49,8 +50,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, OperationType, Params, PodType, PredicatePrefix, Statement, StatementArg, - ToFields, TypedValue, Value, ValueRef, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE, + NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields, + Value, ValueRef, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE, }, }; // @@ -184,6 +185,144 @@ fn verify_operation_public_statement_circuit( Ok(()) } +enum OperationAuxTableTag { + None = 0, + MerkleProof = 1, + PublicKeyOf = 2, + CustomPredVerify = 3, +} + +fn max_operation_aux_entry_len(params: &Params) -> usize { + [ + (params.max_merkle_proofs_containers > 0).then(|| MerkleClaimTarget::size(params)), + (params.max_public_key_of > 0).then(|| KeyPairTarget::size(params)), + (params.max_custom_predicate_verifications > 0) + .then(|| CustomPredicateVerifyQueryTarget::size(params)), + ] + .into_iter() + .flatten() + .max() + .unwrap_or(0) +} + +#[derive(Copy, Clone)] +struct KeyPairTarget { + pk_hash: HashOutTarget, + sk_hash: HashOutTarget, +} + +impl Flattenable for KeyPairTarget { + fn flatten(&self) -> Vec { + self.pk_hash + .elements + .into_iter() + .chain(self.sk_hash.elements) + .collect() + } + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), Self::size(params)); + Self { + pk_hash: HashOutTarget::try_from(&vs[..4]).expect("len = 4"), + sk_hash: HashOutTarget::try_from(&vs[4..]).expect("len = 4"), + } + } + fn size(_params: &Params) -> usize { + 8 + } +} + +fn build_operation_aux_table_circuit( + params: &Params, + builder: &mut CircuitBuilder, + merkle_proofs: &[MerkleClaimAndProofTarget], + public_key_of_sks: &[BigUInt320Target], + custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], + custom_predicate_table: &[HashOutTarget], +) -> Result { + let measure = measure_gates_begin!(builder, "BuildOpAuxTbl"); + assert_eq!( + params.max_custom_predicate_verifications, + custom_predicate_verifications.len() + ); + assert_eq!(params.max_merkle_proofs_containers, merkle_proofs.len()); + let max_entry_len = max_operation_aux_entry_len(params); + let mut table = MuxTableTarget::new(params, max_entry_len); + + // None + table.push_flattened(builder, OperationAuxTableTag::None as u32, &[]); + + // MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) + for merkle_proof in merkle_proofs { + verify_merkle_proof_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from(merkle_proof.clone()); + + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + } + + // PublicKeyOf: verify the derivation from a Schnorr secret key to public key + for sk in public_key_of_sks { + let measure = measure_gates_begin!(builder, "PublicKeyOf"); + let invgenerator = builder.constant_point(Point::generator().inverse()); + let group_orderm1 = &*GROUP_ORDER - BigUint::one(); + let group_orderm1target = builder.constant_biguint320(&group_orderm1); + let compare_ok = list_le_circuit( + builder, + sk.limbs.to_vec(), + group_orderm1target.limbs.to_vec(), + 32, + ); + builder.assert_one(compare_ok.target); + // public_key = g^-secret key + let pk = builder.multiply_point(&sk.bits, &invgenerator); + let sk_hash = builder.hash_n_to_hash_no_pad::(sk.limbs.to_vec()); + let pk_hash = builder.hash_n_to_hash_no_pad::( + pk.x.components.into_iter().chain(pk.u.components).collect(), + ); + + let entry = KeyPairTarget { pk_hash, sk_hash }; + + table.push(builder, OperationAuxTableTag::PublicKeyOf as u32, &entry); + measure_gates_end!(builder, measure); + } + + // CustomPredVerify: verify custom predicate statements verification against operations + for entry in custom_predicate_verifications { + let measure = measure_gates_begin!(builder, "CustomPredVerify"); + // Verify the custom predicate operation + let (statement, op_type) = make_custom_statement_circuit( + params, + builder, + &entry.custom_predicate, + &entry.op_args, + &entry.args, + )?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + params, + custom_predicate_table, + &entry.custom_predicate_table_index, + ); + let out_query_hash = entry.custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let query = CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args: entry.op_args.clone(), // input + }; + table.push( + builder, + OperationAuxTableTag::CustomPredVerify as u32, + &query, + ); + measure_gates_end!(builder, measure); + } + + measure_gates_end!(builder, measure); + Ok(table) +} + #[allow(clippy::too_many_arguments)] fn verify_operation_circuit( params: &Params, @@ -192,9 +331,7 @@ fn verify_operation_circuit( op: &OperationTarget, prev_statements: &[StatementTarget], input_statements_offset: usize, - merkle_claims: &[MerkleClaimTarget], - secret_key: &BigUInt320Target, - custom_predicate_verification_table: &[HashOutTarget], + aux_table: &MuxTableTarget, ) -> Result<()> { let measure = measure_gates_begin!(builder, "OpVerify"); let _true = builder._true(); @@ -206,81 +343,54 @@ fn verify_operation_circuit( let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); let cache = StatementCache::new(params, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); - // TODO: Can we have a single table with merkel claims and verified custom predicates - // together (with an identifying prefix) and then we only need one random access instead of - // two? - // Currently we use one slot of aux for the index to merkle claim and another slot of aux - // for the index to the verified custom predicate. We can't use the same slot because then - // if one table is different size the random access to the smaller one may use an index - // that is too big and not pass the constraints. Possible solutions to use a single slot - // are: - // - a. Use a single table (mux both tables) - // - b. select the index or 0 by checking the operation type here; but that breaks the - // current abstraction a little bit. - // Certain operations (Contains/NotContains) will refer to one - // of the provided Merkle proofs (if any). These proofs have already - // been verified, so we need only look up the claim. - let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim"); - let resolved_merkle_claim = - (!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, &op.aux[0])); - measure_gates_end!(builder, measure_resolve_merkle_claim); + // Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified + // entries in a table (e.g.: Merkle proofs ). These entries have already been verified, so we + // need only look up the claim. - // Operations from custom statements will refer to one - // of the provided custom predicates verifications (if any). These operations have already - // been verified, so we need only look up the entry. - let measure_resolve_custom_pred_verification = - measure_gates_begin!(builder, "ResolveCustomPredVerification"); - let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty()) - .then(|| builder.vec_ref(params, custom_predicate_verification_table, &op.aux[1])); - measure_gates_end!(builder, measure_resolve_custom_pred_verification); + // The aux table always has a fixed zero entry, so we check if there are more than 1 entries to + // trigger the unhashing. + let resolved_aux = (aux_table.len() > 1).then(|| aux_table.get(builder, &op.aux_index)); - // The verification may require aux data which needs to be stored in the - // `OperationVerifyTarget` so that we can set during witness generation. - - // For now only support native operations - // Op checks to carry out. Each 'eval_X' should - // be thought of as 'eval' restricted to the op of type X, - // where the returned target is `false` if the input targets + // Op checks to carry out. Each 'verify_X_circuit' should be thought of as operation check + // restricted to the op of type X, where the returned target is `false` if the input targets // lie outside of the domain. - let op_checks = [ - vec![ - verify_none_circuit(params, builder, st, &op.op_type), - verify_new_entry_circuit( - params, - builder, - st, - &op.op_type, - prev_statements, - input_statements_offset, - ), - ], - // Skip these if there are no resolved op args - if cache.op_args.is_empty() { - vec![] - } else { - vec![ - verify_copy_circuit(builder, st, &op.op_type, &cache.op_args), - verify_eq_neq_from_entries_circuit(params, builder, st, &op.op_type, &cache), - verify_lt_lteq_from_entries_circuit(params, builder, st, &op.op_type, &cache), - verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args), - verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args), - verify_hash_of_circuit(params, builder, st, &op.op_type, &cache), - verify_public_key_of_circuit(params, builder, st, &op.op_type, secret_key, &cache), - verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), - verify_product_of_circuit(params, builder, st, &op.op_type, &cache), - verify_max_of_circuit(params, builder, st, &op.op_type, &cache), - ] - }, - // Skip these if there are no resolved Merkle claims - if let Some(resolved_merkle_claim) = resolved_merkle_claim { - vec![ + let mut op_checks = Vec::new(); + op_checks.extend_from_slice(&[ + verify_none_circuit(params, builder, st, &op.op_type), + verify_new_entry_circuit( + params, + builder, + st, + &op.op_type, + prev_statements, + input_statements_offset, + ), + ]); + // Skip these if there are no resolved op args + if !cache.op_args.is_empty() { + op_checks.extend_from_slice(&[ + verify_copy_circuit(builder, st, &op.op_type, &cache.op_args), + verify_eq_neq_from_entries_circuit(params, builder, st, &op.op_type, &cache), + verify_lt_lteq_from_entries_circuit(params, builder, st, &op.op_type, &cache), + verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args), + verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args), + verify_hash_of_circuit(params, builder, st, &op.op_type, &cache), + verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), + verify_product_of_circuit(params, builder, st, &op.op_type, &cache), + verify_max_of_circuit(params, builder, st, &op.op_type, &cache), + ]); + } + // Skip these if there are no resolved aux entries + if let Some(resolved_aux) = resolved_aux { + if params.max_merkle_proofs_containers > 0 { + op_checks.extend_from_slice(&[ verify_contains_from_entries_circuit( params, builder, st, &op.op_type, - resolved_merkle_claim, + &resolved_aux, &cache, ), verify_not_contains_from_entries_circuit( @@ -288,27 +398,31 @@ fn verify_operation_circuit( builder, st, &op.op_type, - resolved_merkle_claim, + &resolved_aux, &cache, ), - ] - } else { - vec![] - }, - // Skip these if there are no resolved custom predicate verifications - if let Some(resolved_custom_pred_verification) = resolved_custom_pred_verification { - vec![verify_custom_circuit( + ]); + } + if params.max_public_key_of > 0 { + op_checks.push(verify_public_key_of_circuit( + params, builder, st, &op.op_type, - resolved_custom_pred_verification, + &resolved_aux, + &cache, + )); + } + if params.max_custom_predicate_verifications > 0 { + op_checks.push(verify_custom_circuit( + builder, + st, + &op.op_type, + &resolved_aux, &cache.op_args, - )] - } else { - vec![] - }, - ] - .concat(); + )); + } + } let ok = builder.any(op_checks); builder.assert_one(ok.target); @@ -326,10 +440,12 @@ fn verify_contains_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_merkle_claim: MerkleClaimTarget, + aux: &TableEntryTarget, cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); + let (aux_tag_ok, resolved_merkle_claim) = + aux.as_type::(builder, OperationAuxTableTag::MerkleProof as u32); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries); let (arg_types_ok, [merkle_root_value, key_value, value_value]) = @@ -364,7 +480,7 @@ fn verify_contains_from_entries_circuit( ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); + let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]); measure_gates_end!(builder, measure); ok } @@ -374,10 +490,12 @@ fn verify_not_contains_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_merkle_claim: MerkleClaimTarget, + aux: &TableEntryTarget, cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); + let (aux_tag_ok, resolved_merkle_claim) = + aux.as_type::(builder, OperationAuxTableTag::MerkleProof as u32); let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries); let (arg_types_ok, [merkle_root_value, key_value]) = cache.first_n_args_as_values(); @@ -409,7 +527,7 @@ fn verify_not_contains_from_entries_circuit( ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let ok = builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]); + let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]); measure_gates_end!(builder, measure); ok } @@ -418,20 +536,24 @@ fn verify_custom_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_custom_pred_verification: HashOutTarget, + aux: &TableEntryTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpCustom"); - let query = CustomPredicateVerifyQueryTarget { - statement: st.clone(), - op_type: op_type.clone(), - op_args: resolved_op_args.to_vec(), - }; - let out_query_hash = query.hash(builder); - let ok = builder.is_equal_slice( - &resolved_custom_pred_verification.elements, - &out_query_hash.elements, + let (aux_tag_ok, resolved_query) = aux.as_type::( + builder, + OperationAuxTableTag::CustomPredVerify as u32, ); + + let query_ok = builder.is_equal_flattenable( + &resolved_query, + &CustomPredicateVerifyQueryTarget { + statement: st.clone(), + op_type: op_type.clone(), + op_args: resolved_op_args.to_vec(), + }, + ); + let ok = builder.all([aux_tag_ok, query_ok]); measure_gates_end!(builder, measure); ok } @@ -593,41 +715,27 @@ fn verify_public_key_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - secret_key: &BigUInt320Target, + aux: &TableEntryTarget, cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpPublicKeyOf"); + let (aux_tag_ok, resolved_key_pair) = + aux.as_type::(builder, OperationAuxTableTag::PublicKeyOf as u32); let op_code_ok = op_type.has_native(builder, NativeOperation::PublicKeyOf); let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); // inputting public_key, secret_key - let public_key_hash = arg1_value.elements; - let secret_key_hash = arg2_value.elements; + let public_key_hash = arg1_value; + let secret_key_hash = arg2_value; - let secret_key_hash_v = - builder.hash_n_to_hash_no_pad::(secret_key.limbs.to_vec()); - let skey_hash_ok = builder.is_equal_slice(&secret_key_hash, &secret_key_hash_v.elements); - let invgenerator = builder.constant_point(Point::generator().inverse()); - let secret_key_bits = secret_key.bits; - let group_orderm1 = &*GROUP_ORDER - BigUint::one(); - let group_orderm1target = builder.constant_biguint320(&group_orderm1); - let compare_ok = list_le_circuit( - builder, - secret_key.limbs.to_vec(), - group_orderm1target.limbs.to_vec(), - 32, + let skey_hash_ok = builder.is_equal_slice( + &secret_key_hash.elements, + &resolved_key_pair.sk_hash.elements, ); - // public_key = g^-secret key - let public_key = builder.multiply_point(&secret_key_bits, &invgenerator); - let public_key_hash_v = builder.hash_n_to_hash_no_pad::( - public_key - .x - .components - .into_iter() - .chain(public_key.u.components) - .collect(), + let pkey_hash_ok = builder.is_equal_slice( + &public_key_hash.elements, + &resolved_key_pair.pk_hash.elements, ); - let pkey_hash_ok = builder.is_equal_slice(&public_key_hash, &public_key_hash_v.elements); let arg1_expected = cache.equations[0].lhs.clone(); let arg2_expected = cache.equations[1].lhs.clone(); @@ -641,10 +749,10 @@ fn verify_public_key_of_circuit( let ok = builder.all([ op_code_ok, + aux_tag_ok, arg_types_ok, pkey_hash_ok, skey_hash_ok, - compare_ok, st_ok, ]); measure_gates_end!(builder, measure); @@ -1078,53 +1186,6 @@ fn normalize_statement_circuit( } } -/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder -/// elements that didn't fit into a chunk. -fn precompute_hash_state>(inputs: &[F]) -> (P, &[F]) { - let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE); - let mut perm = P::new(core::iter::repeat(F::ZERO)); - - // Absorb all inputs up to the biggest multiple of RATE. - for input_chunk in inputs.chunks(P::RATE) { - perm.set_from_slice(input_chunk, 0); - perm.permute(); - } - - (perm, inputs_rem) -} - -/// Hash `inputs` starting from a circuit-constant `perm` state. -fn hash_from_state_circuit, P: PlonkyPermutation>( - builder: &mut CircuitBuilder, - perm: P, - inputs: &[Target], -) -> HashOutTarget { - let mut state = - H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v))); - - // Absorb all input chunks. - for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { - // Overwrite the first r elements with the inputs. This differs from a standard sponge, - // where we would xor or add in the inputs. This is a well-known variant, though, - // sometimes called "overwrite mode". - state.set_from_slice(input_chunk, 0); - state = builder.permute::(state); - } - - let num_outputs = NUM_HASH_OUT_ELTS; - // Squeeze until we have the desired number of outputs. - let mut outputs = Vec::with_capacity(num_outputs); - loop { - for &s in state.squeeze() { - outputs.push(s); - if outputs.len() == num_outputs { - return HashOutTarget::from_vec(outputs); - } - } - state = builder.permute::(state); - } -} - /// `params.num_public_statements_id` is the total number of statements that will be hashed. /// The id is calculated with front-padded none-statements and then the input statements /// reversed. The part of the hash from the front-padded none-statements is precomputed. @@ -1216,50 +1277,6 @@ fn build_custom_predicate_table_circuit( Ok(custom_predicate_table) } -/// Build table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] -/// with queryable part as hash([st, op, op_args]). While building the table we verify each -/// custom predicate against the operation and statement. Return the hash of each table "query" -/// sub-entry. -fn build_custom_predicate_verification_table_circuit( - params: &Params, - builder: &mut CircuitBuilder, - custom_predicate_table: &[HashOutTarget], - custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], -) -> Result> { - let measure = measure_gates_begin!(builder, "BuildCustomPredVerifyTbl"); - let mut custom_predicate_verification_table = - Vec::with_capacity(params.max_custom_predicate_verifications); - for entry in custom_predicate_verifications { - // Verify the custom predicate operation - let (statement, op_type) = make_custom_statement_circuit( - params, - builder, - &entry.custom_predicate, - &entry.op_args, - &entry.args, - )?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = builder.vec_ref( - params, - custom_predicate_table, - &entry.custom_predicate_table_index, - ); - let out_query_hash = entry.custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let query = CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args: entry.op_args.clone(), // input - }; - let in_query_hash = query.hash(builder); - custom_predicate_verification_table.push(in_query_hash); - } - measure_gates_end!(builder, measure); - Ok(custom_predicate_verification_table) -} - fn verify_main_pod_circuit( builder: &mut CircuitBuilder, main_pod: &MainPodVerifyTarget, @@ -1354,26 +1371,17 @@ fn verify_main_pod_circuit( let public_statements_offset = main_pod.input_statements.len() - params.max_public_statements; let pub_statements = &main_pod.input_statements[public_statements_offset..]; - // Verify Merkle claim/proof targets - let merkle_claims = main_pod - .merkle_proofs - .iter() - .map(|mt_proof| { - verify_merkle_proof_circuit(builder, mt_proof); - MerkleClaimTarget::from(mt_proof.clone()) - }) - .collect_vec(); - // Table of custom predicate batches with batch_id calculation let custom_predicate_table = build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicate_batches)?; - // Table of custom predicate statements verification against operations - let custom_predicate_verification_table = build_custom_predicate_verification_table_circuit( + let aux_table = build_operation_aux_table_circuit( params, builder, - &custom_predicate_table, + &main_pod.merkle_proofs, + &main_pod.public_key_of_sks, &main_pod.custom_predicate_verifications, + &custom_predicate_table, )?; // 2. Calculate the Pod Id from the public statements @@ -1408,9 +1416,7 @@ fn verify_main_pod_circuit( op, prev_statements, input_statements_offset, - &merkle_claims, - &main_pod.secret_keys[i], - &custom_predicate_verification_table, + &aux_table, )?; } else { verify_operation_public_statement_circuit( @@ -1439,7 +1445,7 @@ pub struct MainPodVerifyTarget { input_statements: Vec, operations: Vec, merkle_proofs: Vec, - secret_keys: Vec, + public_key_of_sks: Vec, custom_predicate_batches: Vec, custom_predicate_verifications: Vec, } @@ -1473,7 +1479,7 @@ impl MainPodVerifyTarget { MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder) }) .collect(), - secret_keys: (0..params.max_statements) + public_key_of_sks: (0..params.max_public_key_of) .map(|_| builder.add_virtual_biguint320_target()) .collect(), custom_predicate_batches: (0..params.max_custom_predicate_batches) @@ -1504,6 +1510,7 @@ pub struct MainPodVerifyInput { pub statements: Vec, pub operations: Vec, pub merkle_proofs: Vec, + pub public_key_of_sks: Vec, pub custom_predicate_batches: Vec>, pub custom_predicate_verifications: Vec, } @@ -1612,41 +1619,6 @@ impl InnerCircuit for MainPodVerifyTarget { for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() { self.input_statements[i].set_targets(pw, &self.params, st)?; self.operations[i].set_targets(pw, &self.params, op)?; - if matches!( - op.op_type(), - OperationType::Native(NativeOperation::PublicKeyOf) - ) { - if let StatementArg::Literal(value) = &st.1[1] { - if let TypedValue::SecretKey(sk) = value.typed() { - pw.set_biguint320_target(&self.secret_keys[i], &sk.0)?; - } else { - panic!("SecretKey literal of incorrect type!") - } - } else if let OperationArg::Index(ind) = op.1[1] { - // TODO: This adjustment only works if the secret key came - // from a statement in the current POD, which is the most - // common case. A more general solution needs to be able - // index across the virtual array of statements from all - // input PODs, similar to what's done in - // plonky2::mainpod::layout_statements. - let adjusted_index = ind - - (1 + self.params.max_input_signed_pods - * self.params.max_signed_pod_values - + self.params.max_input_recursive_pods - * self.params.max_public_statements); - if let StatementArg::Literal(value) = &input.statements[adjusted_index].1[1] { - if let TypedValue::SecretKey(sk) = value.typed() { - pw.set_biguint320_target(&self.secret_keys[i], &sk.0)?; - } else { - panic!("SecretKey literal of incorrect type!") - } - } - } else { - panic!("SecretKey arg not found!") - } - } else { - pw.set_biguint320_target(&self.secret_keys[i], &BigUint::ZERO)?; - } } assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs_containers); @@ -1659,6 +1631,16 @@ impl InnerCircuit for MainPodVerifyTarget { self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?; } + assert!(input.public_key_of_sks.len() <= self.params.max_public_key_of); + for (i, sk) in input.public_key_of_sks.iter().enumerate() { + pw.set_biguint320_target(&self.public_key_of_sks[i], &sk.0)?; + } + // Padding + let pad_sk = BigUint::ZERO; + for i in input.public_key_of_sks.len()..self.params.max_public_key_of { + pw.set_biguint320_target(&self.public_key_of_sks[i], &pad_sk)?; + } + assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches); for (i, cpb) in input.custom_predicate_batches.iter().enumerate() { self.custom_predicate_batches[i].set_targets(pw, &self.params, cpb)?; @@ -1727,7 +1709,7 @@ mod tests { frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ hash_str, hash_values, Hash, Key, OperationType, PodId, Predicate, RawValue, - StatementTmpl, StatementTmplArg, Wildcard, + StatementTmpl, StatementTmplArg, TypedValue, Wildcard, }, }; @@ -1736,11 +1718,13 @@ mod tests { op: mainpod::Operation, prev_statements: Vec, merkle_proofs: Vec, - secret_key: &SecretKey, + secret_keys: Vec, ) -> Result<()> { let params = Params { max_custom_predicate_batches: 0, max_custom_predicate_verifications: 0, + max_merkle_proofs_containers: merkle_proofs.len(), + max_public_key_of: secret_keys.len(), ..Default::default() }; @@ -1752,24 +1736,29 @@ mod tests { let prev_statements_target: Vec<_> = (0..prev_statements.len()) .map(|_| builder.add_virtual_statement(¶ms)) .collect(); + let merkle_proofs_target: Vec<_> = merkle_proofs .iter() .map(|_| { - let mt_proof = MerkleClaimAndProofTarget::new_virtual( - params.max_depth_mt_containers, - &mut builder, - ); - verify_merkle_proof_circuit(&mut builder, &mt_proof); - mt_proof + MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, &mut builder) }) .collect(); - let merkle_claims_target: Vec<_> = merkle_proofs_target - .clone() - .into_iter() - .map(|pf| pf.into()) + + let secret_keys_target: Vec<_> = secret_keys + .iter() + .map(|sk| builder.constant_biguint320(&sk.0)) .collect(); - let secret_key_target = builder.constant_biguint320(&secret_key.0); - let custom_predicate_verification_table = vec![]; + + let aux_table = build_operation_aux_table_circuit( + ¶ms, + &mut builder, + &merkle_proofs_target, + &secret_keys_target, + &[], + &[], + )?; + // let max_aux_entry_len = max_operation_aux_entry_len(¶ms); + // let aux = builder.add_virtual_targets(1 + max_aux_entry_len); verify_operation_circuit( ¶ms, @@ -1778,9 +1767,7 @@ mod tests { &op_target, &prev_statements_target, 0, - &merkle_claims_target, - &secret_key_target, - &custom_predicate_verification_table, + &aux_table, )?; let mut pw = PartialWitness::::new(); @@ -1940,13 +1927,7 @@ mod tests { .into_iter() .for_each(|(op, st)| { let check = std::panic::catch_unwind(|| { - operation_verify( - st, - op, - prev_statements.to_vec(), - vec![], - &SecretKey(BigUint::ZERO), - ) + operation_verify(st, op, prev_statements.to_vec(), vec![], vec![]) }); match check { Err(e) => { @@ -2011,14 +1992,7 @@ mod tests { ] .into_iter() .for_each(|(op, st)| { - assert!(operation_verify( - st, - op, - prev_statements.to_vec(), - vec![], - &SecretKey(BigUint::ZERO) - ) - .is_err()) + assert!(operation_verify(st, op, prev_statements.to_vec(), vec![], vec![]).is_err()) }); } @@ -2031,7 +2005,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2049,7 +2023,7 @@ mod tests { vec![], OperationAux::None, ); - operation_verify(st1, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st1, op, prev_statements, vec![], vec![]) } #[test] @@ -2061,7 +2035,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2084,7 +2058,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2107,7 +2081,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2130,7 +2104,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO))?; + operation_verify(st, op, prev_statements, vec![], vec![])?; // Also check negative < negative let st3: mainpod::Statement = Statement::equal( @@ -2154,7 +2128,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO))?; + operation_verify(st, op, prev_statements, vec![], vec![])?; // Also check negative < positive let st: mainpod::Statement = Statement::lt( @@ -2168,7 +2142,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2191,7 +2165,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO))?; + operation_verify(st, op, prev_statements, vec![], vec![])?; // Also check negative <= negative let st3: mainpod::Statement = Statement::equal( @@ -2215,7 +2189,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO))?; + operation_verify(st, op, prev_statements, vec![], vec![])?; // Also check negative <= positive let st: mainpod::Statement = Statement::lt_eq( @@ -2229,13 +2203,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st3, st2]; - operation_verify( - st, - op, - prev_statements.clone(), - vec![], - &SecretKey(BigUint::ZERO), - )?; + operation_verify(st, op, prev_statements.clone(), vec![], vec![])?; // Also check equality, both positive and negative. let st: mainpod::Statement = Statement::lt_eq( @@ -2248,13 +2216,7 @@ mod tests { vec![OperationArg::Index(0), OperationArg::Index(0)], OperationAux::None, ); - operation_verify( - st, - op, - prev_statements.clone(), - vec![], - &SecretKey(BigUint::ZERO), - )?; + operation_verify(st, op, prev_statements.clone(), vec![], vec![])?; let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), @@ -2265,7 +2227,7 @@ mod tests { vec![OperationArg::Index(1), OperationArg::Index(1)], OperationAux::None, ); - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2314,7 +2276,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2360,7 +2322,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) }) } @@ -2407,7 +2369,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) }) } @@ -2449,7 +2411,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) }) } @@ -2494,13 +2456,7 @@ mod tests { let prev_statements = [st1, st2, st3]; let check = std::panic::catch_unwind(|| { - operation_verify( - st, - op, - prev_statements.to_vec(), - vec![], - &SecretKey(BigUint::ZERO), - ) + operation_verify(st, op, prev_statements.to_vec(), vec![], vec![]) }); match check { Err(e) => { @@ -2533,7 +2489,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2559,7 +2515,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, vec![], &SecretKey(BigUint::ZERO)) + operation_verify(st, op, prev_statements, vec![], vec![]) } #[test] @@ -2599,13 +2555,7 @@ mod tests { no_key_pf, )]; let prev_statements = vec![root_st, key_st]; - operation_verify( - st, - op, - prev_statements, - merkle_proofs, - &SecretKey(BigUint::ZERO), - ) + operation_verify(st, op, prev_statements, merkle_proofs, vec![]) } #[test] @@ -2652,21 +2602,15 @@ mod tests { key_pf, )]; let prev_statements = vec![root_st, key_st, value_st]; - operation_verify( - st, - op, - prev_statements, - merkle_proofs, - &SecretKey(BigUint::ZERO), - ) + operation_verify(st, op, prev_statements, merkle_proofs, vec![]) } #[test] - fn test_operation_verify_publickeyof() -> Result<()> { + fn test_operation_verify_publickeyof_ok() -> Result<()> { [ - &SecretKey(BigUint::one()), - &SecretKey::new_rand(), - &SecretKey(&*GROUP_ORDER - BigUint::one()), + SecretKey(BigUint::one()), + SecretKey::new_rand(), + SecretKey(&*GROUP_ORDER - BigUint::one()), ] .into_iter() .try_for_each(|secret_key| { @@ -2684,10 +2628,10 @@ mod tests { let op = mainpod::Operation( OperationType::Native(NativeOperation::PublicKeyOf), vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, + OperationAux::PublicKeyOfIndex(0), ); let prev_statements = vec![public_key_st, secret_key_st]; - operation_verify(st, op, prev_statements, vec![], secret_key) + operation_verify(st, op, prev_statements, vec![], vec![secret_key]) }) } @@ -2707,10 +2651,10 @@ mod tests { let op = mainpod::Operation( OperationType::Native(NativeOperation::PublicKeyOf), vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, + OperationAux::PublicKeyOfIndex(0), ); let prev_statements = vec![public_key_st, secret_key_st]; - assert!(operation_verify(st, op, prev_statements, vec![], &secret_key).is_err()) + assert!(operation_verify(st, op, prev_statements, vec![], vec![secret_key]).is_err()) } #[test] @@ -2732,7 +2676,7 @@ mod tests { OperationAux::None, ); let prev_statements = vec![public_key_st, secret_key_st]; - assert!(operation_verify(st, op, prev_statements, vec![], &secret_key).is_err()) + assert!(operation_verify(st, op, prev_statements, vec![], vec![secret_key]).is_err()) } #[test] @@ -2751,7 +2695,7 @@ mod tests { let op = mainpod::Operation( OperationType::Native(NativeOperation::PublicKeyOf), vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, + OperationAux::PublicKeyOfIndex(0), ); let prev_statements = vec![public_key_st, secret_key_st]; assert!(operation_verify( @@ -2759,7 +2703,7 @@ mod tests { op, prev_statements, vec![], - &SecretKey(BigUint::from(123u32)) + vec![SecretKey(BigUint::from(123u32))] ) .is_err()) } @@ -2780,10 +2724,10 @@ mod tests { let op = mainpod::Operation( OperationType::Native(NativeOperation::PublicKeyOf), vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, + OperationAux::PublicKeyOfIndex(0), ); let prev_statements = vec![public_key_st, secret_key_st]; - assert!(operation_verify(st, op, prev_statements, vec![], &secret_key).is_err()) + assert!(operation_verify(st, op, prev_statements, vec![], vec![secret_key]).is_err()) } fn helper_statement_arg_from_template( diff --git a/src/backends/plonky2/circuits/mod.rs b/src/backends/plonky2/circuits/mod.rs index a865b3c..d98c848 100644 --- a/src/backends/plonky2/circuits/mod.rs +++ b/src/backends/plonky2/circuits/mod.rs @@ -1,5 +1,7 @@ pub mod common; +pub mod hash; pub mod mainpod; pub mod metrics; +pub mod mux_table; pub mod signedpod; pub mod utils; diff --git a/src/backends/plonky2/circuits/mux_table.rs b/src/backends/plonky2/circuits/mux_table.rs new file mode 100644 index 0000000..110dac9 --- /dev/null +++ b/src/backends/plonky2/circuits/mux_table.rs @@ -0,0 +1,216 @@ +use std::iter; + +use itertools::Itertools; +use plonky2::{ + field::{extension::Extendable, types::Field}, + hash::{ + hash_types::{HashOutTarget, RichField}, + poseidon::{PoseidonHash, PoseidonPermutation}, + }, + iop::{ + generator::{GeneratedValues, SimpleGenerator}, + target::{BoolTarget, Target}, + witness::{PartitionWitness, Witness, WitnessWrite}, + }, + plonk::circuit_data::CommonCircuitData, + util::serialization::{Buffer, IoResult, Read, Write}, +}; + +use crate::{ + backends::plonky2::{ + basetypes::CircuitBuilder, + circuits::{ + common::{CircuitBuilderPod, Flattenable, IndexTarget}, + hash::{hash_from_state_circuit, precompute_hash_state}, + }, + }, + measure_gates_begin, measure_gates_end, + middleware::{Params, F}, +}; + +// This structure allows multiplexing multiple tables into one by using tags. The table entries +// are computed by hashing the concatenation of the tag with the flattened target, with zero +// padding to normalize the size of all flattened entries. We use zero-padding on then reverse the +// array so that smaller entries can skip the initial hashes by using the precomputed hash state of +// the prefixed zeroes. +// The table offers an indexing API that returns a flattened entry that includes the "unhashing", +// this allows doing a single lookup for different possible tagged entries at the same time. +pub struct MuxTableTarget { + params: Params, + max_flattened_entry_len: usize, + hashed_tagged_entries: Vec, + tagged_entries: Vec>, +} + +impl MuxTableTarget { + pub fn new(params: &Params, max_flattened_entry_len: usize) -> Self { + Self { + params: params.clone(), + max_flattened_entry_len, + hashed_tagged_entries: Vec::new(), + tagged_entries: Vec::new(), + } + } + + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.hashed_tagged_entries.len() + } + + pub fn push(&mut self, builder: &mut CircuitBuilder, tag: u32, entry: &T) { + let flattened_entry = entry.flatten(); + self.push_flattened(builder, tag, &flattened_entry); + } + + pub fn push_flattened( + &mut self, + builder: &mut CircuitBuilder, + tag: u32, + flattened_entry: &[Target], + ) { + let measure = measure_gates_begin!(builder, "HashTaggedTblEntry"); + assert!(flattened_entry.len() <= self.max_flattened_entry_len); + let flattened = [&[builder.constant(F(tag as u64))], flattened_entry].concat(); + self.tagged_entries.push(flattened.clone()); + + let tagged_entry_max_len = 1 + self.max_flattened_entry_len; + let front_pad_elts = iter::repeat(F::ZERO) + .take(tagged_entry_max_len - flattened.len()) + .collect_vec(); + + let (perm, front_pad_elts_rem) = + precompute_hash_state::>(&front_pad_elts); + + let rev_flattened = flattened.iter().rev().copied(); + // Precompute the Poseidon state for the initial padding chunks + let inputs = front_pad_elts_rem + .iter() + .map(|v| builder.constant(*v)) + .chain(rev_flattened) + .collect_vec(); + let hash = + hash_from_state_circuit::>(builder, perm, &inputs); + + measure_gates_end!(builder, measure); + self.hashed_tagged_entries.push(hash); + } + + pub fn get(&self, builder: &mut CircuitBuilder, index: &IndexTarget) -> TableEntryTarget { + let measure = measure_gates_begin!(builder, "GetTaggedTblEntry"); + let entry_hash = builder.vec_ref(&self.params, &self.hashed_tagged_entries, index); + + let mut rev_resolved_tagged_flattened = + builder.add_virtual_targets(1 + self.max_flattened_entry_len); + let query_hash = + builder.hash_n_to_hash_no_pad::(rev_resolved_tagged_flattened.clone()); + builder.connect_flattenable(&entry_hash, &query_hash); + rev_resolved_tagged_flattened.reverse(); + let resolved_tagged_flattened = rev_resolved_tagged_flattened; + + builder.add_simple_generator(TableGetGenerator { + index: index.clone(), + tagged_entries: self.tagged_entries.clone(), + get_tagged_entry: resolved_tagged_flattened.clone(), + }); + measure_gates_end!(builder, measure); + TableEntryTarget { + params: self.params.clone(), + tagged_flattened_entry: resolved_tagged_flattened, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct TableGetGenerator { + index: IndexTarget, + tagged_entries: Vec>, + get_tagged_entry: Vec, +} + +impl, const D: usize> SimpleGenerator for TableGetGenerator { + fn id(&self) -> String { + "TableGetGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + [self.index.low, self.index.high] + .into_iter() + .chain(self.tagged_entries.iter().flatten().copied()) + .collect() + } + + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> anyhow::Result<()> { + let index_low = witness.get_target(self.index.low); + let index_high = witness.get_target(self.index.high); + let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); + + let entry = witness.get_targets(&self.tagged_entries[index as usize]); + + for (target, value) in self.get_tagged_entry.iter().zip( + entry + .iter() + .chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), + ) { + out_buffer.set_target(*target, *value)?; + } + + Ok(()) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.index.max_array_len)?; + dst.write_target(self.index.low)?; + dst.write_target(self.index.high)?; + + dst.write_usize(self.tagged_entries.len())?; + for tagged_entry in &self.tagged_entries { + dst.write_target_vec(tagged_entry)?; + } + + dst.write_target_vec(&self.get_tagged_entry) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let index = IndexTarget { + max_array_len: src.read_usize()?, + low: src.read_target()?, + high: src.read_target()?, + }; + let len = src.read_usize()?; + let mut tagged_entries = Vec::with_capacity(len); + for _ in 0..len { + tagged_entries.push(src.read_target_vec()?); + } + let get_tagged_entry = src.read_target_vec()?; + + Ok(Self { + index, + tagged_entries, + get_tagged_entry, + }) + } +} + +pub struct TableEntryTarget { + params: Params, + tagged_flattened_entry: Vec, +} + +impl TableEntryTarget { + pub fn as_type( + &self, + builder: &mut CircuitBuilder, + tag: u32, + ) -> (BoolTarget, T) { + let tag_target = self.tagged_flattened_entry[0]; + let flattened_entry = &self.tagged_flattened_entry[1..]; + let entry = T::from_flattened(&self.params, &flattened_entry[..T::size(&self.params)]); + let tag_expect = builder.constant(F(tag as u64)); + let tag_ok = builder.is_equal(tag_expect, tag_target); + (tag_ok, entry) + } +} diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 2a61889..7939e60 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -18,7 +18,7 @@ use crate::{ emptypod::EmptyPod, error::{Error, Result}, mock::emptypod::MockEmptyPod, - primitives::merkletree::MerkleClaimAndProof, + primitives::{ec::schnorr::SecretKey, merkletree::MerkleClaimAndProof}, recursion::{ hash_verifier_data, prove_rec_circuit, RecursiveCircuit, RecursiveCircuitTarget, }, @@ -29,9 +29,9 @@ use crate::{ signedpod::SignedPod, }, middleware::{ - self, resolve_wildcard_values, value_from_op, AnchoredKey, CustomPredicateBatch, Hash, - MainPodInputs, NativeOperation, OperationType, Params, Pod, PodId, PodProver, PodType, - RecursivePod, StatementArg, ToFields, VDSet, KEY_TYPE, SELF, + self, resolve_wildcard_values, value_from_op, AnchoredKey, CustomPredicateBatch, + Error as MiddlewareError, Hash, MainPodInputs, NativeOperation, OperationType, Params, Pod, + PodId, PodProver, PodType, RecursivePod, StatementArg, ToFields, VDSet, KEY_TYPE, SELF, }, timed, }; @@ -87,16 +87,13 @@ pub(crate) fn extract_custom_predicate_batches( /// Extracts all custom predicate operations with all the data required to verify them. pub(crate) fn extract_custom_predicate_verifications( params: &Params, + aux_list: &mut [OperationAux], operations: &[middleware::Operation], custom_predicate_batches: &[Arc], ) -> Result> { - let custom_predicate_data: Vec<_> = operations - .iter() - .flat_map(|op| match op { - middleware::Operation::Custom(cpr, sts) => Some((cpr, sts)), - _ => None, - }) - .map(|(cpr, sts)| { + let mut table = Vec::new(); + for (i, op) in operations.iter().enumerate() { + if let middleware::Operation::Custom(cpr, sts) = op { let wildcard_values = resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); @@ -107,73 +104,105 @@ pub(crate) fn extract_custom_predicate_verifications( .expect("find the custom predicate from the extracted unique list"); let custom_predicate_table_index = batch_index * params.max_custom_batch_size + cpr.index; - CustomPredicateVerification { + aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); + table.push(CustomPredicateVerification { custom_predicate_table_index, custom_predicate: cpr.clone(), args: wildcard_values, op_args: sts, - } - }) - .collect(); - if custom_predicate_data.len() > params.max_custom_predicate_verifications { + }); + } + } + + if table.len() > params.max_custom_predicate_verifications { return Err(Error::custom(format!( "The number of required custom predicate verifications ({}) exceeds the maximum number ({}).", - custom_predicate_data.len(), + table.len(), params.max_custom_predicate_verifications ))); } - Ok(custom_predicate_data) + Ok(table) } /// Extracts Merkle proofs from Contains/NotContains ops. pub(crate) fn extract_merkle_proofs( params: &Params, + aux_list: &mut [OperationAux], operations: &[middleware::Operation], statements: &[middleware::Statement], ) -> Result> { - assert_eq!(operations.len(), statements.len()); - let merkle_proofs: Vec<_> = operations - .iter() - .zip(statements.iter()) - .flat_map(|(op, st)| match (op, st) { + let mut table = Vec::new(); + for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { + let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); + let (root, key, value, pf) = match (op, st) { ( middleware::Operation::ContainsFromEntries(root_s, key_s, value_s, pf), middleware::Statement::Contains(root_ref, key_ref, value_ref), ) => { - let root = value_from_op(root_s, root_ref)?; - let key = value_from_op(key_s, key_ref)?; - let value = value_from_op(value_s, value_ref)?; - Some(MerkleClaimAndProof::new( - Hash::from(root.raw()), - key.raw(), - Some(value.raw()), - pf.clone(), - )) + let root = value_from_op(root_s, root_ref).ok_or_else(deduction_err)?; + let key = value_from_op(key_s, key_ref).ok_or_else(deduction_err)?; + let value = value_from_op(value_s, value_ref).ok_or_else(deduction_err)?; + (root.raw(), key.raw(), Some(value.raw()), pf) } ( middleware::Operation::NotContainsFromEntries(root_s, key_s, pf), middleware::Statement::NotContains(root_ref, key_ref), ) => { - let root = value_from_op(root_s, root_ref)?; - let key = value_from_op(key_s, key_ref)?; - Some(MerkleClaimAndProof::new( - Hash::from(root.raw()), - key.raw(), - None, - pf.clone(), - )) + let root = value_from_op(root_s, root_ref).ok_or_else(deduction_err)?; + let key = value_from_op(key_s, key_ref).ok_or_else(deduction_err)?; + (root.raw(), key.raw(), None, pf) } - _ => None, - }) - .collect(); - if merkle_proofs.len() > params.max_merkle_proofs_containers { + _ => continue, + }; + aux_list[i] = OperationAux::MerkleProofIndex(table.len()); + table.push(MerkleClaimAndProof::new( + Hash::from(root), + key, + value, + pf.clone(), + )); + } + if table.len() > params.max_merkle_proofs_containers { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - merkle_proofs.len(), + table.len(), params.max_merkle_proofs_containers ))); } - Ok(merkle_proofs) + Ok(table) +} + +pub(crate) fn extract_public_key_of( + params: &Params, + aux_list: &mut [OperationAux], + operations: &[middleware::Operation], + statements: &[middleware::Statement], +) -> Result> { + let mut table = Vec::new(); + for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { + if let ( + middleware::Operation::PublicKeyOf(_, sk_s), + middleware::Statement::PublicKeyOf(_, sk_ref), + ) = (op, st) + { + let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); + let sk = SecretKey::try_from( + value_from_op(sk_s, sk_ref) + .ok_or_else(deduction_err)? + .typed(), + )?; + aux_list[i] = OperationAux::PublicKeyOfIndex(table.len()); + table.push(sk); + } + } + if table.len() > params.max_public_key_of { + return Err(Error::custom(format!( + "The number of required PublicKeyOf verifications ({}) exceeds the maximum number ({}).", + table.len(), + params.max_public_statements + ))); + } + Ok(table) } /// Find the operation argument statement in the list of previous statements and return the index. @@ -192,52 +221,6 @@ fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Resu ))) } -/// Find the operation auxiliary data in the list of auxiliary data and return the index. -// NOTE: The `custom_predicate_verifications` is optional because in the MainPod we want to store -// the index of a custom predicate verification in the aux data, but in the MockMainPod we don't -// need that because we keep a reference to the custom predicate in the operation type, which -// removes the need for indexing. We could change the OperationType and Predicate for the backend -// to not keep a reference to the custom predicate and instead just keep the id and index and then -// do the same double indexing that the MainPod does to verify custom predicates. -fn find_op_aux( - merkle_proofs: &[MerkleClaimAndProof], - custom_predicate_verifications: Option<&[CustomPredicateVerification]>, - op: &middleware::Operation, -) -> Result { - let op_aux = op.aux(); - if let (middleware::Operation::Custom(cpr, op_args), Some(cpvs)) = - (op, custom_predicate_verifications) - { - return Ok(cpvs - .iter() - .enumerate() - .find_map(|(i, cpv)| { - (cpv.custom_predicate.batch.id() == cpr.batch.id() - && cpv.custom_predicate.index == cpr.index - && cpv - .op_args - .iter() - .zip_eq(op_args.iter()) - .all(|(a0, a1)| a0.0 == a1.predicate() && a0.1 == a1.args())) - .then_some(i) - }) - .map(OperationAux::CustomPredVerifyIndex) - .expect("custom predicate verification in the list")); - } - match &op_aux { - middleware::OperationAux::None => Ok(OperationAux::None), - middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs - .iter() - .enumerate() - .find_map(|(i, pf)| (pf.proof == *pf_arg).then_some(i)) - .map(OperationAux::MerkleProofIndex) - .ok_or(Error::custom(format!( - "Merkle proof corresponding to op arg {} not found", - op_aux - ))), - } -} - fn fill_pad(v: &mut Vec, pad_value: T, len: usize) { if v.len() > len { panic!("length exceeded"); @@ -367,12 +350,12 @@ pub(crate) fn layout_statements( pub(crate) fn process_private_statements_operations( params: &Params, statements: &[Statement], - merkle_proofs: &[MerkleClaimAndProof], - custom_predicate_verifications: Option<&[CustomPredicateVerification]>, + aux_list: &[OperationAux], input_operations: &[middleware::Operation], ) -> Result> { + assert_eq!(params.max_priv_statements(), aux_list.len()); let mut operations = Vec::new(); - for i in 0..params.max_priv_statements() { + for (i, aux) in aux_list.iter().enumerate() { let op = input_operations .get(i) .unwrap_or(&middleware::Operation::None) @@ -383,10 +366,8 @@ pub(crate) fn process_private_statements_operations( .map(|mid_arg| find_op_arg(statements, mid_arg)) .collect::>>()?; - let aux = find_op_aux(merkle_proofs, custom_predicate_verifications, &op)?; - pad_operation_args(params, &mut args); - operations.push(Operation(op.op_type(), args, aux)); + operations.push(Operation(op.op_type(), args, *aux)); } Ok(operations) } @@ -475,20 +456,25 @@ impl PodProver for Prover { }) .collect_vec(); - let merkle_proofs = extract_merkle_proofs(params, inputs.operations, inputs.statements)?; + // Aux values for backend::Operation + let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; + let merkle_proofs = + extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, + &mut aux_list, inputs.operations, &custom_predicate_batches, )?; + let public_key_of_sks = + extract_public_key_of(params, &mut aux_list, inputs.operations, inputs.statements)?; let (statements, public_statements) = layout_statements(params, false, &inputs)?; let operations = process_private_statements_operations( params, &statements, - &merkle_proofs, - Some(&custom_predicate_verifications), + &aux_list, inputs.operations, )?; let operations = process_public_statements_operations(params, &statements, operations)?; @@ -523,6 +509,7 @@ impl PodProver for Prover { statements: statements[statements.len() - params.max_statements..].to_vec(), operations, merkle_proofs, + public_key_of_sks, custom_predicate_batches, custom_predicate_verifications, }; @@ -845,6 +832,45 @@ pub mod tests { pod.verify().unwrap() } + // This pod does nothing but it's useful for debugging to keep things small. + #[ignore] + #[test] + fn test_mini_1() { + let params = middleware::Params { + max_input_signed_pods: 0, + max_input_recursive_pods: 0, + max_signed_pod_values: 0, + max_statements: 2, + max_public_statements: 1, + max_input_pods_public_statements: 0, + max_merkle_proofs_containers: 0, + max_public_key_of: 0, + max_custom_predicate_verifications: 0, + max_custom_predicate_batches: 0, + ..Default::default() + }; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + + let builder = frontend::MainPodBuilder::new(¶ms, &vd_set); + println!("{}", builder); + println!(); + + // Mock + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + let pod = (pod.pod as Box).downcast::().unwrap(); + pod.verify().unwrap(); + println!("{:#}", pod); + + // Real + let prover = Prover {}; + let pod = builder.prove(&prover).unwrap(); + let pod = (pod.pod as Box).downcast::().unwrap(); + pod.verify().unwrap() + } + #[test] fn test_mainpod_small_empty() { let params = middleware::Params { @@ -863,6 +889,7 @@ pub mod tests { max_custom_predicate_wildcards: 3, max_custom_batch_size: 2, max_merkle_proofs_containers: 2, + max_public_key_of: 2, max_depth_mt_containers: 4, max_depth_mt_vds: 6, }; @@ -927,6 +954,7 @@ pub mod tests { max_custom_batch_size: 3, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, + max_merkle_proofs_containers: 0, ..Default::default() }; println!("{:#?}", params); @@ -980,7 +1008,7 @@ pub mod tests { let st = builder .pub_op(frontend::Operation::new_entry( "entry", - Set::new(params.max_merkle_proofs_containers, set).unwrap(), + Set::new(params.max_depth_mt_containers, set).unwrap(), )) .unwrap(); diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index 49f6770..6cc63b5 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -8,7 +8,7 @@ use crate::{ mainpod::Statement, primitives::merkletree::MerkleClaimAndProof, }, - middleware::{self, OperationType}, + middleware::{self, OperationType, Params}, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -30,19 +30,36 @@ impl OperationArg { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum OperationAux { None, MerkleProofIndex(usize), + PublicKeyOfIndex(usize), CustomPredVerifyIndex(usize), } impl OperationAux { - pub fn as_usizes(&self) -> [usize; 2] { + fn table_offset_merkle_proof(_params: &Params) -> usize { + // At index 0 we store a zero entry + 1 + } + fn table_offset_public_key_of(params: &Params) -> usize { + Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers + } + fn table_offset_custom_pred_verify(params: &Params) -> usize { + Self::table_offset_public_key_of(params) + params.max_public_key_of + } + pub(crate) fn table_size(params: &Params) -> usize { + 1 + params.max_merkle_proofs_containers + + params.max_public_key_of + + params.max_custom_predicate_verifications + } + pub fn table_index(&self, params: &Params) -> usize { match self { - Self::None => [0, 0], - Self::MerkleProofIndex(i) => [*i, 0], - Self::CustomPredVerifyIndex(i) => [0, *i], + Self::None => 0, + Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i, + Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i, + Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i, } } } @@ -87,6 +104,7 @@ impl Operation { .proof .clone(), ), + OperationAux::PublicKeyOfIndex(_) => crate::middleware::OperationAux::None, }; Ok(middleware::Operation::op( self.0.clone(), @@ -114,6 +132,7 @@ impl fmt::Display for Operation { OperationAux::None => (), OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, + OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?, } Ok(()) } diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index bac4378..76b2fbe 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -14,7 +14,7 @@ use crate::{ mainpod::{ calculate_id, extract_merkle_proofs, layout_statements, process_private_statements_operations, process_public_statements_operations, Operation, - Statement, + OperationAux, Statement, }, mock::emptypod::MockEmptyPod, primitives::merkletree::MerkleClaimAndProof, @@ -172,14 +172,15 @@ impl MockMainPod { pub fn new(params: &Params, inputs: MainPodInputs) -> Result { let (statements, public_statements) = layout_statements(params, true, &inputs)?; + let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; // Extract Merkle proofs and pad. - let merkle_proofs = extract_merkle_proofs(params, inputs.operations, inputs.statements)?; + let merkle_proofs = + extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; let operations = process_private_statements_operations( params, &statements, - &merkle_proofs, - None, + &aux_list, inputs.operations, )?; let operations = process_public_statements_operations(params, &statements, operations)?; diff --git a/src/backends/plonky2/serialization.rs b/src/backends/plonky2/serialization.rs index 845325a..aa9d145 100644 --- a/src/backends/plonky2/serialization.rs +++ b/src/backends/plonky2/serialization.rs @@ -19,7 +19,7 @@ use serde::{de, ser, Deserialize, Serialize}; use crate::backends::plonky2::{ basetypes::{CircuitData, CommonCircuitData, VerifierCircuitData, C, D, F}, - circuits::{common::LtMaskGenerator, utils::DebugGenerator}, + circuits::{common::LtMaskGenerator, mux_table::TableGetGenerator, utils::DebugGenerator}, primitives::ec::{ bits::ConditionalZeroGenerator, curve::PointSquareRootGenerator, @@ -92,7 +92,6 @@ use plonky2::{ #[derive(Debug)] pub(crate) struct Pod2GeneratorSerializer {} -// TODO: Add pod2 custom generators impl WitnessGeneratorSerializer for Pod2GeneratorSerializer { impl_generator_serializer! { Pod2GeneratorSerializer, @@ -130,7 +129,8 @@ impl WitnessGeneratorSerializer for Pod2GeneratorSerializer { RecursiveGenerator<1, NNFMulSimple<5, QuinticExtension>>, RecursiveGenerator, RecursiveGenerator<1, ECAddHomogOffset>, - ComparisonGenerator + ComparisonGenerator, + TableGetGenerator } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 6d75071..97b5a12 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -764,8 +764,11 @@ pub struct Params { pub max_depth_mt_containers: usize, // maximum depth of the merkle tree gadget used for verifier_data membership // check. This allows creating verifying sets of pod circuits of size - // 2^max_depth_mt_vds. + // 2^max_depth_mt_vds. Limits the number of container operations of the type Contains, + // NotContains. pub max_depth_mt_vds: usize, + // maximum number of public key derivations used for PublicKeyOf operation + pub max_public_key_of: usize, // // The following parameters define how a pod id is calculated. They need to be the same among // different circuits to be compatible in their verification. @@ -803,6 +806,7 @@ impl Default for Params { max_merkle_proofs_containers: 5, max_depth_mt_containers: 32, max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits + max_public_key_of: 2, } } } @@ -828,10 +832,6 @@ impl Params { Self::predicate_size() + STATEMENT_ARG_F_LEN * self.max_statement_args } - pub fn operation_size(&self, operation_arg_f_len: usize) -> usize { - Self::operation_type_size() + operation_arg_f_len * self.max_operation_args - } - pub const fn statement_tmpl_size(&self) -> usize { Self::predicate_size() + self.max_statement_args * Self::statement_tmpl_arg_size() }