diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index d7d6b39..f65eb7b 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -51,7 +51,7 @@ use crate::{ mainpod::cache_get_rec_main_pod_verifier_circuit_data, primitives::merkletree::MerkleClaimAndProof, }, - middleware::{containers::Array, Hash, Params, RawValue, Result, Value}, + middleware::{containers::Array, Hash, Params, RawValue, Result, Value, EMPTY_HASH}, }; pub static DEFAULT_VD_LIST: LazyLock> = LazyLock::new(|| { @@ -95,6 +95,12 @@ impl Eq for VDSet {} impl VDSet { fn new_from_vds_hashes(mut vds_hashes: Vec) -> Self { + // If vds_hashes is empty we add an zero entry to be used as padding when verifying merkle + // proofs of inclusion in the vds set. This zero entry can't be abused because no circuit + // exists with a vds_hash = 0. + if vds_hashes.is_empty() { + vds_hashes.push(EMPTY_HASH); + } // before using the hash values, sort them, so that each set of // verifier_datas gets the same VDSet root vds_hashes.sort(); @@ -150,6 +156,9 @@ impl VDSet { ))? .clone()) } + pub fn get_vds_proof_0(&self) -> MerkleClaimAndProof { + self.proofs_map[&self.vds_hashes[0]].clone() + } /// Returns true if the `verifier_data_hash` is in the set pub fn contains(&self, verifier_data_hash: HashOut) -> bool { self.proofs_map diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index dfee8a0..bb194a0 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -30,7 +30,7 @@ use crate::{ mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::{ verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, - MerkleProof, MerkleTreeStateTransitionProofTarget, + MerkleProof, MerkleProofExistenceTarget, MerkleTreeStateTransitionProofTarget, }, }, middleware::{ @@ -725,7 +725,6 @@ impl CustomPredicateInBatchTarget { let mtp = MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder); let _true = builder._true(); - builder.connect(_true.target, mtp.enabled.target); builder.connect(_true.target, mtp.existence.target); let zero = builder.constant(F(0)); let key = ValueTarget { @@ -763,7 +762,7 @@ impl CustomPredicateInBatchTarget { value: RawValue::from(hash_fields(&predicate.to_fields())), proof: mtp.clone(), }; - self.mtp.set_targets(pw, true, &mtp_claim)?; + self.mtp.set_targets(pw, &mtp_claim)?; Ok(()) } } @@ -987,7 +986,6 @@ pub trait Flattenable { /// elsewhere. #[derive(Copy, Clone)] pub struct MerkleClaimTarget { - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -997,7 +995,6 @@ pub struct MerkleClaimTarget { impl From for MerkleClaimTarget { fn from(pf: MerkleClaimAndProofTarget) -> Self { Self { - enabled: pf.enabled, root: pf.root, key: pf.key, value: pf.value, @@ -1006,12 +1003,25 @@ impl From for MerkleClaimTarget { } } +impl MerkleClaimTarget { + pub fn from_proof_existence( + builder: &mut CircuitBuilder, + pf: MerkleProofExistenceTarget, + ) -> Self { + Self { + root: pf.root, + key: pf.key, + value: pf.value, + existence: builder._true(), + } + } +} + /// For the purpose of op verification, we need only look up the /// Merkle state transition claim rather than the Merkle state /// transition proof since it is verified elsewhere. #[derive(Copy, Clone)] pub struct MerkleTreeStateTransitionClaimTarget { - pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) new_root: HashOutTarget, @@ -1022,7 +1032,6 @@ pub struct MerkleTreeStateTransitionClaimTarget { impl From for MerkleTreeStateTransitionClaimTarget { fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self { Self { - enabled: pf.enabled, op: pf.op, old_root: pf.old_root, new_root: pf.new_root, @@ -1063,7 +1072,6 @@ impl Flattenable for ValueTarget { impl Flattenable for MerkleClaimTarget { fn flatten(&self) -> Vec { [ - vec![self.enabled.target], self.root.elements.to_vec(), self.key.elements.to_vec(), self.value.elements.to_vec(), @@ -1075,31 +1083,28 @@ impl Flattenable for MerkleClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - enabled: BoolTarget::new_unsafe(vs[0]), - root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), - key: ValueTarget::from_slice( - &vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE], - ), + root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), + key: ValueTarget::from_slice(&vs[NUM_HASH_OUT_ELTS..NUM_HASH_OUT_ELTS + VALUE_SIZE]), value: ValueTarget::from_slice( - &vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], + &vs[NUM_HASH_OUT_ELTS + VALUE_SIZE..NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], ), - existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), + existence: BoolTarget::new_unsafe(vs[NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), } } fn size(params: &Params) -> usize { - 2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params) + HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 } } impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn flatten(&self) -> Vec { [ - vec![self.enabled.target, self.op], self.old_root.elements.to_vec(), self.new_root.elements.to_vec(), self.op_key.elements.to_vec(), self.op_value.elements.to_vec(), + vec![self.op], ] .concat() } @@ -1107,24 +1112,22 @@ impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - enabled: BoolTarget::new_unsafe(vs[0]), - op: vs[1], - old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()), + old_root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), new_root: HashOutTarget::from_vec( - vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(), + vs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS].to_vec(), ), op_key: ValueTarget::from_slice( - &vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE], + &vs[2 * NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS + VALUE_SIZE], ), op_value: ValueTarget::from_slice( - &vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE - ..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE], + &vs[2 * NUM_HASH_OUT_ELTS + VALUE_SIZE..2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], ), + op: vs[2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], } } fn size(params: &Params) -> usize { - 2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params) + 2 * HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod/mod.rs similarity index 53% rename from src/backends/plonky2/circuits/mainpod.rs rename to src/backends/plonky2/circuits/mainpod/mod.rs index 0605e07..89ed3cf 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod/mod.rs @@ -16,6 +16,9 @@ use plonky2::{ use plonky2_u32::gadgets::multiple_comparison::list_le_circuit; use serde::{Deserialize, Serialize}; +#[cfg(test)] +mod tests; + use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, VDSet}, @@ -33,7 +36,7 @@ use crate::{ }, emptypod::EmptyPod, error::Result, - mainpod::{self, pad_statement, SignedBy}, + mainpod::{self, pad_statement, MerkleProofs, MerkleTransitionProofs, SignedBy}, primitives::{ ec::{ bits::{BigUInt320Target, CircuitBuilderBits}, @@ -44,8 +47,9 @@ use crate::{ schnorr::{CircuitBuilderSchnorr, SecretKey, SignatureTarget, WitnessWriteSchnorr}, }, merkletree::{ - verify_merkle_proof_circuit, verify_merkle_state_transition_circuit, - MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp, + verify_merkle_proof_circuit, verify_merkle_proof_existence_circuit, + verify_merkle_state_transition_circuit, MerkleClaimAndProof, + MerkleClaimAndProofTarget, MerkleProof, MerkleProofExistenceTarget, MerkleTreeOp, MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget, }, signature::{verify_signature_circuit, SignatureVerifyTarget}, @@ -55,8 +59,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, - ToFields, Value, BASE_PARAMS, F, HASH_SIZE, + NativePredicate, Params, PredicatePrefix, Statement, StatementTmplArgPrefix, ToFields, + Value, BASE_PARAMS, F, HASH_SIZE, VALUE_SIZE, }, }; // @@ -238,21 +242,21 @@ fn verify_operation_public_statement_circuit( enum OperationAuxTableTag { None = 0, MerkleProof = 1, - PublicKeyOf = 2, - SignedBy = 3, - MerkleTreeStateTransitionProof = 4, - CustomPredVerify = 5, + MerkleTransitionProof = 2, + CustomPredVerify = 3, + PublicKeyOf = 4, + SignedBy = 5, } fn max_operation_aux_entry_len(params: &Params) -> usize { [ - (params.max_merkle_proofs_containers > 0).then(|| MerkleClaimTarget::size(params)), - (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), - (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), - (params.max_merkle_tree_state_transition_proofs_containers > 0) + (params.containers.state.max_total() > 0).then(|| MerkleClaimTarget::size(params)), + (params.containers.transition.max_total() > 0) .then(|| MerkleTreeStateTransitionClaimTarget::size(params)), (params.max_custom_predicate_verifications > 0) .then(|| CustomPredicateVerifyQueryTarget::size(params)), + (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), + (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), ] .into_iter() .flatten() @@ -306,14 +310,59 @@ impl SignedByTarget { } } +fn append_container_proofs_operation_aux_table_circuit( + builder: &mut CircuitBuilder, + table: &mut MuxTableTarget, + merkle_proofs: &MerkleProofsTarget, + merkle_transition_proofs: &MerkleTransitionProofsTarget, +) { + // Small MerkleProofs: verify container merkle proofs (only inclusion) + for merkle_proof in &merkle_proofs.small { + verify_merkle_proof_existence_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from_proof_existence(builder, merkle_proof.clone()); + + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + } + // Medium MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) + for merkle_proof in &merkle_proofs.medium { + verify_merkle_proof_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from(merkle_proof.clone()); + + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + } + + // Small Merkle state transition proofs: verify op proof (only update) + for merkle_transition_proof in &merkle_transition_proofs.small { + verify_merkle_state_transition_circuit(builder, merkle_transition_proof); + let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); + + table.push( + builder, + OperationAuxTableTag::MerkleTransitionProof as u32, + &entry, + ); + } + // Medium Merkle state transition proofs: verify op proof (insert/update/delete) + for merkle_transition_proof in &merkle_transition_proofs.medium { + verify_merkle_state_transition_circuit(builder, merkle_transition_proof); + let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); + + table.push( + builder, + OperationAuxTableTag::MerkleTransitionProof as u32, + &entry, + ); + } +} + #[allow(clippy::too_many_arguments)] fn build_operation_aux_table_circuit( params: &Params, builder: &mut CircuitBuilder, - merkle_proofs: &[MerkleClaimAndProofTarget], + merkle_proofs: &MerkleProofsTarget, + merkle_transition_proofs: &MerkleTransitionProofsTarget, public_key_of_sks: &[BigUInt320Target], signed_bys: &[SignedByTarget], - merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProofTarget], custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], custom_predicate_table: &[HashOutTarget], ) -> Result { @@ -322,19 +371,56 @@ fn build_operation_aux_table_circuit( params.max_custom_predicate_verifications, custom_predicate_verifications.len() ); - assert_eq!(params.max_merkle_proofs_containers, merkle_proofs.len()); + assert_eq!(params.containers.state.max_small, merkle_proofs.small.len()); + assert_eq!( + params.containers.state.max_medium, + merkle_proofs.medium.len() + ); let max_entry_len = max_operation_aux_entry_len(params); let mut table = MuxTableTarget::new(params, max_entry_len); // None table.push_flattened(builder, OperationAuxTableTag::None as u32, &[]); - // MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) - for merkle_proof in merkle_proofs { - verify_merkle_proof_circuit(builder, merkle_proof); - let entry = MerkleClaimTarget::from(merkle_proof.clone()); + append_container_proofs_operation_aux_table_circuit( + builder, + &mut table, + merkle_proofs, + merkle_transition_proofs, + ); - table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + // CustomPredVerify: verify custom predicate statements verification against operations + for entry in custom_predicate_verifications { + let measure = measure_gates_begin!(builder, "CustomPredVerify"); + // Verify the custom predicate operation + let (statement, op_type) = make_custom_statement_circuit( + params, + builder, + &entry.custom_predicate, + &entry.op_args, + &entry.args, + )?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + params, + custom_predicate_table, + &entry.custom_predicate_table_index, + ); + let out_query_hash = entry.custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let query = CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args: entry.op_args.clone(), // input + }; + table.push( + builder, + OperationAuxTableTag::CustomPredVerify as u32, + &query, + ); + measure_gates_end!(builder, measure); } // PublicKeyOf: verify the derivation from a Schnorr secret key to public key @@ -394,53 +480,6 @@ fn build_operation_aux_table_circuit( measure_gates_end!(builder, measure); } - // Merkle state transition proofs: verify op proof (insert/update/delete) - for merkle_tree_state_transition_proof in merkle_tree_state_transition_proofs { - verify_merkle_state_transition_circuit(builder, merkle_tree_state_transition_proof); - let entry = - MerkleTreeStateTransitionClaimTarget::from(merkle_tree_state_transition_proof.clone()); - - table.push( - builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, - &entry, - ); - } - - // CustomPredVerify: verify custom predicate statements verification against operations - for entry in custom_predicate_verifications { - let measure = measure_gates_begin!(builder, "CustomPredVerify"); - // Verify the custom predicate operation - let (statement, op_type) = make_custom_statement_circuit( - params, - builder, - &entry.custom_predicate, - &entry.op_args, - &entry.args, - )?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = builder.vec_ref( - params, - custom_predicate_table, - &entry.custom_predicate_table_index, - ); - let out_query_hash = entry.custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let query = CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args: entry.op_args.clone(), // input - }; - table.push( - builder, - OperationAuxTableTag::CustomPredVerify as u32, - &query, - ); - measure_gates_end!(builder, measure); - } - measure_gates_end!(builder, measure); Ok(table) } @@ -504,7 +543,7 @@ fn verify_operation_circuit( } // Skip these if there are no resolved aux entries if let Some(resolved_aux) = resolved_aux { - if params.max_merkle_proofs_containers > 0 { + if params.containers.state.max_total() > 0 { op_checks.extend_from_slice(&[ verify_contains_from_entries_circuit( params, @@ -544,7 +583,7 @@ fn verify_operation_circuit( &cache, )); } - if params.max_merkle_tree_state_transition_proofs_containers > 0 { + if params.containers.transition.max_total() > 0 { op_checks.extend_from_slice(&[ verify_merkle_insert_circuit( params, @@ -612,8 +651,6 @@ fn verify_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, /* ...and it must be an existence proof. */ resolved_merkle_claim.existence, /* ...for the root-key-value triple in the resolved op args. */ @@ -661,8 +698,6 @@ fn verify_not_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, /* ...and it must be a nonexistence proof. */ builder.not(resolved_merkle_claim.existence), /* ...for the root-key pair in the resolved op args. */ @@ -703,7 +738,7 @@ fn verify_merkle_insert_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerInsertFromEntries); @@ -714,8 +749,6 @@ fn verify_merkle_insert_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an insertion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -778,7 +811,7 @@ fn verify_merkle_update_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerUpdateFromEntries); @@ -789,8 +822,6 @@ fn verify_merkle_update_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an update proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -853,7 +884,7 @@ fn verify_merkle_delete_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerDeleteFromEntries); @@ -864,8 +895,6 @@ fn verify_merkle_delete_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be a deletion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -1774,19 +1803,20 @@ fn verify_main_pod_circuit( // NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction // pod that declares a statement with no arguments. - let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); + let st0_is_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); // Introduction pods can only have Introduction or None statements - let mut intro_ok = is_blank_intro; + let mut intro_ok = st0_is_intro; for self_st in &input_pod_self_statements[1..] { let st_is_intro = self_st.pred_is_blank_intro(builder); let st_is_none = self_st.has_native_type(builder, NativePredicate::None); let st_is_intro_or_none = builder.or(st_is_intro, st_is_none); intro_ok = builder.and(intro_ok, st_is_intro_or_none); } - builder.connect(is_blank_intro.target, intro_ok.target); + builder.connect(st0_is_intro.target, intro_ok.target); - let is_main = builder.not(is_blank_intro); + let is_not_main = st0_is_intro; + let is_main = builder.not(is_not_main); for self_st in input_pod_self_statements { let normalized_st = normalize_statement_circuit( params, @@ -1805,18 +1835,19 @@ fn verify_main_pod_circuit( // their verifier_data_hash appears in their introduction statement. // - verify_merkle_proof_circuit(builder, vd_mt_proof); + verify_merkle_proof_existence_circuit(builder, vd_mt_proof); - // ensure that mt_proof is enabled if it's a main pod - builder.connect(vd_mt_proof.enabled.target, is_main.target); // connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof // verifies against the vds_root builder.connect_hashes(main_pod.vds_root, vd_mt_proof.root); - // connect vd_mt_proof's value with the verified_proof.verifier_data_hash - builder.connect_hashes( - verified_proof.verifier_data_hash, - HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()), - ); + // connect vd_mt_proof's value with the verified_proof.verifier_data_hash only when is_main + for i in 0..VALUE_SIZE { + builder.conditional_assert_eq( + is_main.target, + verified_proof.verifier_data_hash.elements[i], + vd_mt_proof.value.elements[i], + ) + } // // Verify that VD array that input pod uses is the same we use now. @@ -1846,9 +1877,9 @@ fn verify_main_pod_circuit( params, builder, &main_pod.merkle_proofs, + &main_pod.merkle_transition_proofs, &main_pod.public_key_of_sks, &main_pod.signed_bys, - &main_pod.merkle_tree_state_transition_proofs, &main_pod.custom_predicate_verifications, &custom_predicate_table, )?; @@ -1894,19 +1925,77 @@ fn verify_main_pod_circuit( Ok(sts_hash) } +#[derive(Clone, Serialize, Deserialize)] +pub struct MerkleProofsTarget { + small: Vec, + medium: Vec, +} + +impl MerkleProofsTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + Self { + small: (0..params.containers.state.max_small) + .map(|_| { + MerkleProofExistenceTarget::new_virtual( + params.containers.max_depth_small, + builder, + ) + }) + .collect(), + medium: (0..params.containers.state.max_medium) + .map(|_| { + MerkleClaimAndProofTarget::new_virtual( + params.containers.max_depth_medium, + builder, + ) + }) + .collect(), + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct MerkleTransitionProofsTarget { + small: Vec, + medium: Vec, +} + +impl MerkleTransitionProofsTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + Self { + small: (0..params.containers.transition.max_small) + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_small, + builder, + ) + }) + .collect(), + medium: (0..params.containers.transition.max_medium) + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_medium, + builder, + ) + }) + .collect(), + } + } +} + #[derive(Clone, Serialize, Deserialize)] pub struct MainPodVerifyTarget { params: Params, vds_root: HashOutTarget, - vd_mt_proofs: Vec, + vd_mt_proofs: Vec, input_pods_self_statements: Vec>, // The KEY_TYPE statement must be the first public one input_statements: Vec, operations: Vec, - merkle_proofs: Vec, + merkle_proofs: MerkleProofsTarget, + merkle_transition_proofs: MerkleTransitionProofsTarget, public_key_of_sks: Vec, signed_bys: Vec, - merkle_tree_state_transition_proofs: Vec, custom_predicates: Vec, custom_predicate_verifications: Vec, } @@ -1917,7 +2006,7 @@ impl MainPodVerifyTarget { params: params.clone(), vds_root: builder.add_virtual_hash(), vd_mt_proofs: (0..params.max_input_pods) - .map(|_| MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_vds, builder)) + .map(|_| MerkleProofExistenceTarget::new_virtual(params.max_depth_mt_vds, builder)) .collect(), input_pods_self_statements: (0..params.max_input_pods) .map(|_| { @@ -1932,26 +2021,14 @@ impl MainPodVerifyTarget { operations: (0..params.max_statements) .map(|_| builder.add_virtual_operation(params)) .collect(), - merkle_proofs: (0..params.max_merkle_proofs_containers) - .map(|_| { - MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder) - }) - .collect(), + merkle_proofs: MerkleProofsTarget::new_virtual(params, builder), + merkle_transition_proofs: MerkleTransitionProofsTarget::new_virtual(params, builder), public_key_of_sks: (0..params.max_public_key_of) .map(|_| builder.add_virtual_biguint320_target()) .collect(), signed_bys: (0..params.max_signed_by) .map(|_| SignedByTarget::new_virtual(builder)) .collect(), - merkle_tree_state_transition_proofs: (0..params - .max_merkle_tree_state_transition_proofs_containers) - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.max_depth_mt_containers, - builder, - ) - }) - .collect(), custom_predicates: (0..params.max_custom_predicates) .map(|_| CustomPredicateInBatchTarget::new_virtual(builder)) .collect(), @@ -1960,6 +2037,64 @@ impl MainPodVerifyTarget { .collect(), } } + + fn set_container_mtp_targets( + &self, + pw: &mut PartialWitness, + input: &MainPodVerifyInput, + ) -> Result<()> { + assert!(input.merkle_proofs.small.len() <= self.params.containers.state.max_small); + for (i, mp) in input.merkle_proofs.small.iter().enumerate() { + self.merkle_proofs.small[i].set_targets(pw, mp)?; + } + // Padding + let pad_mp = MerkleClaimAndProof::pad(); + for i in input.merkle_proofs.small.len()..self.params.containers.state.max_small { + self.merkle_proofs.small[i].set_targets(pw, &pad_mp)?; + } + + assert!(input.merkle_proofs.medium.len() <= self.params.containers.state.max_medium); + for (i, mp) in input.merkle_proofs.medium.iter().enumerate() { + self.merkle_proofs.medium[i].set_targets(pw, mp)?; + } + // Padding + let pad_mp = MerkleClaimAndProof::pad(); + for i in input.merkle_proofs.medium.len()..self.params.containers.state.max_medium { + self.merkle_proofs.medium[i].set_targets(pw, &pad_mp)?; + } + + assert!( + input.merkle_transition_proofs.small.len() + <= self.params.containers.transition.max_small + ); + for (i, mtp) in input.merkle_transition_proofs.small.iter().enumerate() { + self.merkle_transition_proofs.small[i].set_targets(pw, mtp)?; + } + // Padding + let pad_mtp = MerkleTreeStateTransitionProof::pad(); + for i in + input.merkle_transition_proofs.small.len()..self.params.containers.transition.max_small + { + self.merkle_transition_proofs.small[i].set_targets(pw, &pad_mtp)?; + } + + assert!( + input.merkle_transition_proofs.medium.len() + <= self.params.containers.transition.max_medium + ); + for (i, mtp) in input.merkle_transition_proofs.medium.iter().enumerate() { + self.merkle_transition_proofs.medium[i].set_targets(pw, mtp)?; + } + // Padding + let pad_mtp = MerkleTreeStateTransitionProof::pad(); + for i in input.merkle_transition_proofs.medium.len() + ..self.params.containers.transition.max_medium + { + self.merkle_transition_proofs.medium[i].set_targets(pw, &pad_mtp)?; + } + + Ok(()) + } } pub struct CustomPredicateVerification { @@ -1974,15 +2109,14 @@ pub struct MainPodVerifyInput { /// field containing the `vd_mt_proofs` aside from the `vds_set`, because /// inside the MainPodVerifyTarget circuit, since it is the InnerCircuit for /// the RecursiveCircuit, we don't have access to the used verifier_datas. - /// The bool is used as `enabled` and will be false for intro pods. - pub vd_mt_proofs: Vec<(bool, MerkleClaimAndProof)>, + pub vd_mt_proofs: Vec, pub input_pods_pub_self_statements: Vec>, pub statements: Vec, pub operations: Vec, - pub merkle_proofs: Vec, + pub merkle_proofs: MerkleProofs, + pub merkle_transition_proofs: MerkleTransitionProofs, pub public_key_of_sks: Vec, pub signed_bys: Vec, - pub merkle_tree_state_transition_proofs: Vec, pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>, pub custom_predicate_verifications: Vec, } @@ -2038,8 +2172,8 @@ impl InnerCircuit for MainPodVerifyTarget { ); let input_pods_len = input.vd_mt_proofs.len(); assert!(input_pods_len <= self.params.max_input_pods); - for (i, (enable, vd_mt_proof)) in input.vd_mt_proofs.iter().enumerate() { - self.vd_mt_proofs[i].set_targets(pw, *enable, vd_mt_proof)?; + for (i, vd_mt_proof) in input.vd_mt_proofs.iter().enumerate() { + self.vd_mt_proofs[i].set_targets(pw, vd_mt_proof)?; } for (i, pod_pub_statements) in input.input_pods_pub_self_statements.iter().enumerate() { set_targets_input_pods_self_statements( @@ -2053,14 +2187,10 @@ impl InnerCircuit for MainPodVerifyTarget { if input_pods_len != self.params.max_input_pods { let empty_pod = EmptyPod::new_boxed(input.vds_set.clone()); let empty_pod_statements = empty_pod.pub_statements(); - let empty_mt_proof = MerkleClaimAndProof { - root: input.vds_set.root(), - value: RawValue::from(empty_pod.verifier_data_hash()), - ..MerkleClaimAndProof::empty() - }; + let pad_mt_proof = input.vds_set.get_vds_proof_0(); for i in input_pods_len..self.params.max_input_pods { - self.vd_mt_proofs[i].set_targets(pw, false, &empty_mt_proof)?; + self.vd_mt_proofs[i].set_targets(pw, &pad_mt_proof)?; set_targets_input_pods_self_statements( pw, &self.params, @@ -2076,15 +2206,7 @@ impl InnerCircuit for MainPodVerifyTarget { self.operations[i].set_targets(pw, &self.params, op)?; } - assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs_containers); - for (i, mp) in input.merkle_proofs.iter().enumerate() { - self.merkle_proofs[i].set_targets(pw, true, mp)?; - } - // Padding - let pad_mp = MerkleClaimAndProof::empty(); - for i in input.merkle_proofs.len()..self.params.max_merkle_proofs_containers { - self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?; - } + self.set_container_mtp_targets(pw, input)?; assert!(input.public_key_of_sks.len() <= self.params.max_public_key_of); for (i, sk) in input.public_key_of_sks.iter().enumerate() { @@ -2106,25 +2228,6 @@ impl InnerCircuit for MainPodVerifyTarget { self.signed_bys[i].set_targets(pw, &pad_signed_by)?; } - assert!( - input.merkle_tree_state_transition_proofs.len() - <= self - .params - .max_merkle_tree_state_transition_proofs_containers - ); - for (i, mtp) in input.merkle_tree_state_transition_proofs.iter().enumerate() { - self.merkle_tree_state_transition_proofs[i].set_targets(pw, true, mtp)?; - } - // Padding - let pad_mtp = MerkleTreeStateTransitionProof::empty(); - for i in input.merkle_tree_state_transition_proofs.len() - ..self - .params - .max_merkle_tree_state_transition_proofs_containers - { - self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?; - } - assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates); for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() { self.custom_predicates[i].set_targets(pw, cp, mtp)?; @@ -2169,1729 +2272,3 @@ impl InnerCircuit for MainPodVerifyTarget { Ok(()) } } - -#[cfg(test)] -mod tests { - use std::{iter, ops::Not}; - - use num::FromPrimitive; - use plonky2::{ - field::{goldilocks_field::GoldilocksField, types::Field}, - hash::hash_types::HashOut, - iop::witness::WitnessWrite, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, - }; - - use super::*; - use crate::{ - backends::plonky2::{ - basetypes::C, - circuits::common::tests::I64_TEST_PAIRS, - mainpod::{calculate_statements_hash, OperationArg, OperationAux}, - primitives::{ - ec::schnorr::SecretKey, - merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, - }, - signer, - }, - dict, - frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::{ - hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, - BASE_PARAMS, EMPTY_VALUE, - }, - }; - - #[derive(Default)] - struct Aux { - merkle_proofs: Vec, - secret_keys: Vec, - signed_bys: Vec, - merkle_tree_state_transition_proofs: Vec, - } - - impl Aux { - fn merkle_proof(v: MerkleClaimAndProof) -> Self { - Self { - merkle_proofs: vec![v], - ..Default::default() - } - } - fn secret_key(v: SecretKey) -> Self { - Self { - secret_keys: vec![v], - ..Default::default() - } - } - fn signed_by(v: SignedBy) -> Self { - Self { - signed_bys: vec![v], - ..Default::default() - } - } - fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { - Self { - merkle_tree_state_transition_proofs: vec![v], - ..Default::default() - } - } - } - - fn operation_verify( - st: mainpod::Statement, - op: mainpod::Operation, - prev_statements: Vec, - aux: Aux, - ) -> Result<()> { - let params = Params { - max_merkle_proofs_containers: aux.merkle_proofs.len(), - max_public_key_of: aux.secret_keys.len(), - max_signed_by: aux.signed_bys.len(), - max_merkle_tree_state_transition_proofs_containers: aux - .merkle_tree_state_transition_proofs - .len(), - max_custom_predicate_verifications: 0, - max_custom_predicates: 0, - ..Default::default() - }; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_target = builder.add_virtual_statement(false); - let op_target = builder.add_virtual_operation(¶ms); - let prev_statements_target: Vec<_> = (0..prev_statements.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let prev_statement_flatteneds_target: Vec> = prev_statements_target - .iter() - .map(|st| st.flatten()) - .collect(); - let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target - .iter() - .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) - .collect(); - - let merkle_proofs_target: Vec<_> = aux - .merkle_proofs - .iter() - .map(|_| { - MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, &mut builder) - }) - .collect(); - - let secret_keys_target: Vec<_> = aux - .secret_keys - .iter() - .map(|sk| builder.constant_biguint320(&sk.0)) - .collect(); - - let signed_by_targets: Vec<_> = aux - .signed_bys - .iter() - .map(|_| SignedByTarget::new_virtual(&mut builder)) - .collect(); - - let merkle_tree_state_transition_proofs_target: Vec<_> = aux - .merkle_tree_state_transition_proofs - .iter() - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.max_depth_mt_containers, - &mut builder, - ) - }) - .collect(); - - let aux_table = build_operation_aux_table_circuit( - ¶ms, - &mut builder, - &merkle_proofs_target, - &secret_keys_target, - &signed_by_targets, - &merkle_tree_state_transition_proofs_target, - &[], - &[], - )?; - - verify_operation_circuit( - ¶ms, - &mut builder, - &st_target, - &op_target, - &prev_statement_flatteneds_target, - &prev_statement_hashes_target, - &aux_table, - )?; - - let mut pw = PartialWitness::::new(); - st_target.set_targets(&mut pw, &st)?; - op_target.set_targets(&mut pw, ¶ms, &op)?; - for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { - prev_st_target.set_targets(&mut pw, prev_st)?; - } - for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { - signed_by_target.set_targets(&mut pw, signed_by)? - } - for (merkle_proof_target, merkle_proof) in - merkle_proofs_target.iter().zip(aux.merkle_proofs.iter()) - { - merkle_proof_target.set_targets(&mut pw, true, merkle_proof)? - } - for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in - merkle_tree_state_transition_proofs_target - .iter() - .zip(aux.merkle_tree_state_transition_proofs.iter()) - { - merkle_tree_state_transition_proof_target.set_targets( - &mut pw, - true, - merkle_tree_state_transition_proof, - )? - } - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) - } - - #[test] - fn test_lt_lteq_verify_failures() { - let invalid_int = RawValue([ - GoldilocksField::NEG_ONE, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - ]); - - let prev_statements = [Statement::None.into()]; - - [ - // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(55, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -56).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(-55, -56).into(), - ), - // 56 < p-1 and p-1 <= p-1 should fail to verify, where p - // is the Goldilocks prime and 'p-1' occupies a single - // limb. - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, invalid_int).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(invalid_int, invalid_int).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }); - } - - #[test] - fn test_eq_neq_verify_failures() { - let prev_statements = [Statement::None.into()]; - - [ - // 56 == 55, 55 != 55 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::equal(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::not_equal(55, 55).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) - }); - } - - #[test] - fn test_operation_verify_none() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::None), - vec![], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_copy() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::CopyStatement), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_eq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 55}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_neq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 75}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_lt() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"hello" => 56}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < negative - let dict3 = dict!({"hola" => -56}); - let dict4 = dict!({"mundo" => -55}); - let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); - let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict4, "mundo")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < positive - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_lteq() -> Result<()> { - let local = dict!({ - "n55" => 55, - "n56" => 56, - "n_56" => -56, - "n_55" => -55, - }); - let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n55")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= negative - let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); - let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_55")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= positive - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - - // Also check equality, both positive and negative. - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(1), OperationArg::Index(1)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_hashof() -> Result<()> { - let input_values = [ - Value::from(RawValue([ - GoldilocksField(1), - GoldilocksField(2), - GoldilocksField(3), - GoldilocksField(4), - ])), - Value::from(512), - ]; - let v1 = hash_values(&input_values); - let [v2, v3] = input_values; - - let local = dict!({ - "hola" => v1, - "mundo" => v2.clone(), - "!" => v3.clone(), - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); - - let st: mainpod::Statement = Statement::hash_of( - AnchoredKey::from((&local, "hola")), - AnchoredKey::from((&local, "mundo")), - AnchoredKey::from((&local, "!")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::HashOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_sumof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (sum, overflow) = a.overflowing_add(b); - overflow.not().then_some((a, b, sum)) - }) - .try_for_each(|(a, b, sum)| { - let local = dict!({ - "sum" => sum, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { - let local = dict!({ - "a" => 3, - "noise" => 99, - "sum" => 6, - }); - let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); - let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); - let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "a")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - // Non-monotonic and repeated indices to stress random-access resolution. - OperationArg::Index(2), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = vec![st_a, st_noise, st_sum]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_productof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (prod, overflow) = a.overflowing_mul(b); - overflow.not().then_some((a, b, prod)) - }) - .try_for_each(|(a, b, prod)| { - let local = dict!({ - "prod" => prod, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = - Statement::contains(local.clone(), "prod", prod).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::product_of( - AnchoredKey::from((&local, "prod")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ProductOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_maxof() -> Result<()> { - I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { - let max = i64::max(a, b); - let local = dict!({ - "max" => max, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::max_of( - AnchoredKey::from((&local, "max")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_maxof_failures() { - [(5, 3, 4), (5, 5, 8), (3, 4, 5)] - .into_iter() - .for_each(|(max, a, b)| { - let st: mainpod::Statement = Statement::max_of(max, a, b).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = [Statement::None.into()]; - - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }) - } - - #[test] - fn test_operation_verify_lt_to_neq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 20, - }); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st1: mainpod::Statement = Statement::lt( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtToNotEqual), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![st1]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_transitive_eq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 10, - "c" => 10, - }); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let st1: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st2: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "b")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::TransitiveEqualFromStatements), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_sintains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(5); - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - - let no_key_pf = mt.prove_nonexistence(&key.raw())?; - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = - Statement::contains(local.clone(), "key", key.clone()).into(); - let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotContainsFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::MerkleProofIndex(0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); - let prev_statements = vec![root_st, key_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) - } - - #[test] - fn test_operation_verify_contains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(175); - let (value, key_pf) = mt.prove(&key.raw())?; - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - "value" => value, - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - let value_ak = AnchoredKey::from((&local, "value")); - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = - Statement::contains(local.clone(), "key", key.clone()).into(); - let value_st: mainpod::Statement = - Statement::contains(local.clone(), "value", value).into(); - - let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainsFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::MerkleProofIndex(0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); - let prev_statements = vec![root_st, key_st, value_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) - } - - #[test] - fn test_operation_verify_merkle_insert() -> Result<()> { - let mut tree = MerkleTree::new(&[].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerInsertFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_merkle_update() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.update(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerUpdateFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_merkle_delete() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let state_transition_proof = tree.delete(&key.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerDeleteFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_publickeyof_ok() -> Result<()> { - [ - SecretKey(BigUint::one()), - SecretKey::new_rand(), - SecretKey(&*GROUP_ORDER - BigUint::one()), - ] - .into_iter() - .try_for_each(|secret_key| { - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) - }) - } - - #[test] - fn test_operation_verify_publickeyof_failure_wrong_key() { - let secret_key = SecretKey(BigUint::one()); - let public_key = SecretKey(BigUint::ZERO).public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_pk_type() { - let secret_key = SecretKey(BigUint::one()); - let public_key = 123i64; - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_sk_type() { - let secret_key = 123i64; - let public_key = SecretKey(BigUint::from(123u32)).public_key(); - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); - assert!(operation_verify(st, op, prev_statements, aux,).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_sk_size() { - let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_signedby_ok() -> Result<()> { - let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); - let pk = sk.public_key(); - let msg = RawValue([F(1), F(2), F(3), F(4)]); - let nonce = BigUint::from_u32(123).unwrap(); - let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); - let signed_by = SignedBy { - msg, - pk, - sig: sig.clone(), - }; - - let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SignedBy), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::SignedByIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) - } - - #[test] - fn test_operation_replace_value_with_entry() -> Result<()> { - let d = dict!({"a" => 42, "b" => 33}); - - // 0: None - // 1: Lt(5, 42) - let st_in: mainpod::Statement = Statement::lt(5, 42).into(); - // 2: Contains(d, "a", 42) - let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); - - let st_out: mainpod::Statement = - Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); - let mut op_args: Vec<_> = iter::repeat(OperationArg::None) - .take(BASE_PARAMS.max_statement_args + 1) - .collect(); - op_args[1] = OperationArg::Index(2); - op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ReplaceValueWithEntry), - op_args, - OperationAux::None, - ); - - let prev_statements = vec![Statement::None.into(), st_in, st_entry]; - operation_verify(st_out, op, prev_statements, Aux::default()) - } - - fn helper_statement_arg_from_template( - params: &Params, - st_tmpl_arg: StatementTmplArg, - args: Vec, - expected_st_arg: StatementArg, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_arg_target = make_statement_arg_from_template_circuit( - params, - &mut builder, - &st_tmpl_arg_target, - &args_target, - ); - // TODO: Instead of connect, assign witness to result - let expected_st_arg_target = builder.add_virtual_statement_arg(); - builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); - - let mut pw = PartialWitness::::new(); - - st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) - } - - #[test] - fn test_statement_arg_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // case: None - let st_tmpl_arg = StatementTmplArg::None; - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::None; - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: Literal - let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("foo")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: AnchoredKey(id_wildcard, key_literal) - let st_tmpl_arg = - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: WildcardLiteral(wildcard) - let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); - let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("key")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - Ok(()) - } - - fn helper_statement_from_template( - params: &Params, - st_tmpl: StatementTmpl, - args: Vec, - expected_st: Statement, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_target = builder.add_virtual_statement_tmpl(false); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = make_statement_from_template_circuit( - params, - &mut builder, - &st_tmpl_target, - &args_target, - ); - // TODO: Instead of connect, assign witness to result - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - let mut pw = PartialWitness::::new(); - - st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_target.set_targets(&mut pw, &expected_st.into())?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) - } - - #[test] - fn test_statement_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st = Statement::equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); - let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; - let expected_st = Statement::not_equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - Ok(()) - } - - fn helper_custom_operation_verify_gadget( - params: &Params, - custom_predicate: CustomPredicateRef, - mut op_args: Vec, - mut args: Vec, - expected_st: Option, - ) -> Result<()> { - // Pad - for _ in op_args.len()..BASE_PARAMS.max_operation_args { - op_args.push(Statement::None); - } - for _ in args.len()..params.max_custom_predicate_wildcards { - args.push(Value::from(EMPTY_VALUE)); - } - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); - let op_args_target: Vec<_> = (0..op_args.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let (st_target, op_type_target) = make_custom_statement_circuit( - params, - &mut builder, - &custom_predicate_target, - &op_args_target, - &args_target, - )?; - - let mut pw = PartialWitness::::new(); - - // Input - custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; - for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { - op_arg_target.set_targets(&mut pw, &op_arg.into())?; - } - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; - } - // Expected Output - if let Some(expected_st) = expected_st { - st_target.set_targets(&mut pw, &expected_st.into())?; - } - - let expected_op_type = OperationType::Custom(custom_predicate); - op_type_target.set_targets(&mut pw, &expected_op_type)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) - } - - fn value_ref(v: impl Into) -> ValueRef { - v.into() - } - - // TODO: Add negative tests - #[test] - fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("id", "key")) - .arg("secret"); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (1) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::None, - ]; - let args = vec![Value::from(dict), Value::from(0)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (2) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::None, - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - Ok(()) - } - - #[test] - fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("secret_id", "key")) - .arg(("id", "score")); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret_id"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let secret_dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND (0) Sanity check with correct values - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // AND (1) Different dict for same wildcard - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (2) key doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (3) literal doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal( - AnchoredKey::new(dict, Key::from("score")), - Value::from(0xbad), - ), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (4) predicate doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::not_equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // OR (1) Two Nones - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![Statement::None, Statement::None]; - let args = vec![Value::from(dict), Value::from(0)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None - ) - .is_err()); - - Ok(()) - } - - fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let statements_target = (0..params.max_public_statements) - .map(|_| builder.add_virtual_statement(false)) - .collect_vec(); - let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); - - let mut pw = PartialWitness::::new(); - - // Input - let statements = statements - .iter() - .map(|st| { - let mut st = mainpod::Statement::from(st.clone()); - pad_statement(&mut st); - st - }) - .collect_vec(); - for (st_target, st) in statements_target.iter().zip(statements.iter()) { - st_target.set_targets(&mut pw, st)?; - } - // Expected Output - let expected_sts_hash = calculate_statements_hash(&statements); - pw.set_hash_target( - sts_hash_target, - HashOut { - elements: expected_sts_hash.0, - }, - )?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) - } - - #[test] - fn test_calculate_sts_hash() -> frontend::Result<()> { - assert_eq!(Params::num_public_statements_hash(), 16); - // Case with no public public statements - let params = Params { - max_public_statements: 0, - ..Default::default() - }; - - helper_calculate_statements_hash(¶ms, &[]).unwrap(); - - // Case with number of statements for the sts_hash equal to number of public statements - let params = Params { - max_public_statements: Params::num_public_statements_hash(), - ..Default::default() - }; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let statements = (0..Params::num_public_statements_hash()) - .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - // Case with more statements for the sts_hash than the number of public statements - let params = Params { - max_public_statements: 4, - ..Default::default() - }; - - let dict2 = Hash([F(5), F(6), F(7), F(8)]); - let statements = [ - Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), - Statement::equal( - AnchoredKey::from((dict, "bar")), - AnchoredKey::from((dict, "baz")), - ), - Statement::lt( - AnchoredKey::from((dict2, "one")), - AnchoredKey::from((dict2, "two")), - ), - ] - .into_iter() - .chain(iter::repeat(Statement::None)) - .take(params.max_public_statements) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - Ok(()) - } - - #[test] - fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { - let params = Params::default(); - - // Build a batch with two predicates: - // pred_A: Equal(x, y) - // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash - use NativePredicate as NP; - let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) - .arg("x") - .arg("y"); - cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) - .unwrap(); - - // Build pred_B's template manually with SelfPredicateHash(0) - let stb_b_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), - args: vec![ - StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), - StatementTmplArg::SelfPredicateHash(0), - ], - }; - let pred_b = CustomPredicate::new( - ¶ms, - "pred_B".into(), - true, - vec![stb_b_tmpl], - 1, - vec!["x".to_string()], - ) - .unwrap(); - cpb.predicates.push(pred_b); - let batch = cpb.finish().unwrap(); - - // Compute the expected resolved hash of pred_A - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); - let expected_pred_a_value = Value::from(pred_a_hash); - - // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to - // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce - // a statement with that literal value. - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - // Create the template target and batch id target - let st_tmpl_target = builder.add_virtual_statement_tmpl(true); - let batch_id = builder.add_virtual_hash(); - - // Normalize the template (this is what we're testing) - let normalized = - normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); - - // Feed normalized template into statement generation - let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = - make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); - - // Connect to expected output - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - // Set witness - let mut pw = PartialWitness::::new(); - st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; - pw.set_target_arr(&batch_id.elements, &batch.id().0)?; - - let some_value = Value::from(42); - // args: first wildcard is "x" = some_value, rest are padding - let mut args_values = vec![some_value.clone()]; - for _ in 1..params.max_custom_predicate_wildcards { - args_values.push(Value::from(EMPTY_VALUE)); - } - for (target, value) in args_target.iter().zip(args_values.iter()) { - target.set_targets(&mut pw, value)?; - } - - // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) - let expected_st: crate::backends::plonky2::mainpod::Statement = - Statement::equal(some_value, expected_pred_a_value).into(); - expected_st_target.set_targets(&mut pw, &expected_st)?; - - // Build and verify - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) - } -} diff --git a/src/backends/plonky2/circuits/mainpod/tests.rs b/src/backends/plonky2/circuits/mainpod/tests.rs new file mode 100644 index 0000000..49fe4a0 --- /dev/null +++ b/src/backends/plonky2/circuits/mainpod/tests.rs @@ -0,0 +1,1707 @@ +use std::{iter, ops::Not}; + +use num::FromPrimitive; +use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Field}, + hash::hash_types::HashOut, + iop::witness::WitnessWrite, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, +}; + +use super::*; +use crate::{ + backends::plonky2::{ + basetypes::C, + circuits::common::tests::I64_TEST_PAIRS, + mainpod::{calculate_statements_hash, OperationArg, OperationAux, Size}, + primitives::{ + ec::schnorr::SecretKey, + merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, + }, + signer, + }, + dict, + frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + self, hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, BASE_PARAMS, + EMPTY_VALUE, + }, +}; + +#[derive(Default)] +struct Aux { + merkle_proofs: Vec, + secret_keys: Vec, + signed_bys: Vec, + merkle_transition_proofs: Vec, +} + +impl Aux { + fn merkle_proof(v: MerkleClaimAndProof) -> Self { + Self { + merkle_proofs: vec![v], + ..Default::default() + } + } + fn secret_key(v: SecretKey) -> Self { + Self { + secret_keys: vec![v], + ..Default::default() + } + } + fn signed_by(v: SignedBy) -> Self { + Self { + signed_bys: vec![v], + ..Default::default() + } + } + fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { + Self { + merkle_transition_proofs: vec![v], + ..Default::default() + } + } +} + +fn operation_verify( + st: mainpod::Statement, + op: mainpod::Operation, + prev_statements: Vec, + aux: Aux, +) -> Result<()> { + let params = Params { + max_public_key_of: aux.secret_keys.len(), + max_signed_by: aux.signed_bys.len(), + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: aux.merkle_proofs.len(), + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: aux.merkle_transition_proofs.len(), + }, + max_depth_small: 8, + max_depth_medium: 32, + }, + max_custom_predicate_verifications: 0, + max_custom_predicates: 0, + ..Default::default() + }; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_target = builder.add_virtual_statement(false); + let op_target = builder.add_virtual_operation(¶ms); + let prev_statements_target: Vec<_> = (0..prev_statements.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + let prev_statement_flatteneds_target: Vec> = prev_statements_target + .iter() + .map(|st| st.flatten()) + .collect(); + let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect(); + + let merkle_proofs_target = MerkleProofsTarget { + medium: aux + .merkle_proofs + .iter() + .map(|_| { + MerkleClaimAndProofTarget::new_virtual( + params.containers.max_depth_medium, + &mut builder, + ) + }) + .collect(), + small: Vec::new(), + }; + + let secret_keys_target: Vec<_> = aux + .secret_keys + .iter() + .map(|sk| builder.constant_biguint320(&sk.0)) + .collect(); + + let signed_by_targets: Vec<_> = aux + .signed_bys + .iter() + .map(|_| SignedByTarget::new_virtual(&mut builder)) + .collect(); + + let merkle_transition_proofs_target = MerkleTransitionProofsTarget { + medium: aux + .merkle_transition_proofs + .iter() + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_medium, + &mut builder, + ) + }) + .collect(), + small: Vec::new(), + }; + + let aux_table = build_operation_aux_table_circuit( + ¶ms, + &mut builder, + &merkle_proofs_target, + &merkle_transition_proofs_target, + &secret_keys_target, + &signed_by_targets, + &[], + &[], + )?; + + verify_operation_circuit( + ¶ms, + &mut builder, + &st_target, + &op_target, + &prev_statement_flatteneds_target, + &prev_statement_hashes_target, + &aux_table, + )?; + + let mut pw = PartialWitness::::new(); + st_target.set_targets(&mut pw, &st)?; + op_target.set_targets(&mut pw, ¶ms, &op)?; + for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { + prev_st_target.set_targets(&mut pw, prev_st)?; + } + for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { + signed_by_target.set_targets(&mut pw, signed_by)? + } + for (merkle_proof_target, merkle_proof) in merkle_proofs_target + .medium + .iter() + .zip(aux.merkle_proofs.iter()) + { + merkle_proof_target.set_targets(&mut pw, merkle_proof)? + } + for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in + merkle_transition_proofs_target + .medium + .iter() + .zip(aux.merkle_transition_proofs.iter()) + { + merkle_tree_state_transition_proof_target + .set_targets(&mut pw, merkle_tree_state_transition_proof)? + } + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) +} + +#[test] +fn test_lt_lteq_verify_failures() { + let invalid_int = RawValue([ + GoldilocksField::NEG_ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ]); + + let prev_statements = [Statement::None.into()]; + + [ + // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(55, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -56).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(-55, -56).into(), + ), + // 56 < p-1 and p-1 <= p-1 should fail to verify, where p + // is the Goldilocks prime and 'p-1' occupies a single + // limb. + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, invalid_int).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(invalid_int, invalid_int).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + })); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }); +} + +#[test] +fn test_eq_neq_verify_failures() { + let prev_statements = [Statement::None.into()]; + + [ + // 56 == 55, 55 != 55 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::equal(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::not_equal(55, 55).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) + }); +} + +#[test] +fn test_operation_verify_none() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::None), + vec![], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_copy() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::CopyStatement), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_eq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 55}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_neq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 75}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_lt() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"hello" => 56}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < negative + let dict3 = dict!({"hola" => -56}); + let dict4 = dict!({"mundo" => -55}); + let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); + let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict4, "mundo")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < positive + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_lteq() -> Result<()> { + let local = dict!({ + "n55" => 55, + "n56" => 56, + "n_56" => -56, + "n_55" => -55, + }); + let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n55")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= negative + let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); + let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_55")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= positive + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + + // Also check equality, both positive and negative. + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(1), OperationArg::Index(1)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_hashof() -> Result<()> { + let input_values = [ + Value::from(RawValue([ + GoldilocksField(1), + GoldilocksField(2), + GoldilocksField(3), + GoldilocksField(4), + ])), + Value::from(512), + ]; + let v1 = hash_values(&input_values); + let [v2, v3] = input_values; + + let local = dict!({ + "hola" => v1, + "mundo" => v2.clone(), + "!" => v3.clone(), + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); + + let st: mainpod::Statement = Statement::hash_of( + AnchoredKey::from((&local, "hola")), + AnchoredKey::from((&local, "mundo")), + AnchoredKey::from((&local, "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::HashOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_sumof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (sum, overflow) = a.overflowing_add(b); + overflow.not().then_some((a, b, sum)) + }) + .try_for_each(|(a, b, sum)| { + let local = dict!({ + "sum" => sum, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { + let local = dict!({ + "a" => 3, + "noise" => 99, + "sum" => 6, + }); + let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); + let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); + let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "a")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + // Non-monotonic and repeated indices to stress random-access resolution. + OperationArg::Index(2), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = vec![st_a, st_noise, st_sum]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_productof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (prod, overflow) = a.overflowing_mul(b); + overflow.not().then_some((a, b, prod)) + }) + .try_for_each(|(a, b, prod)| { + let local = dict!({ + "prod" => prod, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "prod", prod).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::product_of( + AnchoredKey::from((&local, "prod")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ProductOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_maxof() -> Result<()> { + I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { + let max = i64::max(a, b); + let local = dict!({ + "max" => max, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::max_of( + AnchoredKey::from((&local, "max")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_maxof_failures() { + [(5, 3, 4), (5, 5, 8), (3, 4, 5)] + .into_iter() + .for_each(|(max, a, b)| { + let st: mainpod::Statement = Statement::max_of(max, a, b).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = [Statement::None.into()]; + + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + })); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }) +} + +#[test] +fn test_operation_verify_lt_to_neq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 20, + }); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st1: mainpod::Statement = Statement::lt( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtToNotEqual), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![st1]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_transitive_eq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 10, + "c" => 10, + }); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let st1: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st2: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "b")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::TransitiveEqualFromStatements), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_sintains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(5); + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + + let no_key_pf = mt.prove_nonexistence(&key.raw())?; + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); + let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotContainsFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::MerkleProofIndex(Size::Medium, 0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); + let prev_statements = vec![root_st, key_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) +} + +#[test] +fn test_operation_verify_contains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(175); + let (value, key_pf) = mt.prove(&key.raw())?; + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + "value" => value, + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + let value_ak = AnchoredKey::from((&local, "value")); + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); + let value_st: mainpod::Statement = Statement::contains(local.clone(), "value", value).into(); + + let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainsFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::MerkleProofIndex(Size::Medium, 0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); + let prev_statements = vec![root_st, key_st, value_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) +} + +#[test] +fn test_operation_verify_merkle_insert() -> Result<()> { + let mut tree = MerkleTree::new(&[].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerInsertFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_merkle_update() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.update(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerUpdateFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_merkle_delete() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let state_transition_proof = tree.delete(&key.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerDeleteFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_publickeyof_ok() -> Result<()> { + [ + SecretKey(BigUint::one()), + SecretKey::new_rand(), + SecretKey(&*GROUP_ORDER - BigUint::one()), + ] + .into_iter() + .try_for_each(|secret_key| { + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) + }) +} + +#[test] +fn test_operation_verify_publickeyof_failure_wrong_key() { + let secret_key = SecretKey(BigUint::one()); + let public_key = SecretKey(BigUint::ZERO).public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_pk_type() { + let secret_key = SecretKey(BigUint::one()); + let public_key = 123i64; + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_sk_type() { + let secret_key = 123i64; + let public_key = SecretKey(BigUint::from(123u32)).public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); + assert!(operation_verify(st, op, prev_statements, aux,).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_sk_size() { + let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_signedby_ok() -> Result<()> { + let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); + let pk = sk.public_key(); + let msg = RawValue([F(1), F(2), F(3), F(4)]); + let nonce = BigUint::from_u32(123).unwrap(); + let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); + let signed_by = SignedBy { + msg, + pk, + sig: sig.clone(), + }; + + let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SignedBy), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::SignedByIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) +} + +#[test] +fn test_operation_replace_value_with_entry() -> Result<()> { + let d = dict!({"a" => 42, "b" => 33}); + + // 0: None + // 1: Lt(5, 42) + let st_in: mainpod::Statement = Statement::lt(5, 42).into(); + // 2: Contains(d, "a", 42) + let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); + + let st_out: mainpod::Statement = + Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); + let mut op_args: Vec<_> = iter::repeat(OperationArg::None) + .take(BASE_PARAMS.max_statement_args + 1) + .collect(); + op_args[1] = OperationArg::Index(2); + op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + op_args, + OperationAux::None, + ); + + let prev_statements = vec![Statement::None.into(), st_in, st_entry]; + operation_verify(st_out, op, prev_statements, Aux::default()) +} + +fn helper_statement_arg_from_template( + params: &Params, + st_tmpl_arg: StatementTmplArg, + args: Vec, + expected_st_arg: StatementArg, +) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_arg_target = make_statement_arg_from_template_circuit( + params, + &mut builder, + &st_tmpl_arg_target, + &args_target, + ); + // TODO: Instead of connect, assign witness to result + let expected_st_arg_target = builder.add_virtual_statement_arg(); + builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); + + let mut pw = PartialWitness::::new(); + + st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) +} + +#[test] +fn test_statement_arg_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // case: None + let st_tmpl_arg = StatementTmplArg::None; + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::None; + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: Literal + let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("foo")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(id_wildcard, key_literal) + let st_tmpl_arg = + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: WildcardLiteral(wildcard) + let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); + let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("key")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + Ok(()) +} + +fn helper_statement_from_template( + params: &Params, + st_tmpl: StatementTmpl, + args: Vec, + expected_st: Statement, +) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_target = builder.add_virtual_statement_tmpl(false); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(params, &mut builder, &st_tmpl_target, &args_target); + // TODO: Instead of connect, assign witness to result + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + let mut pw = PartialWitness::::new(); + + st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_target.set_targets(&mut pw, &expected_st.into())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) +} + +#[test] +fn test_statement_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st = Statement::equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); + let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; + let expected_st = Statement::not_equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + Ok(()) +} + +fn helper_custom_operation_verify_gadget( + params: &Params, + custom_predicate: CustomPredicateRef, + mut op_args: Vec, + mut args: Vec, + expected_st: Option, +) -> Result<()> { + // Pad + for _ in op_args.len()..BASE_PARAMS.max_operation_args { + op_args.push(Statement::None); + } + for _ in args.len()..params.max_custom_predicate_wildcards { + args.push(Value::from(EMPTY_VALUE)); + } + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); + let op_args_target: Vec<_> = (0..op_args.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let (st_target, op_type_target) = make_custom_statement_circuit( + params, + &mut builder, + &custom_predicate_target, + &op_args_target, + &args_target, + )?; + + let mut pw = PartialWitness::::new(); + + // Input + custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; + for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { + op_arg_target.set_targets(&mut pw, &op_arg.into())?; + } + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; + } + // Expected Output + if let Some(expected_st) = expected_st { + st_target.set_targets(&mut pw, &expected_st.into())?; + } + + let expected_op_type = OperationType::Custom(custom_predicate); + op_type_target.set_targets(&mut pw, &expected_op_type)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) +} + +fn value_ref(v: impl Into) -> ValueRef { + v.into() +} + +// TODO: Add negative tests +#[test] +fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("id", "key")) + .arg("secret"); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; + let batch = builder.finish()?; + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (1) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::None, + ]; + let args = vec![Value::from(dict), Value::from(0)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (2) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::None, + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + Ok(()) +} + +#[test] +fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("secret_id", "key")) + .arg(("id", "score")); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret_id"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; + let batch = builder.finish()?; + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let secret_dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND (0) Sanity check with correct values + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // AND (1) Different dict for same wildcard + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (2) key doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (3) literal doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal( + AnchoredKey::new(dict, Key::from("score")), + Value::from(0xbad), + ), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (4) predicate doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::not_equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // OR (1) Two Nones + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![Statement::None, Statement::None]; + let args = vec![Value::from(dict), Value::from(0)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None) + .is_err() + ); + + Ok(()) +} + +fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let statements_target = (0..params.max_public_statements) + .map(|_| builder.add_virtual_statement(false)) + .collect_vec(); + let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); + + let mut pw = PartialWitness::::new(); + + // Input + let statements = statements + .iter() + .map(|st| { + let mut st = mainpod::Statement::from(st.clone()); + pad_statement(&mut st); + st + }) + .collect_vec(); + for (st_target, st) in statements_target.iter().zip(statements.iter()) { + st_target.set_targets(&mut pw, st)?; + } + // Expected Output + let expected_sts_hash = calculate_statements_hash(&statements); + pw.set_hash_target( + sts_hash_target, + HashOut { + elements: expected_sts_hash.0, + }, + )?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) +} + +#[test] +fn test_calculate_sts_hash() -> frontend::Result<()> { + assert_eq!(Params::num_public_statements_hash(), 16); + // Case with no public public statements + let params = Params { + max_public_statements: 0, + ..Default::default() + }; + + helper_calculate_statements_hash(¶ms, &[]).unwrap(); + + // Case with number of statements for the sts_hash equal to number of public statements + let params = Params { + max_public_statements: Params::num_public_statements_hash(), + ..Default::default() + }; + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let statements = (0..Params::num_public_statements_hash()) + .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + // Case with more statements for the sts_hash than the number of public statements + let params = Params { + max_public_statements: 4, + ..Default::default() + }; + + let dict2 = Hash([F(5), F(6), F(7), F(8)]); + let statements = [ + Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), + Statement::equal( + AnchoredKey::from((dict, "bar")), + AnchoredKey::from((dict, "baz")), + ), + Statement::lt( + AnchoredKey::from((dict2, "one")), + AnchoredKey::from((dict2, "two")), + ), + ] + .into_iter() + .chain(iter::repeat(Statement::None)) + .take(params.max_public_statements) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + Ok(()) +} + +#[test] +fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { + let params = Params::default(); + + // Build a batch with two predicates: + // pred_A: Equal(x, y) + // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash + use NativePredicate as NP; + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) + .arg("x") + .arg("y"); + cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) + .unwrap(); + + // Build pred_B's template manually with SelfPredicateHash(0) + let stb_b_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), + args: vec![ + StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), + StatementTmplArg::SelfPredicateHash(0), + ], + }; + let pred_b = CustomPredicate::new( + ¶ms, + "pred_B".into(), + true, + vec![stb_b_tmpl], + 1, + vec!["x".to_string()], + ) + .unwrap(); + cpb.predicates.push(pred_b); + let batch = cpb.finish().unwrap(); + + // Compute the expected resolved hash of pred_A + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); + let expected_pred_a_value = Value::from(pred_a_hash); + + // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to + // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce + // a statement with that literal value. + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + // Create the template target and batch id target + let st_tmpl_target = builder.add_virtual_statement_tmpl(true); + let batch_id = builder.add_virtual_hash(); + + // Normalize the template (this is what we're testing) + let normalized = normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); + + // Feed normalized template into statement generation + let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); + + // Connect to expected output + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + // Set witness + let mut pw = PartialWitness::::new(); + st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; + pw.set_target_arr(&batch_id.elements, &batch.id().0)?; + + let some_value = Value::from(42); + // args: first wildcard is "x" = some_value, rest are padding + let mut args_values = vec![some_value.clone()]; + for _ in 1..params.max_custom_predicate_wildcards { + args_values.push(Value::from(EMPTY_VALUE)); + } + for (target, value) in args_target.iter().zip(args_values.iter()) { + target.set_targets(&mut pw, value)?; + } + + // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) + let expected_st: crate::backends::plonky2::mainpod::Statement = + Statement::equal(some_value, expected_pred_a_value).into(); + expected_st_target.set_targets(&mut pw, &expected_st)?; + + // Build and verify + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) +} diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 5e9df2e..513b1da 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -148,14 +148,20 @@ pub(crate) fn extract_custom_predicate_verifications( Ok(table) } +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MerkleProofs { + pub(crate) medium: Vec, + pub(crate) small: Vec, +} + /// Extracts Merkle proofs from Contains/NotContains ops. pub(crate) fn extract_merkle_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], statements: &[middleware::Statement], -) -> Result> { - let mut table = Vec::new(); +) -> Result { + let mut tables = MerkleProofs::default(); for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); let (root, key, value, pf) = match (op, st) { @@ -178,31 +184,42 @@ pub(crate) fn extract_merkle_proofs( } _ => continue, }; - aux_list[i] = OperationAux::MerkleProofIndex(table.len()); - table.push(MerkleClaimAndProof::new( - Hash::from(root), - key, - value, - pf.clone(), - )); + let claim_proof = MerkleClaimAndProof::new(Hash::from(root), key, value, pf.clone()); + if pf.existence + // TODO: Make sure there's no off-by-one error here + && pf.siblings.len() <= params.containers.max_depth_small + && tables.small.len() < params.containers.state.max_small + { + aux_list[i] = OperationAux::MerkleProofIndex(Size::Small, tables.small.len()); + tables.small.push(claim_proof); + } else { + aux_list[i] = OperationAux::MerkleProofIndex(Size::Medium, tables.medium.len()); + tables.medium.push(claim_proof); + } } - if table.len() > params.max_merkle_proofs_containers { + if tables.medium.len() > params.containers.state.max_medium { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - table.len(), - params.max_merkle_proofs_containers + tables.medium.len(), + params.containers.state.max_medium ))); } - Ok(table) + Ok(tables) +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MerkleTransitionProofs { + pub(crate) medium: Vec, + pub(crate) small: Vec, } /// Extracts Merkle state transition proofs from container update ops. -pub(crate) fn extract_merkle_tree_state_transition_proofs( +pub(crate) fn extract_merkle_transition_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], -) -> Result> { - let mut table = Vec::new(); +) -> Result { + let mut tables = MerkleTransitionProofs::default(); for (i, op) in operations.iter().enumerate() { let pf = match op { middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf) @@ -210,17 +227,27 @@ pub(crate) fn extract_merkle_tree_state_transition_proofs( | middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(), _ => continue, }; - aux_list[i] = OperationAux::MerkleTreeStateTransitionProofIndex(table.len()); - table.push(pf); + if pf.op_proof.existence + // TODO: Make sure there's no off-by-one error here + && pf.siblings.len() <= params.containers.max_depth_small + && tables.small.len() < params.containers.transition.max_small + { + aux_list[i] = OperationAux::MerkleTransitionProofIndex(Size::Small, tables.small.len()); + tables.small.push(pf); + } else { + aux_list[i] = + OperationAux::MerkleTransitionProofIndex(Size::Medium, tables.medium.len()); + tables.medium.push(pf); + } } - if table.len() > params.max_merkle_tree_state_transition_proofs_containers { + if tables.medium.len() > params.containers.transition.max_medium { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - table.len(), - params.max_merkle_tree_state_transition_proofs_containers + tables.medium.len(), + params.containers.transition.max_medium ))); } - Ok(table) + Ok(tables) } pub(crate) fn extract_public_key_of( @@ -513,6 +540,8 @@ impl MainPodProver for Prover { let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; + let merkle_transition_proofs = + extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; let custom_predicates = extract_custom_predicates(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, @@ -537,9 +566,6 @@ impl MainPodProver for Prover { let signed_bys = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; - let merkle_tree_state_transition_proofs = - extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; - let (statements, public_statements) = layout_statements(params, false, &inputs)?; let operations = process_private_statements_operations( params, @@ -572,20 +598,15 @@ impl MainPodProver for Prover { .collect_vec(); let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len()); + let pad_vd_mt_proof = inputs.vd_set.get_vds_proof_0(); for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) { vd_mt_proofs.push(if pod.is_main() { - (true, inputs.vd_set.get_vds_proof(vd)?) + inputs.vd_set.get_vds_proof(vd)? } else { // For intro pods we don't verify inclusion of their vk into the vd set, so we - // generate a dummy mt proof with expected root and value to pass some constraints - ( - false, - MerkleClaimAndProof { - root: inputs.vd_set.root(), - value: RawValue::from(pod.verifier_data_hash()), - ..MerkleClaimAndProof::empty() - }, - ) + // use a valid vds proof that matches the expected root but not the value to pass + // the constraints + pad_vd_mt_proof.clone() }); } @@ -598,7 +619,7 @@ impl MainPodProver for Prover { merkle_proofs, public_key_of_sks, signed_bys, - merkle_tree_state_transition_proofs, + merkle_transition_proofs, custom_predicates_with_mpt_proofs, custom_predicate_verifications, }; @@ -985,7 +1006,18 @@ pub mod tests { max_statements: 2, max_public_statements: 1, max_input_pods_public_statements: 0, - max_merkle_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, max_public_key_of: 0, max_custom_predicate_verifications: 0, max_custom_predicates: 0, @@ -1024,11 +1056,20 @@ pub mod tests { max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, - max_merkle_proofs_containers: 2, - max_merkle_tree_state_transition_proofs_containers: 2, max_public_key_of: 2, - max_depth_mt_containers: 4, max_depth_mt_vds: 6, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 2, + max_medium: 2, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 2, + max_medium: 2, + }, + max_depth_small: 2, + max_depth_medium: 4, + }, }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); @@ -1087,8 +1128,18 @@ pub mod tests { max_public_statements: 4, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, - max_merkle_proofs_containers: 3, - max_merkle_tree_state_transition_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 3, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, ..Default::default() }; println!("{:#?}", params); @@ -1156,8 +1207,18 @@ pub mod tests { max_public_statements: 2, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, - max_merkle_proofs_containers: 0, - max_merkle_tree_state_transition_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, ..Default::default() }; let mut vds = DEFAULT_VD_LIST.clone(); diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index d7b44bb..2060ac7 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ error::{Error, Result}, - mainpod::{SignedBy, Statement}, - primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, + mainpod::{MerkleProofs, MerkleTransitionProofs, SignedBy, Statement}, }, middleware::{self, OperationType, Params}, }; @@ -30,50 +29,89 @@ impl OperationArg { } } +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] +pub enum Size { + Small, + Medium, +} + +impl fmt::Display for Size { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Small => write!(f, "small"), + Self::Medium => write!(f, "medium"), + } + } +} + +impl Size { + pub const fn min() -> Self { + Self::Small + } + pub const fn max() -> Self { + Self::Medium + } +} + #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum OperationAux { None, - MerkleProofIndex(usize), + MerkleProofIndex(Size, usize), + MerkleTransitionProofIndex(Size, usize), PublicKeyOfIndex(usize), SignedByIndex(usize), - MerkleTreeStateTransitionProofIndex(usize), CustomPredVerifyIndex(usize), } impl OperationAux { - fn table_offset_merkle_proof(_params: &Params) -> usize { - // At index 0 we store a zero entry - 1 + fn table_offset_merkle_proof(params: &Params, size: Size) -> usize { + match size { + // At index 0 we store a zero entry + Size::Small => 1, + Size::Medium => { + Self::table_offset_merkle_proof(params, Size::Small) + + params.containers.state.max_small + } + } + } + fn table_offset_merkle_transition_proof(params: &Params, size: Size) -> usize { + match size { + Size::Small => { + Self::table_offset_merkle_proof(params, Size::min()) + + params.containers.state.max_total() + } + Size::Medium => { + Self::table_offset_merkle_transition_proof(params, Size::Small) + + params.containers.transition.max_small + } + } + } + fn table_offset_custom_pred_verify(params: &Params) -> usize { + Self::table_offset_merkle_transition_proof(params, Size::min()) + + params.containers.transition.max_total() } fn table_offset_public_key_of(params: &Params) -> usize { - Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers + Self::table_offset_custom_pred_verify(params) + params.max_custom_predicate_verifications } fn table_offset_signed_by(params: &Params) -> usize { Self::table_offset_public_key_of(params) + params.max_public_key_of } - fn table_offset_merkle_tree_state_transition_proof(params: &Params) -> usize { - Self::table_offset_signed_by(params) + params.max_signed_by - } - fn table_offset_custom_pred_verify(params: &Params) -> usize { - Self::table_offset_merkle_tree_state_transition_proof(params) - + params.max_merkle_tree_state_transition_proofs_containers - } pub(crate) fn table_size(params: &Params) -> usize { - 1 + params.max_merkle_proofs_containers + 1 + params.containers.state.max_total() + + params.containers.transition.max_total() + + params.max_custom_predicate_verifications + params.max_public_key_of + params.max_signed_by - + params.max_merkle_tree_state_transition_proofs_containers - + params.max_custom_predicate_verifications } pub fn table_index(&self, params: &Params) -> usize { match self { Self::None => 0, - Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i, + Self::MerkleProofIndex(size, i) => Self::table_offset_merkle_proof(params, *size) + *i, + Self::MerkleTransitionProofIndex(size, i) => { + Self::table_offset_merkle_transition_proof(params, *size) + *i + } Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i, Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i, - Self::MerkleTreeStateTransitionProofIndex(i) => { - Self::table_offset_merkle_tree_state_transition_proof(params) + *i - } Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i, } } @@ -96,8 +134,8 @@ impl Operation { &self, statements: &[Statement], signatures: &[SignedBy], - merkle_proofs: &[MerkleClaimAndProof], - merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProof], + merkle_proofs: &MerkleProofs, + merkle_transition_proofs: &MerkleTransitionProofs, ) -> Result { let deref_args = self .1 @@ -113,17 +151,26 @@ impl Operation { .collect::>>()?; let deref_aux = match self.2 { OperationAux::None => crate::middleware::OperationAux::None, - OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, - OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( - merkle_proofs - .get(i) - .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? - .proof - .clone(), - ), - OperationAux::MerkleTreeStateTransitionProofIndex(i) => { + OperationAux::MerkleProofIndex(size, i) => { + let table = match size { + Size::Small => &merkle_proofs.small, + Size::Medium => &merkle_proofs.medium, + }; + crate::middleware::OperationAux::MerkleProof( + table + .get(i) + .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? + .proof + .clone(), + ) + } + OperationAux::MerkleTransitionProofIndex(size, i) => { + let table = match size { + Size::Small => &merkle_transition_proofs.small, + Size::Medium => &merkle_transition_proofs.medium, + }; crate::middleware::OperationAux::MerkleTreeStateTransitionProof( - merkle_tree_state_transition_proofs + table .get(i) .ok_or(Error::custom(format!( "Missing Merkle state transition proof index {}", @@ -132,6 +179,7 @@ impl Operation { .clone(), ) } + OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature( signatures .get(i) @@ -165,12 +213,14 @@ impl fmt::Display for Operation { } match self.2 { OperationAux::None => (), - OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, + OperationAux::MerkleProofIndex(size, i) => { + write!(f, " {}_merkle_proof_{:02}", size, i)? + } OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?, OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?, - OperationAux::MerkleTreeStateTransitionProofIndex(i) => { - write!(f, " merkle_tree_state_transition_proof_{:02}", i)? + OperationAux::MerkleTransitionProofIndex(size, i) => { + write!(f, " {}_merkle_transition_proof_{:02}", size, i)? } } Ok(()) diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index b8c6a03..8dd710a 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -11,13 +11,12 @@ use crate::{ basetypes::{Proof, VerifierOnlyCircuitData}, error::{Error, Result}, mainpod::{ - calculate_statements_hash, extract_merkle_proofs, - extract_merkle_tree_state_transition_proofs, extract_signatures, layout_statements, - process_private_statements_operations, process_public_statements_operations, Operation, + calculate_statements_hash, extract_merkle_proofs, extract_merkle_transition_proofs, + extract_signatures, layout_statements, process_private_statements_operations, + process_public_statements_operations, MerkleProofs, MerkleTransitionProofs, Operation, OperationAux, SignedBy, Statement, }, mock::emptypod::MockEmptyPod, - primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, recursion::hash_verifier_data, }, middleware::{ @@ -45,10 +44,10 @@ pub struct MockMainPod { operations: Vec, // public subset of the `statements` vector public_statements: Vec, - // All Merkle proofs - merkle_proofs_containers: Vec, - // All Merkle tree state transition proofs - merkle_tree_state_transition_proofs_containers: Vec, + // All Merkle proofs for containers + merkle_proofs: MerkleProofs, + // All Merkle tree state transition proofs for containers + merkle_transition_proofs: MerkleTransitionProofs, // All verified signatures signatures: Vec, } @@ -124,8 +123,8 @@ struct Data { public_statements: Vec, operations: Vec, statements: Vec, - merkle_proofs: Vec, - merkle_tree_state_transition_proofs: Vec, + merkle_proofs: MerkleProofs, + merkle_transition_proofs: MerkleTransitionProofs, signatures: Vec, input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>, } @@ -153,8 +152,8 @@ impl MockMainPod { let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; // Similarly for Merkle state transition proofs. - let merkle_tree_state_transition_proofs = - extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; + let merkle_transition_proofs = + extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; let signatures = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; @@ -185,8 +184,8 @@ impl MockMainPod { public_statements, statements, operations, - merkle_proofs_containers: merkle_proofs, - merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, + merkle_proofs, + merkle_transition_proofs, signatures, }) } @@ -260,8 +259,8 @@ impl Pod for MockMainPod { .deref( &self.statements[..input_statement_offset + i], &self.signatures, - &self.merkle_proofs_containers, - &self.merkle_tree_state_transition_proofs_containers, + &self.merkle_proofs, + &self.merkle_transition_proofs, )? .check_and_log(&self.params, &s.clone().try_into()?) .map_err(|e| e.into()) @@ -321,10 +320,8 @@ impl Pod for MockMainPod { public_statements: self.public_statements.clone(), operations: self.operations.clone(), statements: self.statements.clone(), - merkle_proofs: self.merkle_proofs_containers.clone(), - merkle_tree_state_transition_proofs: self - .merkle_tree_state_transition_proofs_containers - .clone(), + merkle_proofs: self.merkle_proofs.clone(), + merkle_transition_proofs: self.merkle_transition_proofs.clone(), signatures: self.signatures.clone(), input_pods, }) @@ -344,7 +341,7 @@ impl Pod for MockMainPod { operations, statements, merkle_proofs, - merkle_tree_state_transition_proofs, + merkle_transition_proofs, signatures, input_pods, } = serde_json::from_value(data)?; @@ -362,8 +359,8 @@ impl Pod for MockMainPod { public_statements, operations, statements, - merkle_proofs_containers: merkle_proofs, - merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, + merkle_proofs, + merkle_transition_proofs, signatures, }) } diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 2c54b8b..f53a143 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -42,8 +42,6 @@ use crate::{ #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MerkleClaimAndProofTarget { pub(crate) max_depth: usize, - // `enabled` determines if the merkleproof verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -121,16 +119,9 @@ pub fn verify_merkle_proof_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) - .collect(); + // check that obtained_root==root (from inputs) for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); + builder.connect(obtained_root.elements[j], proof.root.elements[j]); } measure_gates_end!(builder, measure); } @@ -139,7 +130,6 @@ impl MerkleClaimAndProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleClaimAndProofTarget { max_depth, - enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -154,12 +144,7 @@ impl MerkleClaimAndProofTarget { } /// assigns the given values to the targets #[allow(clippy::too_many_arguments)] - pub fn set_targets( - &self, - pw: &mut PartialWitness, - enabled: bool, - mp: &MerkleClaimAndProof, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( self.max_depth, @@ -167,7 +152,6 @@ impl MerkleClaimAndProofTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -207,8 +191,6 @@ impl MerkleClaimAndProofTarget { #[derive(Clone, Serialize, Deserialize)] pub struct MerkleProofExistenceTarget { max_depth: usize, - // `enabled` determines if the merkleproof verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -236,16 +218,9 @@ pub fn verify_merkle_proof_existence_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) - .collect(); + // check that obtained_root==root (from inputs) for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); + builder.connect(obtained_root.elements[j], proof.root.elements[j]); } measure_gates_end!(builder, measure); @@ -256,7 +231,6 @@ impl MerkleProofExistenceTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleProofExistenceTarget { max_depth, - enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -265,12 +239,7 @@ impl MerkleProofExistenceTarget { } } /// assigns the given values to the targets - pub fn set_targets( - &self, - pw: &mut PartialWitness, - enabled: bool, - mp: &MerkleClaimAndProof, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { assert!(mp.proof.existence); // sanity check if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( @@ -279,7 +248,6 @@ impl MerkleProofExistenceTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -456,8 +424,6 @@ fn hash_with_flag_target>( #[derive(Clone, Serialize, Deserialize)] pub struct MerkleTreeStateTransitionProofTarget { pub(crate) max_depth: usize, - // `enabled` determines if the merkleproof state transition verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) op_proof: MerkleClaimAndProofTarget, @@ -511,7 +477,6 @@ pub fn verify_merkle_state_transition_circuit( }; let new_key_proof = MerkleProofExistenceTarget { max_depth: proof.max_depth, - enabled: proof.enabled, root, key: proof.op_key, value: proof.op_value, @@ -523,13 +488,7 @@ pub fn verify_merkle_state_transition_circuit( // Insert/Delete: Non-existence // Update: Existence let proof_type = is_update; - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.existence.target, - proof_type.target, - ); - // 3.2) assert that proof.enabled matches with op_proof.enabled - builder.connect(proof.op_proof.enabled.target, proof.enabled.target); + builder.connect(proof.op_proof.existence.target, proof_type.target); // 4) assert proof_non_existence.root corresponds to the root // specified by the op (old_root for Insert/Update and new_root @@ -545,17 +504,9 @@ pub fn verify_merkle_state_transition_circuit( }; for j in 0..HASH_SIZE { // 4.1) assert that proof.proof_non_existence.root == proof.old_root - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.root.elements[j], - claim_root.elements[j], - ); + builder.connect(proof.op_proof.root.elements[j], claim_root.elements[j]); // 4.2) assert that the non-existence proof uses the op_key (value not needed). - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.key.elements[j], - proof.op_key.elements[j], - ); + builder.connect(proof.op_proof.key.elements[j], proof.op_key.elements[j]); } // prepare value for check 5.2) @@ -593,7 +544,7 @@ pub fn verify_merkle_state_transition_circuit( .map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j])) .collect(); for j in 0..HASH_SIZE { - builder.conditional_assert_eq(proof.enabled.target, old_sibling_i[j], new_sibling_i[j]); + builder.connect(old_sibling_i[j], new_sibling_i[j]); } // 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that: @@ -611,7 +562,7 @@ pub fn verify_merkle_state_transition_circuit( let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level); // do the case2's checks - let sel = builder.and(proof.enabled, in_case_5_2); + let sel = in_case_5_2; for j in 0..HASH_SIZE { builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero); builder.conditional_assert_eq( @@ -641,7 +592,6 @@ impl MerkleTreeStateTransitionProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { Self { max_depth, - enabled: builder.add_virtual_bool_target_safe(), op: builder.add_virtual_target(), old_root: builder.add_virtual_hash(), @@ -661,7 +611,6 @@ impl MerkleTreeStateTransitionProofTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - enabled: bool, mp: &MerkleTreeStateTransitionProof, ) -> Result<()> { let new_siblings = mp.siblings.clone(); @@ -672,13 +621,11 @@ impl MerkleTreeStateTransitionProofTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?; pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?; self.op_proof.set_targets( pw, - enabled, &MerkleClaimAndProof { root: if mp.op == MerkleTreeOp::Delete { mp.new_root @@ -859,7 +806,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -871,6 +817,42 @@ pub mod tests { Ok(()) } + #[test] + fn test_merkleproof_pad_valid() -> Result<()> { + // circuit + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleClaimAndProofTarget::new_virtual(32, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, &MerkleClaimAndProof::pad())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } + + #[test] + fn test_merkleproof_transition_pad_valid() -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleTreeStateTransitionProofTarget::new_virtual(32, &mut builder); + verify_merkle_state_transition_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, &MerkleTreeStateTransitionProof::pad())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + #[test] fn test_merkleproof_only_existence_verify() -> Result<()> { for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { @@ -906,7 +888,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -982,7 +963,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -1028,32 +1008,15 @@ pub mod tests { let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_proof_circuit(&mut builder, &targets); - // verification enabled & proof of existence + // proof of existence let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof); - targets.set_targets(&mut pw, true, &mp)?; + targets.set_targets(&mut pw, &mp)?; // generate proof, expecting it to fail (since we're using the wrong // root) let data = builder.build::(); assert!(data.prove(pw).is_err()); - // Now generate a new proof, using `enabled=false`, which should pass the verification - // despite containing 'wrong' witness. - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_proof_circuit(&mut builder, &targets); - // verification disabled & proof of existence - targets.set_targets(&mut pw, false, &mp)?; - - // generate proof, should pass despite using wrong witness, since the - // `enabled=false` - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - Ok(()) } @@ -1076,7 +1039,7 @@ pub mod tests { let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, true, state_transition_proof)?; + targets.set_targets(&mut pw, state_transition_proof)?; // generate & verify proof let data = builder.build::(); @@ -1273,71 +1236,4 @@ pub mod tests { assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check Ok(()) } - - #[test] - fn test_state_transition_gadget_disabled() -> Result<()> { - let max_depth: usize = 32; - let mut kvs = HashMap::new(); - for i in 0..8 { - kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); - } - let mut tree = MerkleTree::new(&kvs); - - let key = RawValue::from(37); - let value = RawValue::from(1037); - let _ = tree.insert(&key, &value)?; - - let key = RawValue::from(21); - let value = RawValue::from(1021); - let original_state_transition_proof = tree.insert(&key, &value)?; - - let mut state_transition_proof = original_state_transition_proof.clone(); - - // modify the proof, so that it should fail when `enabled=true`, by - // changing the new_root - state_transition_proof.new_root = state_transition_proof.old_root; - - run_circuit_disabled(max_depth, &state_transition_proof)?; - - // modify the proof, so that it should fail when `enabled=true`, by - // changing the new_sibling at the divergence level, which should not - // pass the verification in the case where we're inserting key=21 - let mut state_transition_proof = original_state_transition_proof.clone(); - state_transition_proof.siblings[4] = EMPTY_HASH; - - run_circuit_disabled(max_depth, &state_transition_proof)?; - - Ok(()) - } - - fn run_circuit_disabled( - max_depth: usize, - state_transition_proof: &MerkleTreeStateTransitionProof, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, true, state_transition_proof)?; - - // generate proof, and expect it to fail - let data = builder.build::(); - assert!(data.prove(pw).is_err()); // expect prove to fail - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, false, state_transition_proof)?; - - // generate and expect it to pass - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - Ok(()) - } } diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 0e29e14..e84da20 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -921,6 +921,21 @@ impl MerkleClaimAndProof { }, } } + /// Value used for padding. This is a valid merkle proof. + pub fn pad() -> Self { + let [key, value] = [EMPTY_VALUE, EMPTY_VALUE]; + let root = kv_hash(&key, Some(value)); + Self { + root, + key, + value, + proof: MerkleProof { + existence: true, + siblings: vec![], + other_leaf: None, + }, + } + } pub fn new(root: Hash, key: RawValue, value: Option, proof: MerkleProof) -> Self { Self { root, @@ -974,7 +989,6 @@ pub struct MerkleTreeStateTransitionProof { } impl MerkleTreeStateTransitionProof { - /// Value used for padding. pub fn empty() -> Self { let empty_proof_and_claim = MerkleClaimAndProof::empty(); Self { @@ -988,6 +1002,20 @@ impl MerkleTreeStateTransitionProof { siblings: vec![], } } + /// Value used for padding. This is a valid transition proof. + pub fn pad() -> Self { + let pad_proof_and_claim = MerkleClaimAndProof::pad(); + Self { + op: MerkleTreeOp::Update, + old_root: pad_proof_and_claim.root, + op_proof: pad_proof_and_claim.proof, + new_root: pad_proof_and_claim.root, + op_key: pad_proof_and_claim.key, + op_value: pad_proof_and_claim.value, + value: Some(pad_proof_and_claim.value), + siblings: vec![], + } + } } // NOTE: currently we use automatic serialization/deserialization, which is @@ -1165,6 +1193,15 @@ pub mod tests { Ok(()) } + #[test] + fn test_merkletree_pad() { + let claim = MerkleClaimAndProof::pad(); + MerkleTree::verify(claim.root, &claim.proof, &claim.key, &claim.value).unwrap(); + + let proof = MerkleTreeStateTransitionProof::pad(); + MerkleTree::verify_state_transition(&proof).unwrap(); + } + #[test] fn test_key_not_found() -> Result<()> { let db = Box::new(db::MemDB::new()); diff --git a/src/frontend/multi_pod/diagnostics.rs b/src/frontend/multi_pod/diagnostics.rs index 438f379..f56778f 100644 --- a/src/frontend/multi_pod/diagnostics.rs +++ b/src/frontend/multi_pod/diagnostics.rs @@ -78,12 +78,12 @@ fn aggregate_rows<'a>( UtilizationRow { name: "merkle proofs", used: merkle_proofs, - limit: params.max_merkle_proofs_containers, + limit: params.containers.state.max_medium, }, UtilizationRow { name: "merkle state transitions", used: merkle_state_transitions, - limit: params.max_merkle_tree_state_transition_proofs_containers, + limit: params.containers.transition.max_medium, }, UtilizationRow { name: "custom pred verifications", @@ -278,15 +278,24 @@ mod tests { use super::*; use crate::{ frontend::multi_pod::cost::CustomPredicateId, - middleware::{Hash, RawValue}, + middleware::{Hash, ParamsContainers, ParamsMerkleProofs, RawValue}, }; fn default_params() -> Params { Params { max_statements: 48, max_public_statements: 8, - max_merkle_proofs_containers: 8, - max_merkle_tree_state_transition_proofs_containers: 4, + containers: ParamsContainers { + state: ParamsMerkleProofs { + max_small: 0, + max_medium: 8, + }, + transition: ParamsMerkleProofs { + max_small: 0, + max_medium: 4, + }, + ..Default::default() + }, max_custom_predicate_verifications: 10, max_custom_predicates: 2, max_signed_by: 4, diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index db1502e..8d81ab3 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -395,13 +395,11 @@ pub fn solve(input: &SolverInput) -> Result { let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod); let lb_merkle = lower_bound_from_total( resource_totals.merkle_proofs, - input.params.max_merkle_proofs_containers, + input.params.containers.state.max_medium, ); let lb_merkle_transitions = lower_bound_from_total( resource_totals.merkle_state_transitions, - input - .params - .max_merkle_tree_state_transition_proofs_containers, + input.params.containers.transition.max_medium, ); let lb_custom_pred_verifications = lower_bound_from_total( resource_totals.custom_pred_verifications, @@ -753,7 +751,7 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p] + merkle_sum <= (input.params.containers.state.max_medium as f64) * pod_used[p] )); // 6d: Merkle state transitions @@ -761,11 +759,7 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - mst_sum - <= (input - .params - .max_merkle_tree_state_transition_proofs_containers as f64) - * pod_used[p] + mst_sum <= (input.params.containers.transition.max_medium as f64) * pod_used[p] )); // 6e: Custom predicate verifications diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 82675d7..d212ca8 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -780,6 +780,50 @@ pub const BASE_PARAMS: BaseParams = BaseParams { max_operation_args: 5 + 1, }; +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] +#[serde(rename_all = "camelCase")] +pub struct ParamsMerkleProofs { + pub max_small: usize, + pub max_medium: usize, +} + +impl ParamsMerkleProofs { + pub fn max_total(&self) -> usize { + self.max_small + self.max_medium + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] +#[serde(rename_all = "camelCase")] +pub struct ParamsContainers { + // Parameters for exists/nonexists container operations. The small set only supports exists + pub state: ParamsMerkleProofs, + // Parameters for transition container operations (insert, delete, update). The small set only + // supports update. + pub transition: ParamsMerkleProofs, + // Max depth of small proofs + pub max_depth_small: usize, + // Max depth of medium proofs + pub max_depth_medium: usize, +} + +impl Default for ParamsContainers { + fn default() -> Self { + Self { + state: ParamsMerkleProofs { + max_small: 22, + max_medium: 8, + }, + transition: ParamsMerkleProofs { + max_small: 12, + max_medium: 6, + }, + max_depth_small: 8, + max_depth_medium: 32, + } + } +} + /// Params: non dynamic parameters that define the circuit. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] @@ -793,12 +837,7 @@ pub struct Params { // max number of operations using custom predicates that can be verified in the MainPod pub max_custom_predicate_verifications: usize, pub max_custom_predicate_wildcards: usize, - // maximum number of merkle proofs used for container operations - pub max_merkle_proofs_containers: usize, - // maximum number of merkle tree state transition proofs used for container update operations - pub max_merkle_tree_state_transition_proofs_containers: usize, - // maximum depth for merkle tree gadget used for container operations - pub max_depth_mt_containers: usize, + pub containers: ParamsContainers, // maximum depth of the merkle tree gadget used for verifier_data membership // check. This allows creating verifying sets of pod circuits of size // 2^max_depth_mt_vds. Limits the number of container operations of the type Contains, @@ -820,9 +859,7 @@ impl Default for Params { max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, - max_merkle_proofs_containers: 20, - max_merkle_tree_state_transition_proofs_containers: 6, - max_depth_mt_containers: 32, + containers: ParamsContainers::default(), max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits max_public_key_of: 2, max_signed_by: 4,