diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c34d0ea..d3741ab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,8 +24,6 @@ jobs: run: cargo build --features metrics - name: Build time run: cargo build --features time - - name: Build db_rocksdb - run: cargo build --features db_rocksdb - name: Build disk_cache run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b3b389a..3d1ba0e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,5 +17,4 @@ jobs: - name: Set up Rust uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests - # RocksDB is disabled by default but we still want to test it. - run: cargo test --release --features db_rocksdb + run: cargo test --release diff --git a/Cargo.toml b/Cargo.toml index 704fe89..a1f7511 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,6 @@ good_lp = { version = "1.8", default-features = false, features = [ "scip_bundled", ] } annotate-snippets = "0.11" -rocksdb = { version = "0.24.0", optional = true } # keyvalue database for merkletree # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPARC/plonky2"] @@ -58,7 +57,6 @@ rocksdb = { version = "0.24.0", optional = true } # keyvalue database for merkle pretty_assertions = "1.4.1" # Used only for testing JSON Schema generation and validation. jsonschema = "0.30.0" -tempfile = "3" [build-dependencies] vergen-gitcl = { version = "1.0.0", features = ["build"] } @@ -72,7 +70,6 @@ time = [] examples = [] disk_cache = ["directories", "minicbor-serde"] mem_cache = [] -db_rocksdb = ["rocksdb"] # Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release. # [profile.release] diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index f65eb7b..d7d6b39 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -51,7 +51,7 @@ use crate::{ mainpod::cache_get_rec_main_pod_verifier_circuit_data, primitives::merkletree::MerkleClaimAndProof, }, - middleware::{containers::Array, Hash, Params, RawValue, Result, Value, EMPTY_HASH}, + middleware::{containers::Array, Hash, Params, RawValue, Result, Value}, }; pub static DEFAULT_VD_LIST: LazyLock> = LazyLock::new(|| { @@ -95,12 +95,6 @@ impl Eq for VDSet {} impl VDSet { fn new_from_vds_hashes(mut vds_hashes: Vec) -> Self { - // If vds_hashes is empty we add an zero entry to be used as padding when verifying merkle - // proofs of inclusion in the vds set. This zero entry can't be abused because no circuit - // exists with a vds_hash = 0. - if vds_hashes.is_empty() { - vds_hashes.push(EMPTY_HASH); - } // before using the hash values, sort them, so that each set of // verifier_datas gets the same VDSet root vds_hashes.sort(); @@ -156,9 +150,6 @@ impl VDSet { ))? .clone()) } - pub fn get_vds_proof_0(&self) -> MerkleClaimAndProof { - self.proofs_map[&self.vds_hashes[0]].clone() - } /// Returns true if the `verifier_data_hash` is in the set pub fn contains(&self, verifier_data_hash: HashOut) -> bool { self.proofs_map diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index bb194a0..db8c32a 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -25,20 +25,20 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, CommonCircuitData, D}, - circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator}, + circuits::mainpod::CustomPredicateVerification, error::Result, mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::{ verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, - MerkleProof, MerkleProofExistenceTarget, MerkleTreeStateTransitionProofTarget, + MerkleProof, MerkleTreeStateTransitionProofTarget, }, }, middleware::{ hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, - StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE, - STATEMENT_ARG_F_LEN, VALUE_SIZE, + StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, + VALUE_SIZE, }, }; @@ -103,20 +103,6 @@ pub struct StatementArgTarget { pub elements: [Target; STATEMENT_ARG_F_LEN], } -impl Flattenable for StatementArgTarget { - fn flatten(&self) -> Vec { - self.elements.to_vec() - } - fn from_flattened(_params: &Params, vs: &[Target]) -> Self { - Self { - elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"), - } - } - fn size(_params: &Params) -> usize { - STATEMENT_ARG_F_LEN - } -} - impl StatementArgTarget { pub fn set_targets(&self, pw: &mut PartialWitness, arg: &StatementArg) -> Result<()> { Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) @@ -332,7 +318,7 @@ impl OperationTarget { .args() .iter() .chain(iter::repeat(&OperationArg::None)) - .take(BASE_PARAMS.max_operation_args) + .take(params.max_operation_args) .enumerate() { self.args[i].set_targets(pw, arg.as_usize())?; @@ -342,7 +328,7 @@ impl OperationTarget { fn size(params: &Params) -> usize { OperationTypeTarget::size(params) - + BASE_PARAMS.max_operation_args * IndexTarget::size(params) + + params.max_operation_args * IndexTarget::size(params) + IndexTarget::size(params) } } @@ -725,6 +711,7 @@ impl CustomPredicateInBatchTarget { let mtp = MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder); let _true = builder._true(); + builder.connect(_true.target, mtp.enabled.target); builder.connect(_true.target, mtp.existence.target); let zero = builder.constant(F(0)); let key = ValueTarget { @@ -762,7 +749,7 @@ impl CustomPredicateInBatchTarget { value: RawValue::from(hash_fields(&predicate.to_fields())), proof: mtp.clone(), }; - self.mtp.set_targets(pw, &mtp_claim)?; + self.mtp.set_targets(pw, true, &mtp_claim)?; Ok(()) } } @@ -784,8 +771,7 @@ impl CustomPredicateEntryTarget { pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; - // Replace BatchSelf predicates with Custom(batch, i), and - // SelfPredicateHash args with Literal(hash(Custom(batch, i))) + // Replace statement templates of batch-self with (id,index) let batch = &predicate.batch; let predicate = predicate.predicate(); let statements = predicate @@ -802,22 +788,10 @@ impl CustomPredicateEntryTarget { } x => x.clone(), }; - let args = st_tmpl - .args - .into_iter() - .map(|arg| match arg { - StatementTmplArg::SelfPredicateHash(i) => { - let pred_hash = Predicate::Custom(CustomPredicateRef { - batch: batch.clone(), - index: i, - }) - .hash(); - StatementTmplArg::Literal(Value::from(pred_hash)) - } - other => other, - }) - .collect(); - StatementTmpl { pred_or_wc, args } + StatementTmpl { + pred_or_wc, + args: st_tmpl.args, + } }) .collect_vec(); let predicate = CustomPredicate { @@ -881,7 +855,7 @@ impl CustomPredicateVerifyEntryTarget { args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), - op_args: (0..BASE_PARAMS.max_operation_args) + op_args: (0..params.max_operation_args) .map(|_| builder.add_virtual_statement(false)) .collect(), } @@ -911,7 +885,7 @@ impl CustomPredicateVerifyEntryTarget { cpv.op_args .iter() .chain(iter::repeat(&pad_op_arg)) - .take(BASE_PARAMS.max_operation_args), + .take(params.max_operation_args), ) { op_arg_target.set_targets(pw, op_arg)? } @@ -954,7 +928,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { .expect("len = operation_type_size"), }; let (pos, size) = (pos + size, StatementTarget::size(params)); - let op_args = (0..BASE_PARAMS.max_operation_args) + let op_args = (0..params.max_operation_args) .map(|i| { StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) }) @@ -966,7 +940,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { } } fn size(params: &Params) -> usize { - StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args) + StatementTarget::size(params) * (1 + params.max_operation_args) + OperationTarget::size(params) } } @@ -986,6 +960,7 @@ pub trait Flattenable { /// elsewhere. #[derive(Copy, Clone)] pub struct MerkleClaimTarget { + pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -995,6 +970,7 @@ pub struct MerkleClaimTarget { impl From for MerkleClaimTarget { fn from(pf: MerkleClaimAndProofTarget) -> Self { Self { + enabled: pf.enabled, root: pf.root, key: pf.key, value: pf.value, @@ -1003,25 +979,12 @@ impl From for MerkleClaimTarget { } } -impl MerkleClaimTarget { - pub fn from_proof_existence( - builder: &mut CircuitBuilder, - pf: MerkleProofExistenceTarget, - ) -> Self { - Self { - root: pf.root, - key: pf.key, - value: pf.value, - existence: builder._true(), - } - } -} - /// For the purpose of op verification, we need only look up the /// Merkle state transition claim rather than the Merkle state /// transition proof since it is verified elsewhere. #[derive(Copy, Clone)] pub struct MerkleTreeStateTransitionClaimTarget { + pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) new_root: HashOutTarget, @@ -1032,6 +995,7 @@ pub struct MerkleTreeStateTransitionClaimTarget { impl From for MerkleTreeStateTransitionClaimTarget { fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self { Self { + enabled: pf.enabled, op: pf.op, old_root: pf.old_root, new_root: pf.new_root, @@ -1072,6 +1036,7 @@ impl Flattenable for ValueTarget { impl Flattenable for MerkleClaimTarget { fn flatten(&self) -> Vec { [ + vec![self.enabled.target], self.root.elements.to_vec(), self.key.elements.to_vec(), self.value.elements.to_vec(), @@ -1083,28 +1048,31 @@ impl Flattenable for MerkleClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), - key: ValueTarget::from_slice(&vs[NUM_HASH_OUT_ELTS..NUM_HASH_OUT_ELTS + VALUE_SIZE]), - value: ValueTarget::from_slice( - &vs[NUM_HASH_OUT_ELTS + VALUE_SIZE..NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], + enabled: BoolTarget::new_unsafe(vs[0]), + root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), + key: ValueTarget::from_slice( + &vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE], ), - existence: BoolTarget::new_unsafe(vs[NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), + value: ValueTarget::from_slice( + &vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], + ), + existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), } } fn size(params: &Params) -> usize { - HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 + 2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params) } } impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn flatten(&self) -> Vec { [ + vec![self.enabled.target, self.op], self.old_root.elements.to_vec(), self.new_root.elements.to_vec(), self.op_key.elements.to_vec(), self.op_value.elements.to_vec(), - vec![self.op], ] .concat() } @@ -1112,22 +1080,24 @@ impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - old_root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), + enabled: BoolTarget::new_unsafe(vs[0]), + op: vs[1], + old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()), new_root: HashOutTarget::from_vec( - vs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS].to_vec(), + vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(), ), op_key: ValueTarget::from_slice( - &vs[2 * NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS + VALUE_SIZE], + &vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE], ), op_value: ValueTarget::from_slice( - &vs[2 * NUM_HASH_OUT_ELTS + VALUE_SIZE..2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], + &vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE + ..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE], ), - op: vs[2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], } } fn size(params: &Params) -> usize { - 2 * HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 + 2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params) } } @@ -1365,18 +1335,6 @@ pub trait CircuitBuilderPod, const D: usize> { fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T; - /// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then - /// materializes the selected row via a witness generator and constrains its hash. Cheaper than - /// `vec_ref` when each entry has many fields, since random access runs only over the 4-field - /// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and - /// reusing the same slices across multiple lookups. - fn vec_ref_projected( - &mut self, - params: &Params, - ts_flattened: &[Vec], - ts_hashes: &[HashOutTarget], - i: &IndexTarget, - ) -> T; fn select_flattenable( &mut self, params: &Params, @@ -1454,7 +1412,7 @@ impl CircuitBuilderPod for CircuitBuilder { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { OperationTarget { op_type: self.add_virtual_operation_type(), - args: (0..BASE_PARAMS.max_operation_args) + args: (0..params.max_operation_args) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .collect(), aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), @@ -1764,7 +1722,7 @@ impl CircuitBuilderPod for CircuitBuilder { let num_chunks = array.len().div_ceil(CHUNK_LEN); for chunk in array.chunks(CHUNK_LEN) { let mut index_chunk = i.low; - // If we have several chunks and the last one is smaller (it's index needs less than 6 + // I we have several chunks and the last one is smaller (it's index needs less than 6 // bits), make it zero except when it's used so that the range check over the index // passes. if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { @@ -1779,6 +1737,12 @@ impl CircuitBuilderPod for CircuitBuilder { self.random_access(i.high, chunk_res) } + // TODO: Implement a version of vec_ref for types `T` which are big and support hashing. + // The idea would be the following: Take the array `ts` and hash each element. Then do the + // random access on the hash result. Finally "unhash" to recover the resolved element. + // We don't want to hash each element from the array each time, so we should cache the hashed + // result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and + // then do `ts: &[HashCache]`. fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { let num_rows = m.len(); @@ -1802,28 +1766,6 @@ impl CircuitBuilderPod for CircuitBuilder { T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) } - fn vec_ref_projected( - &mut self, - params: &Params, - ts_flattened: &[Vec], - ts_hashes: &[HashOutTarget], - i: &IndexTarget, - ) -> T { - assert_eq!(ts_flattened.len(), ts_hashes.len()); - let selected_hash = self.vec_ref(params, ts_hashes, i); - let selected_flattened = self.add_virtual_targets(T::size(params)); - let selected_flattened_hash = - self.hash_n_to_hash_no_pad::(selected_flattened.clone()); - self.connect_hashes(selected_hash, selected_flattened_hash); - let result = T::from_flattened(params, &selected_flattened); - self.add_simple_generator(TableGetGenerator::new( - i.clone(), - ts_flattened.to_vec(), - selected_flattened, - )); - result - } - fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T { let zero = self.zero(); self.vec_ref( @@ -2070,7 +2012,7 @@ pub(crate) mod tests { // Empty case let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; - let custom_predicate_batch = cpb_builder.finish()?; + let custom_predicate_batch = cpb_builder.finish(); helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); // Some cases from the examples diff --git a/src/backends/plonky2/circuits/mainpod/mod.rs b/src/backends/plonky2/circuits/mainpod.rs similarity index 52% rename from src/backends/plonky2/circuits/mainpod/mod.rs rename to src/backends/plonky2/circuits/mainpod.rs index 89ed3cf..ebe77b4 100644 --- a/src/backends/plonky2/circuits/mainpod/mod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -16,9 +16,6 @@ use plonky2::{ use plonky2_u32::gadgets::multiple_comparison::list_le_circuit; use serde::{Deserialize, Serialize}; -#[cfg(test)] -mod tests; - use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, VDSet}, @@ -36,20 +33,18 @@ use crate::{ }, emptypod::EmptyPod, error::Result, - mainpod::{self, pad_statement, MerkleProofs, MerkleTransitionProofs, SignedBy}, + mainpod::{self, pad_statement, SignedBy}, primitives::{ ec::{ bits::{BigUInt320Target, CircuitBuilderBits}, curve::{ - CircuitBuilderElliptic, CircuitBuilderSignature, Point, PointTarget, - WitnessWriteCurve, GROUP_ORDER, + CircuitBuilderElliptic, Point, PointTarget, WitnessWriteCurve, GROUP_ORDER, }, schnorr::{CircuitBuilderSchnorr, SecretKey, SignatureTarget, WitnessWriteSchnorr}, }, merkletree::{ - verify_merkle_proof_circuit, verify_merkle_proof_existence_circuit, - verify_merkle_state_transition_circuit, MerkleClaimAndProof, - MerkleClaimAndProofTarget, MerkleProof, MerkleProofExistenceTarget, MerkleTreeOp, + verify_merkle_proof_circuit, verify_merkle_state_transition_circuit, + MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp, MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget, }, signature::{verify_signature_circuit, SignatureVerifyTarget}, @@ -59,8 +54,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PredicatePrefix, Statement, StatementTmplArgPrefix, ToFields, - Value, BASE_PARAMS, F, HASH_SIZE, VALUE_SIZE, + NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F, + HASH_SIZE, }, }; // @@ -74,38 +69,30 @@ pub const PI_OFFSET_VDSROOT: usize = 4; pub const NUM_PUBLIC_INPUTS: usize = 8; -const MAX_VALUE_ARGS: usize = 5; +const MAX_VALUE_ARGS: usize = 4; struct StatementArgCache { rhs: ValueTarget, lhs: StatementArgTarget, valid: BoolTarget, - pred_is_none: BoolTarget, - is_reference: BoolTarget, - // if `is_reference` then this is the AnchoredKey found in the Contains statement - reference: StatementArgTarget, - // if `is_reference` then this is the value found in the Contains statement - value: ValueTarget, } -struct StatementCache { - equations: [StatementArgCache; MAX_EQS], - first_n_equations_valid: [BoolTarget; MAX_EQS], +struct StatementCache { + equations: [StatementArgCache; MAX_VALUE_ARGS], + first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS], op_args: Vec, } -impl StatementCache { +impl StatementCache { fn new( params: &Params, - max_operation_args: usize, builder: &mut CircuitBuilder, op: &OperationTarget, st: &StatementTarget, - prev_statement_flatteneds: &[Vec], - prev_statement_hashes: &[HashOutTarget], + prev_statements: &[StatementTarget], ) -> Self { - let op_args = if prev_statement_flatteneds.is_empty() { - (0..max_operation_args) + let op_args = if prev_statements.is_empty() { + (0..params.max_operation_args) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .collect_vec() } else { @@ -113,17 +100,10 @@ impl StatementCache { // converting a length 1 array into a scalar. op.args .iter() - .take(max_operation_args) - .map(|i| { - builder.vec_ref_projected( - params, - prev_statement_flatteneds, - prev_statement_hashes, - i, - ) - }) + .map(|i| builder.vec_ref(params, prev_statements, i)) .collect::>() }; + assert!(params.max_operation_args >= MAX_VALUE_ARGS); assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); let equations = array::from_fn(|i| { let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None); @@ -137,9 +117,9 @@ impl StatementCache { let is_reference = builder.and(pred_is_contains, ref_is_value); let valid = builder.or(is_literal, is_reference); - let rhs_from_literal = st.args[i].as_value(); - let rhs_from_reference = op_args[i].args[2].as_value(); - let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference); + let rhs_literal = st.args[i].as_value(); + let rhs_reference = op_args[i].args[2].as_value(); + let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference); let lhs_literal = &st.args[i]; let lhs_reference = StatementArgTarget::anchored_key( builder, @@ -147,22 +127,10 @@ impl StatementCache { &op_args[i].args[1].as_value(), ); let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference); - StatementArgCache { - rhs, - lhs, - valid, - pred_is_none, - is_reference, - reference: lhs_reference, - value: rhs_from_reference, - } + StatementArgCache { rhs, lhs, valid } }); - let mut first_n_equations_valid = if MAX_EQS != 0 { - [equations[0].valid; MAX_EQS] - } else { - [builder._false(); MAX_EQS] - }; - for i in 1..MAX_EQS { + let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS]; + for i in 1..MAX_VALUE_ARGS { first_n_equations_valid[i] = builder.and(equations[i].valid, first_n_equations_valid[i - 1]); } @@ -177,7 +145,7 @@ impl StatementCache { /// /// If the operation argument is a statement of type `None`, then the value /// should be the corresponding argument of the current statement. - /// If the operation argument is a statement of type `Contains`, then the value + /// If the operation argument is a statement of type `Equals`, then the value /// should be the argument at index 1 of that statement. /// If the function successfully interprets the arguments as values, /// returns `True` along with those values. Otherwise, returns `False` @@ -190,12 +158,6 @@ impl StatementCache { } } -/// Statement cache for private statements -type StatementCachePriv = StatementCache; -/// Statement cache for public statements. Since the operations can only be None or Copy, no -/// equation is needed because none of these operations dereference entries. -type StatementCachePub = StatementCache<0>; - /// Specialized implementation of `verify_operation_circuit` for operations that generate public /// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact /// that public statements in the current implementation are always generated by copying private @@ -205,26 +167,15 @@ fn verify_operation_public_statement_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statement_flatteneds: &[Vec], - prev_statement_hashes: &[HashOutTarget], + prev_statements: &[StatementTarget], ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerifyPub"); + let measure = measure_gates_begin!(builder, "OpVerify"); // Verify that the operation `op` correctly generates the statement `st`. The operation // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the - // StatementCache requires. - let cache = StatementCachePub::new( - params, - 1, - builder, - op, - st, - prev_statement_flatteneds, - prev_statement_hashes, - ); + let cache = StatementCache::new(params, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); let op_checks = vec![ @@ -242,21 +193,21 @@ fn verify_operation_public_statement_circuit( enum OperationAuxTableTag { None = 0, MerkleProof = 1, - MerkleTransitionProof = 2, - CustomPredVerify = 3, - PublicKeyOf = 4, - SignedBy = 5, + PublicKeyOf = 2, + SignedBy = 3, + MerkleTreeStateTransitionProof = 4, + CustomPredVerify = 5, } fn max_operation_aux_entry_len(params: &Params) -> usize { [ - (params.containers.state.max_total() > 0).then(|| MerkleClaimTarget::size(params)), - (params.containers.transition.max_total() > 0) + (params.max_merkle_proofs_containers > 0).then(|| MerkleClaimTarget::size(params)), + (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), + (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), + (params.max_merkle_tree_state_transition_proofs_containers > 0) .then(|| MerkleTreeStateTransitionClaimTarget::size(params)), (params.max_custom_predicate_verifications > 0) .then(|| CustomPredicateVerifyQueryTarget::size(params)), - (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), - (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), ] .into_iter() .flatten() @@ -310,59 +261,14 @@ impl SignedByTarget { } } -fn append_container_proofs_operation_aux_table_circuit( - builder: &mut CircuitBuilder, - table: &mut MuxTableTarget, - merkle_proofs: &MerkleProofsTarget, - merkle_transition_proofs: &MerkleTransitionProofsTarget, -) { - // Small MerkleProofs: verify container merkle proofs (only inclusion) - for merkle_proof in &merkle_proofs.small { - verify_merkle_proof_existence_circuit(builder, merkle_proof); - let entry = MerkleClaimTarget::from_proof_existence(builder, merkle_proof.clone()); - - table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); - } - // Medium MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) - for merkle_proof in &merkle_proofs.medium { - verify_merkle_proof_circuit(builder, merkle_proof); - let entry = MerkleClaimTarget::from(merkle_proof.clone()); - - table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); - } - - // Small Merkle state transition proofs: verify op proof (only update) - for merkle_transition_proof in &merkle_transition_proofs.small { - verify_merkle_state_transition_circuit(builder, merkle_transition_proof); - let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); - - table.push( - builder, - OperationAuxTableTag::MerkleTransitionProof as u32, - &entry, - ); - } - // Medium Merkle state transition proofs: verify op proof (insert/update/delete) - for merkle_transition_proof in &merkle_transition_proofs.medium { - verify_merkle_state_transition_circuit(builder, merkle_transition_proof); - let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); - - table.push( - builder, - OperationAuxTableTag::MerkleTransitionProof as u32, - &entry, - ); - } -} - #[allow(clippy::too_many_arguments)] fn build_operation_aux_table_circuit( params: &Params, builder: &mut CircuitBuilder, - merkle_proofs: &MerkleProofsTarget, - merkle_transition_proofs: &MerkleTransitionProofsTarget, + merkle_proofs: &[MerkleClaimAndProofTarget], public_key_of_sks: &[BigUInt320Target], signed_bys: &[SignedByTarget], + merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProofTarget], custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], custom_predicate_table: &[HashOutTarget], ) -> Result { @@ -371,63 +277,25 @@ fn build_operation_aux_table_circuit( params.max_custom_predicate_verifications, custom_predicate_verifications.len() ); - assert_eq!(params.containers.state.max_small, merkle_proofs.small.len()); - assert_eq!( - params.containers.state.max_medium, - merkle_proofs.medium.len() - ); + assert_eq!(params.max_merkle_proofs_containers, merkle_proofs.len()); let max_entry_len = max_operation_aux_entry_len(params); let mut table = MuxTableTarget::new(params, max_entry_len); // None table.push_flattened(builder, OperationAuxTableTag::None as u32, &[]); - append_container_proofs_operation_aux_table_circuit( - builder, - &mut table, - merkle_proofs, - merkle_transition_proofs, - ); + // MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) + for merkle_proof in merkle_proofs { + verify_merkle_proof_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from(merkle_proof.clone()); - // CustomPredVerify: verify custom predicate statements verification against operations - for entry in custom_predicate_verifications { - let measure = measure_gates_begin!(builder, "CustomPredVerify"); - // Verify the custom predicate operation - let (statement, op_type) = make_custom_statement_circuit( - params, - builder, - &entry.custom_predicate, - &entry.op_args, - &entry.args, - )?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = builder.vec_ref( - params, - custom_predicate_table, - &entry.custom_predicate_table_index, - ); - let out_query_hash = entry.custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let query = CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args: entry.op_args.clone(), // input - }; - table.push( - builder, - OperationAuxTableTag::CustomPredVerify as u32, - &query, - ); - measure_gates_end!(builder, measure); + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); } // PublicKeyOf: verify the derivation from a Schnorr secret key to public key - let invgenerator = builder.constant_point(Point::generator().inverse()); - let zero_bits: [BoolTarget; 320] = array::from_fn(|_| builder._false()); for sk in public_key_of_sks { let measure = measure_gates_begin!(builder, "PublicKeyOf"); + let invgenerator = builder.constant_point(Point::generator().inverse()); let group_orderm1 = &*GROUP_ORDER - BigUint::one(); let group_orderm1target = builder.constant_biguint320(&group_orderm1); let compare_ok = list_le_circuit( @@ -438,9 +306,7 @@ fn build_operation_aux_table_circuit( ); builder.assert_one(compare_ok.target); // public_key = g^-secret key - // Use the windowed ECAddXuGate (3-bit windows, 107 iterations) instead of the - // naive multiply_point (1-bit double-and-add, 320 iterations) for fewer gates. - let pk = builder.linear_combination_point_gen(&zero_bits, &sk.bits, &invgenerator); + let pk = builder.multiply_point(&sk.bits, &invgenerator); let sk_hash = builder.hash_n_to_hash_no_pad::(sk.limbs.to_vec()); let pk_hash = builder.hash_n_to_hash_no_pad::( pk.x.components.into_iter().chain(pk.u.components).collect(), @@ -480,6 +346,53 @@ fn build_operation_aux_table_circuit( measure_gates_end!(builder, measure); } + // Merkle state transition proofs: verify op proof (insert/update/delete) + for merkle_tree_state_transition_proof in merkle_tree_state_transition_proofs { + verify_merkle_state_transition_circuit(builder, merkle_tree_state_transition_proof); + let entry = + MerkleTreeStateTransitionClaimTarget::from(merkle_tree_state_transition_proof.clone()); + + table.push( + builder, + OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + &entry, + ); + } + + // CustomPredVerify: verify custom predicate statements verification against operations + for entry in custom_predicate_verifications { + let measure = measure_gates_begin!(builder, "CustomPredVerify"); + // Verify the custom predicate operation + let (statement, op_type) = make_custom_statement_circuit( + params, + builder, + &entry.custom_predicate, + &entry.op_args, + &entry.args, + )?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + params, + custom_predicate_table, + &entry.custom_predicate_table_index, + ); + let out_query_hash = entry.custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let query = CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args: entry.op_args.clone(), // input + }; + table.push( + builder, + OperationAuxTableTag::CustomPredVerify as u32, + &query, + ); + measure_gates_end!(builder, measure); + } + measure_gates_end!(builder, measure); Ok(table) } @@ -490,11 +403,10 @@ fn verify_operation_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statement_flatteneds: &[Vec], - prev_statement_hashes: &[HashOutTarget], + prev_statements: &[StatementTarget], aux_table: &MuxTableTarget, ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerifyPriv"); + let measure = measure_gates_begin!(builder, "OpVerify"); let _true = builder._true(); let _false = builder._false(); @@ -502,15 +414,7 @@ fn verify_operation_circuit( // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCachePriv::new( - params, - BASE_PARAMS.max_operation_args, - builder, - op, - st, - prev_statement_flatteneds, - prev_statement_hashes, - ); + let cache = StatementCache::new(params, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); // Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified @@ -538,12 +442,11 @@ fn verify_operation_circuit( verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), verify_product_of_circuit(params, builder, st, &op.op_type, &cache), verify_max_of_circuit(params, builder, st, &op.op_type, &cache), - verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache), ]); } // Skip these if there are no resolved aux entries if let Some(resolved_aux) = resolved_aux { - if params.containers.state.max_total() > 0 { + if params.max_merkle_proofs_containers > 0 { op_checks.extend_from_slice(&[ verify_contains_from_entries_circuit( params, @@ -583,7 +486,7 @@ fn verify_operation_circuit( &cache, )); } - if params.containers.transition.max_total() > 0 { + if params.max_merkle_tree_state_transition_proofs_containers > 0 { op_checks.extend_from_slice(&[ verify_merkle_insert_circuit( params, @@ -639,7 +542,7 @@ fn verify_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -651,6 +554,8 @@ fn verify_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ + /* The supplied Merkle proof must be enabled. */ + resolved_merkle_claim.enabled, /* ...and it must be an existence proof. */ resolved_merkle_claim.existence, /* ...for the root-key-value triple in the resolved op args. */ @@ -687,7 +592,7 @@ fn verify_not_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -698,6 +603,8 @@ fn verify_not_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ + /* The supplied Merkle proof must be enabled. */ + resolved_merkle_claim.enabled, /* ...and it must be a nonexistence proof. */ builder.not(resolved_merkle_claim.existence), /* ...for the root-key pair in the resolved op args. */ @@ -732,13 +639,13 @@ fn verify_merkle_insert_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleInsertOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTransitionProof as u32, + OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerInsertFromEntries); @@ -749,6 +656,8 @@ fn verify_merkle_insert_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ + /* The supplied Merkle transition proof must be enabled. */ + resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an insertion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -805,13 +714,13 @@ fn verify_merkle_update_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleUpdateOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTransitionProof as u32, + OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerUpdateFromEntries); @@ -822,6 +731,8 @@ fn verify_merkle_update_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ + /* The supplied Merkle transition proof must be enabled. */ + resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an update proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -878,13 +789,13 @@ fn verify_merkle_delete_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleDeleteOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTransitionProof as u32, + OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerDeleteFromEntries); @@ -895,6 +806,8 @@ fn verify_merkle_delete_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ + /* The supplied Merkle transition proof must be enabled. */ + resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be a deletion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -970,7 +883,7 @@ fn verify_eq_neq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let eq_op_st_code_ok = { @@ -1019,9 +932,9 @@ fn verify_lt_lteq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpLtEqFromEntries"); + let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); let zero = ValueTarget::zero(builder); let one = ValueTarget::one(builder); @@ -1087,7 +1000,7 @@ fn verify_hash_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpHashOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); @@ -1120,7 +1033,7 @@ fn verify_public_key_of_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpPublicKeyOf"); let (aux_tag_ok, resolved_pk_sk) = @@ -1156,7 +1069,7 @@ fn verify_signed_by_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSignedBy"); let (aux_tag_ok, resolved_msg_pk) = @@ -1191,7 +1104,7 @@ fn verify_sum_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSumOf"); let value_zero = ValueTarget::zero(builder); @@ -1229,7 +1142,7 @@ fn verify_product_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpProductOf"); let value_zero = ValueTarget::zero(builder); @@ -1267,7 +1180,7 @@ fn verify_max_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCachePriv, + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpMaxOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); @@ -1307,47 +1220,6 @@ fn verify_max_of_circuit( ok } -fn verify_replace_value_with_entry_circuit( - params: &Params, - builder: &mut CircuitBuilder, - st: &StatementTarget, - op_type: &OperationTypeTarget, - cache: &StatementCachePriv, -) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry"); - let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry); - - let st_in = &cache.op_args[BASE_PARAMS.max_statement_args]; - - let mut args = Vec::new(); - let mut args_ok = builder._true(); - for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) { - // if the op_arg is None, keep the original argument, if it's a Contains swap the value by - // the reference Entry while checking that the value in Contains matches the original - // argument. - let arg = builder.select_flattenable( - params, - entry_cache.pred_is_none, - arg_in, - &entry_cache.reference, - ); - args.push(arg); - let arg_ref_ok = { - let arg_in_is_value = builder.statement_arg_is_value(arg_in); - let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value); - builder.all([entry_cache.is_reference, arg_in_is_value, value_eq]) - }; - let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok); - args_ok = builder.and(args_ok, arg_ok); - } - let expected_statement = StatementTarget::new(*st_in.pred_hash(), args); - let st_ok = builder.is_equal_flattenable(st, &expected_statement); - - let ok = builder.all([op_code_ok, args_ok, st_ok]); - measure_gates_end!(builder, measure); - ok -} - fn verify_transitive_eq_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1557,7 +1429,7 @@ fn make_custom_statement_circuit( ) -> Result<(StatementTarget, OperationTypeTarget)> { let measure = measure_gates_begin!(builder, "CustomOpVerify"); // Some sanity checks - assert_eq!(BASE_PARAMS.max_operation_args, op_args.len()); + assert_eq!(params.max_operation_args, op_args.len()); assert_eq!(params.max_custom_predicate_wildcards, args.len()); let (batch_id, index) = (custom_predicate.id, custom_predicate.index); @@ -1591,6 +1463,7 @@ fn make_custom_statement_circuit( .collect(); // expected_sts.len() == params.max_custom_predicate_arity // op_args.len() == params.max_operation_args; + assert!(Params::max_custom_predicate_arity() <= params.max_operation_args); let sts_eq: Vec<_> = expected_sts .iter() @@ -1661,8 +1534,8 @@ pub fn calculate_statements_hash_circuit( sts_hash } -// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and -// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))). +// Replace predicates of batch-self with the corresponding global custom predicate batch_id and +// index fn normalize_st_tmpl_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1691,41 +1564,7 @@ fn normalize_st_tmpl_circuit( ); let pred_hash_or_wc = PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data); - - // Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved - // predicate hash. Same pattern as the predicate normalization above. - let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash)); - let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal)); - let zero = builder.zero(); - let normalized_args = st_tmpl - .args - .iter() - .map(|arg| { - let is_sph = builder.is_equal(arg.elements[0], prefix_sph); - - // The predicate index is in elements[1] (same slot as WildcardLiteral). - let pred_index = arg.elements[1]; - - // Compute hash(Custom(batch_id, pred_index)) - let pred_target = PredicateTarget::new_custom(builder, id, pred_index); - let pred_hash = pred_target.hash(builder); - - // Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0] - let mut literal_elements = [zero; Params::statement_tmpl_arg_size()]; - literal_elements[0] = prefix_literal; - literal_elements[1] = pred_hash.elements[0]; - literal_elements[2] = pred_hash.elements[1]; - literal_elements[3] = pred_hash.elements[2]; - literal_elements[4] = pred_hash.elements[3]; - let normalized = StatementTmplArgTarget { - elements: literal_elements, - }; - - builder.select_flattenable(params, is_sph, &normalized, arg) - }) - .collect(); - - StatementTmplTarget::new(pred_hash_or_wc, normalized_args) + StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone()) } /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as @@ -1803,20 +1642,19 @@ fn verify_main_pod_circuit( // NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction // pod that declares a statement with no arguments. - let st0_is_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); + let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); // Introduction pods can only have Introduction or None statements - let mut intro_ok = st0_is_intro; + let mut intro_ok = is_blank_intro; for self_st in &input_pod_self_statements[1..] { let st_is_intro = self_st.pred_is_blank_intro(builder); let st_is_none = self_st.has_native_type(builder, NativePredicate::None); let st_is_intro_or_none = builder.or(st_is_intro, st_is_none); intro_ok = builder.and(intro_ok, st_is_intro_or_none); } - builder.connect(st0_is_intro.target, intro_ok.target); + builder.connect(is_blank_intro.target, intro_ok.target); - let is_not_main = st0_is_intro; - let is_main = builder.not(is_not_main); + let is_main = builder.not(is_blank_intro); for self_st in input_pod_self_statements { let normalized_st = normalize_statement_circuit( params, @@ -1835,19 +1673,18 @@ fn verify_main_pod_circuit( // their verifier_data_hash appears in their introduction statement. // - verify_merkle_proof_existence_circuit(builder, vd_mt_proof); + verify_merkle_proof_circuit(builder, vd_mt_proof); + // ensure that mt_proof is enabled if it's a main pod + builder.connect(vd_mt_proof.enabled.target, is_main.target); // connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof // verifies against the vds_root builder.connect_hashes(main_pod.vds_root, vd_mt_proof.root); - // connect vd_mt_proof's value with the verified_proof.verifier_data_hash only when is_main - for i in 0..VALUE_SIZE { - builder.conditional_assert_eq( - is_main.target, - verified_proof.verifier_data_hash.elements[i], - vd_mt_proof.value.elements[i], - ) - } + // connect vd_mt_proof's value with the verified_proof.verifier_data_hash + builder.connect_hashes( + verified_proof.verifier_data_hash, + HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()), + ); // // Verify that VD array that input pod uses is the same we use now. @@ -1877,9 +1714,9 @@ fn verify_main_pod_circuit( params, builder, &main_pod.merkle_proofs, - &main_pod.merkle_transition_proofs, &main_pod.public_key_of_sks, &main_pod.signed_bys, + &main_pod.merkle_tree_state_transition_proofs, &main_pod.custom_predicate_verifications, &custom_predicate_table, )?; @@ -1887,37 +1724,13 @@ fn verify_main_pod_circuit( // 2. Calculate the Pod Id from the public statements let sts_hash = calculate_statements_hash_circuit(builder, pub_statements); - // Precompute flattened statements and their hashes once, then resolve operation args using - // projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup. - let statement_flatteneds: Vec> = statements.iter().map(|st| st.flatten()).collect(); - let statement_hashes = statement_flatteneds - .iter() - .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) - .collect_vec(); - // 5. Verify input statements for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { - let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i]; - let prev_statement_hashes = &statement_hashes[..input_statements_offset + i]; + let prev_statements = &statements[..input_statements_offset + i]; if i < public_statements_offset { - verify_operation_circuit( - params, - builder, - st, - op, - prev_statement_flatteneds, - prev_statement_hashes, - &aux_table, - )?; + verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?; } else { - verify_operation_public_statement_circuit( - params, - builder, - st, - op, - prev_statement_flatteneds, - prev_statement_hashes, - )?; + verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?; } } @@ -1925,77 +1738,19 @@ fn verify_main_pod_circuit( Ok(sts_hash) } -#[derive(Clone, Serialize, Deserialize)] -pub struct MerkleProofsTarget { - small: Vec, - medium: Vec, -} - -impl MerkleProofsTarget { - pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { - Self { - small: (0..params.containers.state.max_small) - .map(|_| { - MerkleProofExistenceTarget::new_virtual( - params.containers.max_depth_small, - builder, - ) - }) - .collect(), - medium: (0..params.containers.state.max_medium) - .map(|_| { - MerkleClaimAndProofTarget::new_virtual( - params.containers.max_depth_medium, - builder, - ) - }) - .collect(), - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct MerkleTransitionProofsTarget { - small: Vec, - medium: Vec, -} - -impl MerkleTransitionProofsTarget { - pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { - Self { - small: (0..params.containers.transition.max_small) - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.containers.max_depth_small, - builder, - ) - }) - .collect(), - medium: (0..params.containers.transition.max_medium) - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.containers.max_depth_medium, - builder, - ) - }) - .collect(), - } - } -} - #[derive(Clone, Serialize, Deserialize)] pub struct MainPodVerifyTarget { params: Params, vds_root: HashOutTarget, - vd_mt_proofs: Vec, + vd_mt_proofs: Vec, input_pods_self_statements: Vec>, // The KEY_TYPE statement must be the first public one input_statements: Vec, operations: Vec, - merkle_proofs: MerkleProofsTarget, - merkle_transition_proofs: MerkleTransitionProofsTarget, + merkle_proofs: Vec, public_key_of_sks: Vec, signed_bys: Vec, + merkle_tree_state_transition_proofs: Vec, custom_predicates: Vec, custom_predicate_verifications: Vec, } @@ -2006,7 +1761,7 @@ impl MainPodVerifyTarget { params: params.clone(), vds_root: builder.add_virtual_hash(), vd_mt_proofs: (0..params.max_input_pods) - .map(|_| MerkleProofExistenceTarget::new_virtual(params.max_depth_mt_vds, builder)) + .map(|_| MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_vds, builder)) .collect(), input_pods_self_statements: (0..params.max_input_pods) .map(|_| { @@ -2021,14 +1776,26 @@ impl MainPodVerifyTarget { operations: (0..params.max_statements) .map(|_| builder.add_virtual_operation(params)) .collect(), - merkle_proofs: MerkleProofsTarget::new_virtual(params, builder), - merkle_transition_proofs: MerkleTransitionProofsTarget::new_virtual(params, builder), + merkle_proofs: (0..params.max_merkle_proofs_containers) + .map(|_| { + MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder) + }) + .collect(), public_key_of_sks: (0..params.max_public_key_of) .map(|_| builder.add_virtual_biguint320_target()) .collect(), signed_bys: (0..params.max_signed_by) .map(|_| SignedByTarget::new_virtual(builder)) .collect(), + merkle_tree_state_transition_proofs: (0..params + .max_merkle_tree_state_transition_proofs_containers) + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.max_depth_mt_containers, + builder, + ) + }) + .collect(), custom_predicates: (0..params.max_custom_predicates) .map(|_| CustomPredicateInBatchTarget::new_virtual(builder)) .collect(), @@ -2037,64 +1804,6 @@ impl MainPodVerifyTarget { .collect(), } } - - fn set_container_mtp_targets( - &self, - pw: &mut PartialWitness, - input: &MainPodVerifyInput, - ) -> Result<()> { - assert!(input.merkle_proofs.small.len() <= self.params.containers.state.max_small); - for (i, mp) in input.merkle_proofs.small.iter().enumerate() { - self.merkle_proofs.small[i].set_targets(pw, mp)?; - } - // Padding - let pad_mp = MerkleClaimAndProof::pad(); - for i in input.merkle_proofs.small.len()..self.params.containers.state.max_small { - self.merkle_proofs.small[i].set_targets(pw, &pad_mp)?; - } - - assert!(input.merkle_proofs.medium.len() <= self.params.containers.state.max_medium); - for (i, mp) in input.merkle_proofs.medium.iter().enumerate() { - self.merkle_proofs.medium[i].set_targets(pw, mp)?; - } - // Padding - let pad_mp = MerkleClaimAndProof::pad(); - for i in input.merkle_proofs.medium.len()..self.params.containers.state.max_medium { - self.merkle_proofs.medium[i].set_targets(pw, &pad_mp)?; - } - - assert!( - input.merkle_transition_proofs.small.len() - <= self.params.containers.transition.max_small - ); - for (i, mtp) in input.merkle_transition_proofs.small.iter().enumerate() { - self.merkle_transition_proofs.small[i].set_targets(pw, mtp)?; - } - // Padding - let pad_mtp = MerkleTreeStateTransitionProof::pad(); - for i in - input.merkle_transition_proofs.small.len()..self.params.containers.transition.max_small - { - self.merkle_transition_proofs.small[i].set_targets(pw, &pad_mtp)?; - } - - assert!( - input.merkle_transition_proofs.medium.len() - <= self.params.containers.transition.max_medium - ); - for (i, mtp) in input.merkle_transition_proofs.medium.iter().enumerate() { - self.merkle_transition_proofs.medium[i].set_targets(pw, mtp)?; - } - // Padding - let pad_mtp = MerkleTreeStateTransitionProof::pad(); - for i in input.merkle_transition_proofs.medium.len() - ..self.params.containers.transition.max_medium - { - self.merkle_transition_proofs.medium[i].set_targets(pw, &pad_mtp)?; - } - - Ok(()) - } } pub struct CustomPredicateVerification { @@ -2109,14 +1818,15 @@ pub struct MainPodVerifyInput { /// field containing the `vd_mt_proofs` aside from the `vds_set`, because /// inside the MainPodVerifyTarget circuit, since it is the InnerCircuit for /// the RecursiveCircuit, we don't have access to the used verifier_datas. - pub vd_mt_proofs: Vec, + /// The bool is used as `enabled` and will be false for intro pods. + pub vd_mt_proofs: Vec<(bool, MerkleClaimAndProof)>, pub input_pods_pub_self_statements: Vec>, pub statements: Vec, pub operations: Vec, - pub merkle_proofs: MerkleProofs, - pub merkle_transition_proofs: MerkleTransitionProofs, + pub merkle_proofs: Vec, pub public_key_of_sks: Vec, pub signed_bys: Vec, + pub merkle_tree_state_transition_proofs: Vec, pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>, pub custom_predicate_verifications: Vec, } @@ -2172,8 +1882,8 @@ impl InnerCircuit for MainPodVerifyTarget { ); let input_pods_len = input.vd_mt_proofs.len(); assert!(input_pods_len <= self.params.max_input_pods); - for (i, vd_mt_proof) in input.vd_mt_proofs.iter().enumerate() { - self.vd_mt_proofs[i].set_targets(pw, vd_mt_proof)?; + for (i, (enable, vd_mt_proof)) in input.vd_mt_proofs.iter().enumerate() { + self.vd_mt_proofs[i].set_targets(pw, *enable, vd_mt_proof)?; } for (i, pod_pub_statements) in input.input_pods_pub_self_statements.iter().enumerate() { set_targets_input_pods_self_statements( @@ -2187,10 +1897,14 @@ impl InnerCircuit for MainPodVerifyTarget { if input_pods_len != self.params.max_input_pods { let empty_pod = EmptyPod::new_boxed(input.vds_set.clone()); let empty_pod_statements = empty_pod.pub_statements(); - let pad_mt_proof = input.vds_set.get_vds_proof_0(); + let empty_mt_proof = MerkleClaimAndProof { + root: input.vds_set.root(), + value: RawValue::from(empty_pod.verifier_data_hash()), + ..MerkleClaimAndProof::empty() + }; for i in input_pods_len..self.params.max_input_pods { - self.vd_mt_proofs[i].set_targets(pw, &pad_mt_proof)?; + self.vd_mt_proofs[i].set_targets(pw, false, &empty_mt_proof)?; set_targets_input_pods_self_statements( pw, &self.params, @@ -2206,7 +1920,15 @@ impl InnerCircuit for MainPodVerifyTarget { self.operations[i].set_targets(pw, &self.params, op)?; } - self.set_container_mtp_targets(pw, input)?; + assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs_containers); + for (i, mp) in input.merkle_proofs.iter().enumerate() { + self.merkle_proofs[i].set_targets(pw, true, mp)?; + } + // Padding + let pad_mp = MerkleClaimAndProof::empty(); + for i in input.merkle_proofs.len()..self.params.max_merkle_proofs_containers { + self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?; + } assert!(input.public_key_of_sks.len() <= self.params.max_public_key_of); for (i, sk) in input.public_key_of_sks.iter().enumerate() { @@ -2228,6 +1950,25 @@ impl InnerCircuit for MainPodVerifyTarget { self.signed_bys[i].set_targets(pw, &pad_signed_by)?; } + assert!( + input.merkle_tree_state_transition_proofs.len() + <= self + .params + .max_merkle_tree_state_transition_proofs_containers + ); + for (i, mtp) in input.merkle_tree_state_transition_proofs.iter().enumerate() { + self.merkle_tree_state_transition_proofs[i].set_targets(pw, true, mtp)?; + } + // Padding + let pad_mtp = MerkleTreeStateTransitionProof::empty(); + for i in input.merkle_tree_state_transition_proofs.len() + ..self + .params + .max_merkle_tree_state_transition_proofs_containers + { + self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?; + } + assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates); for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() { self.custom_predicates[i].set_targets(pw, cp, mtp)?; @@ -2272,3 +2013,1561 @@ impl InnerCircuit for MainPodVerifyTarget { Ok(()) } } + +#[cfg(test)] +mod tests { + use std::{iter, ops::Not}; + + use num::FromPrimitive; + use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Field}, + hash::hash_types::HashOut, + iop::witness::WitnessWrite, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, + }; + + use super::*; + use crate::{ + backends::plonky2::{ + basetypes::C, + circuits::common::tests::I64_TEST_PAIRS, + mainpod::{calculate_statements_hash, OperationArg, OperationAux}, + primitives::{ + ec::schnorr::SecretKey, + merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, + }, + signer, + }, + dict, + frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE, + }, + }; + + #[derive(Default)] + struct Aux { + merkle_proofs: Vec, + secret_keys: Vec, + signed_bys: Vec, + merkle_tree_state_transition_proofs: Vec, + } + + impl Aux { + fn merkle_proof(v: MerkleClaimAndProof) -> Self { + Self { + merkle_proofs: vec![v], + ..Default::default() + } + } + fn secret_key(v: SecretKey) -> Self { + Self { + secret_keys: vec![v], + ..Default::default() + } + } + fn signed_by(v: SignedBy) -> Self { + Self { + signed_bys: vec![v], + ..Default::default() + } + } + fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { + Self { + merkle_tree_state_transition_proofs: vec![v], + ..Default::default() + } + } + } + + fn operation_verify( + st: mainpod::Statement, + op: mainpod::Operation, + prev_statements: Vec, + aux: Aux, + ) -> Result<()> { + let params = Params { + max_merkle_proofs_containers: aux.merkle_proofs.len(), + max_public_key_of: aux.secret_keys.len(), + max_signed_by: aux.signed_bys.len(), + max_merkle_tree_state_transition_proofs_containers: aux + .merkle_tree_state_transition_proofs + .len(), + max_custom_predicate_verifications: 0, + max_custom_predicates: 0, + ..Default::default() + }; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_target = builder.add_virtual_statement(false); + let op_target = builder.add_virtual_operation(¶ms); + let prev_statements_target: Vec<_> = (0..prev_statements.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + + let merkle_proofs_target: Vec<_> = aux + .merkle_proofs + .iter() + .map(|_| { + MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, &mut builder) + }) + .collect(); + + let secret_keys_target: Vec<_> = aux + .secret_keys + .iter() + .map(|sk| builder.constant_biguint320(&sk.0)) + .collect(); + + let signed_by_targets: Vec<_> = aux + .signed_bys + .iter() + .map(|_| SignedByTarget::new_virtual(&mut builder)) + .collect(); + + let merkle_tree_state_transition_proofs_target: Vec<_> = aux + .merkle_tree_state_transition_proofs + .iter() + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.max_depth_mt_containers, + &mut builder, + ) + }) + .collect(); + + let aux_table = build_operation_aux_table_circuit( + ¶ms, + &mut builder, + &merkle_proofs_target, + &secret_keys_target, + &signed_by_targets, + &merkle_tree_state_transition_proofs_target, + &[], + &[], + )?; + + verify_operation_circuit( + ¶ms, + &mut builder, + &st_target, + &op_target, + &prev_statements_target, + &aux_table, + )?; + + let mut pw = PartialWitness::::new(); + st_target.set_targets(&mut pw, &st)?; + op_target.set_targets(&mut pw, ¶ms, &op)?; + for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { + prev_st_target.set_targets(&mut pw, prev_st)?; + } + for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { + signed_by_target.set_targets(&mut pw, signed_by)? + } + for (merkle_proof_target, merkle_proof) in + merkle_proofs_target.iter().zip(aux.merkle_proofs.iter()) + { + merkle_proof_target.set_targets(&mut pw, true, merkle_proof)? + } + for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in + merkle_tree_state_transition_proofs_target + .iter() + .zip(aux.merkle_tree_state_transition_proofs.iter()) + { + merkle_tree_state_transition_proof_target.set_targets( + &mut pw, + true, + merkle_tree_state_transition_proof, + )? + } + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } + + #[test] + fn test_lt_lteq_verify_failures() { + let invalid_int = RawValue([ + GoldilocksField::NEG_ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ]); + + let prev_statements = [Statement::None.into()]; + + [ + // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(55, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -56).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(-55, -56).into(), + ), + // 56 < p-1 and p-1 <= p-1 should fail to verify, where p + // is the Goldilocks prime and 'p-1' occupies a single + // limb. + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, invalid_int).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(invalid_int, invalid_int).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + let check = std::panic::catch_unwind(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + }); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }); + } + + #[test] + fn test_eq_neq_verify_failures() { + let prev_statements = [Statement::None.into()]; + + [ + // 56 == 55, 55 != 55 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::equal(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::not_equal(55, 55).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) + }); + } + + #[test] + fn test_operation_verify_none() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::None), + vec![], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_copy() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::CopyStatement), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_eq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 55}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_neq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 75}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_lt() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"hello" => 56}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < negative + let dict3 = dict!({"hola" => -56}); + let dict4 = dict!({"mundo" => -55}); + let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); + let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict4, "mundo")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < positive + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_lteq() -> Result<()> { + let local = dict!({ + "n55" => 55, + "n56" => 56, + "n_56" => -56, + "n_55" => -55, + }); + let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n55")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= negative + let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); + let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_55")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= positive + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + + // Also check equality, both positive and negative. + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(1), OperationArg::Index(1)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_hashof() -> Result<()> { + let input_values = [ + Value::from(RawValue([ + GoldilocksField(1), + GoldilocksField(2), + GoldilocksField(3), + GoldilocksField(4), + ])), + Value::from(512), + ]; + let v1 = hash_values(&input_values); + let [v2, v3] = input_values; + + let local = dict!({ + "hola" => v1, + "mundo" => v2.clone(), + "!" => v3.clone(), + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); + + let st: mainpod::Statement = Statement::hash_of( + AnchoredKey::from((&local, "hola")), + AnchoredKey::from((&local, "mundo")), + AnchoredKey::from((&local, "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::HashOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_sumof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (sum, overflow) = a.overflowing_add(b); + overflow.not().then_some((a, b, sum)) + }) + .try_for_each(|(a, b, sum)| { + let local = dict!({ + "sum" => sum, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) + } + + #[test] + fn test_operation_verify_productof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (prod, overflow) = a.overflowing_mul(b); + overflow.not().then_some((a, b, prod)) + }) + .try_for_each(|(a, b, prod)| { + let local = dict!({ + "prod" => prod, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = + Statement::contains(local.clone(), "prod", prod).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::product_of( + AnchoredKey::from((&local, "prod")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ProductOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) + } + + #[test] + fn test_operation_verify_maxof() -> Result<()> { + I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { + let max = i64::max(a, b); + let local = dict!({ + "max" => max, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::max_of( + AnchoredKey::from((&local, "max")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) + } + + #[test] + fn test_operation_verify_maxof_failures() { + [(5, 3, 4), (5, 5, 8), (3, 4, 5)] + .into_iter() + .for_each(|(max, a, b)| { + let st: mainpod::Statement = Statement::max_of(max, a, b).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = [Statement::None.into()]; + + let check = std::panic::catch_unwind(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + }); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }) + } + + #[test] + fn test_operation_verify_lt_to_neq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 20, + }); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st1: mainpod::Statement = Statement::lt( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtToNotEqual), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![st1]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_transitive_eq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 10, + "c" => 10, + }); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let st1: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st2: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "b")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::TransitiveEqualFromStatements), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) + } + + #[test] + fn test_operation_verify_sintains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(5); + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + + let no_key_pf = mt.prove_nonexistence(&key.raw())?; + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = + Statement::contains(local.clone(), "key", key.clone()).into(); + let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotContainsFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::MerkleProofIndex(0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); + let prev_statements = vec![root_st, key_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) + } + + #[test] + fn test_operation_verify_contains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(175); + let (value, key_pf) = mt.prove(&key.raw())?; + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + "value" => value, + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + let value_ak = AnchoredKey::from((&local, "value")); + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = + Statement::contains(local.clone(), "key", key.clone()).into(); + let value_st: mainpod::Statement = + Statement::contains(local.clone(), "value", value).into(); + + let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainsFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::MerkleProofIndex(0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); + let prev_statements = vec![root_st, key_st, value_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) + } + + #[test] + fn test_operation_verify_merkle_insert() -> Result<()> { + let mut tree = MerkleTree::new(&[].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerInsertFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTreeStateTransitionProofIndex(0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) + } + + #[test] + fn test_operation_verify_merkle_update() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.update(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerUpdateFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTreeStateTransitionProofIndex(0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) + } + + #[test] + fn test_operation_verify_merkle_delete() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let state_transition_proof = tree.delete(&key.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerDeleteFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTreeStateTransitionProofIndex(0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) + } + + #[test] + fn test_operation_verify_publickeyof_ok() -> Result<()> { + [ + SecretKey(BigUint::one()), + SecretKey::new_rand(), + SecretKey(&*GROUP_ORDER - BigUint::one()), + ] + .into_iter() + .try_for_each(|secret_key| { + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) + }) + } + + #[test] + fn test_operation_verify_publickeyof_failure_wrong_key() { + let secret_key = SecretKey(BigUint::one()); + let public_key = SecretKey(BigUint::ZERO).public_key(); + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) + } + + #[test] + fn test_operation_verify_publickeyof_failure_pk_type() { + let secret_key = SecretKey(BigUint::one()); + let public_key = 123i64; + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) + } + + #[test] + fn test_operation_verify_publickeyof_failure_sk_type() { + let secret_key = 123i64; + let public_key = SecretKey(BigUint::from(123u32)).public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); + assert!(operation_verify(st, op, prev_statements, aux,).is_err()) + } + + #[test] + fn test_operation_verify_publickeyof_failure_sk_size() { + let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) + } + + #[test] + fn test_operation_verify_signedby_ok() -> Result<()> { + let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); + let pk = sk.public_key(); + let msg = RawValue([F(1), F(2), F(3), F(4)]); + let nonce = BigUint::from_u32(123).unwrap(); + let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); + let signed_by = SignedBy { + msg, + pk, + sig: sig.clone(), + }; + + let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SignedBy), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::SignedByIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) + } + + fn helper_statement_arg_from_template( + params: &Params, + st_tmpl_arg: StatementTmplArg, + args: Vec, + expected_st_arg: StatementArg, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_arg_target = make_statement_arg_from_template_circuit( + params, + &mut builder, + &st_tmpl_arg_target, + &args_target, + ); + // TODO: Instead of connect, assign witness to result + let expected_st_arg_target = builder.add_virtual_statement_arg(); + builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); + + let mut pw = PartialWitness::::new(); + + st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + #[test] + fn test_statement_arg_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // case: None + let st_tmpl_arg = StatementTmplArg::None; + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::None; + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: Literal + let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("foo")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(id_wildcard, key_literal) + let st_tmpl_arg = + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: WildcardLiteral(wildcard) + let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); + let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("key")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + Ok(()) + } + + fn helper_statement_from_template( + params: &Params, + st_tmpl: StatementTmpl, + args: Vec, + expected_st: Statement, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_target = builder.add_virtual_statement_tmpl(false); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = make_statement_from_template_circuit( + params, + &mut builder, + &st_tmpl_target, + &args_target, + ); + // TODO: Instead of connect, assign witness to result + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + let mut pw = PartialWitness::::new(); + + st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_target.set_targets(&mut pw, &expected_st.into())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + #[test] + fn test_statement_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st = Statement::equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); + let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; + let expected_st = Statement::not_equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + Ok(()) + } + + fn helper_custom_operation_verify_gadget( + params: &Params, + custom_predicate: CustomPredicateRef, + mut op_args: Vec, + mut args: Vec, + expected_st: Option, + ) -> Result<()> { + // Pad + for _ in op_args.len()..params.max_operation_args { + op_args.push(Statement::None); + } + for _ in args.len()..params.max_custom_predicate_wildcards { + args.push(Value::from(EMPTY_VALUE)); + } + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); + let op_args_target: Vec<_> = (0..op_args.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let (st_target, op_type_target) = make_custom_statement_circuit( + params, + &mut builder, + &custom_predicate_target, + &op_args_target, + &args_target, + )?; + + let mut pw = PartialWitness::::new(); + + // Input + custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; + for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { + op_arg_target.set_targets(&mut pw, &op_arg.into())?; + } + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; + } + // Expected Output + if let Some(expected_st) = expected_st { + st_target.set_targets(&mut pw, &expected_st.into())?; + } + + let expected_op_type = OperationType::Custom(custom_predicate); + op_type_target.set_targets(&mut pw, &expected_op_type)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) + } + + // TODO: Add negative tests + #[test] + fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("id", "key")) + .arg("secret"); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; + let batch = builder.finish(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), Value::from(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (1) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::None, + ]; + let args = vec![Value::from(dict), Value::from(0)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), Value::from(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (2) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::None, + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), Value::from(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + Ok(()) + } + + #[test] + fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("secret_id", "key")) + .arg(("id", "score")); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret_id"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; + let batch = builder.finish(); + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let secret_dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND (0) Sanity check with correct values + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), Value::from(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // AND (1) Different dict for same wildcard + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (2) key doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (3) literal doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal( + AnchoredKey::new(dict, Key::from("score")), + Value::from(0xbad), + ), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (4) predicate doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::not_equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // OR (1) Two Nones + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![Statement::None, Statement::None]; + let args = vec![Value::from(dict), Value::from(0)]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None + ) + .is_err()); + + Ok(()) + } + + fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let statements_target = (0..params.max_public_statements) + .map(|_| builder.add_virtual_statement(false)) + .collect_vec(); + let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); + + let mut pw = PartialWitness::::new(); + + // Input + let statements = statements + .iter() + .map(|st| { + let mut st = mainpod::Statement::from(st.clone()); + pad_statement(&mut st); + st + }) + .collect_vec(); + for (st_target, st) in statements_target.iter().zip(statements.iter()) { + st_target.set_targets(&mut pw, st)?; + } + // Expected Output + let expected_sts_hash = calculate_statements_hash(&statements); + pw.set_hash_target( + sts_hash_target, + HashOut { + elements: expected_sts_hash.0, + }, + )?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) + } + + #[test] + fn test_calculate_sts_hash() -> frontend::Result<()> { + assert_eq!(Params::num_public_statements_hash(), 16); + // Case with no public public statements + let params = Params { + max_public_statements: 0, + ..Default::default() + }; + + helper_calculate_statements_hash(¶ms, &[]).unwrap(); + + // Case with number of statements for the sts_hash equal to number of public statements + let params = Params { + max_public_statements: Params::num_public_statements_hash(), + ..Default::default() + }; + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let statements = (0..Params::num_public_statements_hash()) + .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + // Case with more statements for the sts_hash than the number of public statements + let params = Params { + max_public_statements: 4, + ..Default::default() + }; + + let dict2 = Hash([F(5), F(6), F(7), F(8)]); + let statements = [ + Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), + Statement::equal( + AnchoredKey::from((dict, "bar")), + AnchoredKey::from((dict, "baz")), + ), + Statement::lt( + AnchoredKey::from((dict2, "one")), + AnchoredKey::from((dict2, "two")), + ), + ] + .into_iter() + .chain(iter::repeat(Statement::None)) + .take(params.max_public_statements) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + Ok(()) + } +} diff --git a/src/backends/plonky2/circuits/mainpod/tests.rs b/src/backends/plonky2/circuits/mainpod/tests.rs deleted file mode 100644 index 49fe4a0..0000000 --- a/src/backends/plonky2/circuits/mainpod/tests.rs +++ /dev/null @@ -1,1707 +0,0 @@ -use std::{iter, ops::Not}; - -use num::FromPrimitive; -use plonky2::{ - field::{goldilocks_field::GoldilocksField, types::Field}, - hash::hash_types::HashOut, - iop::witness::WitnessWrite, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, -}; - -use super::*; -use crate::{ - backends::plonky2::{ - basetypes::C, - circuits::common::tests::I64_TEST_PAIRS, - mainpod::{calculate_statements_hash, OperationArg, OperationAux, Size}, - primitives::{ - ec::schnorr::SecretKey, - merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, - }, - signer, - }, - dict, - frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::{ - self, hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, BASE_PARAMS, - EMPTY_VALUE, - }, -}; - -#[derive(Default)] -struct Aux { - merkle_proofs: Vec, - secret_keys: Vec, - signed_bys: Vec, - merkle_transition_proofs: Vec, -} - -impl Aux { - fn merkle_proof(v: MerkleClaimAndProof) -> Self { - Self { - merkle_proofs: vec![v], - ..Default::default() - } - } - fn secret_key(v: SecretKey) -> Self { - Self { - secret_keys: vec![v], - ..Default::default() - } - } - fn signed_by(v: SignedBy) -> Self { - Self { - signed_bys: vec![v], - ..Default::default() - } - } - fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { - Self { - merkle_transition_proofs: vec![v], - ..Default::default() - } - } -} - -fn operation_verify( - st: mainpod::Statement, - op: mainpod::Operation, - prev_statements: Vec, - aux: Aux, -) -> Result<()> { - let params = Params { - max_public_key_of: aux.secret_keys.len(), - max_signed_by: aux.signed_bys.len(), - containers: middleware::ParamsContainers { - state: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: aux.merkle_proofs.len(), - }, - transition: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: aux.merkle_transition_proofs.len(), - }, - max_depth_small: 8, - max_depth_medium: 32, - }, - max_custom_predicate_verifications: 0, - max_custom_predicates: 0, - ..Default::default() - }; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_target = builder.add_virtual_statement(false); - let op_target = builder.add_virtual_operation(¶ms); - let prev_statements_target: Vec<_> = (0..prev_statements.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let prev_statement_flatteneds_target: Vec> = prev_statements_target - .iter() - .map(|st| st.flatten()) - .collect(); - let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target - .iter() - .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) - .collect(); - - let merkle_proofs_target = MerkleProofsTarget { - medium: aux - .merkle_proofs - .iter() - .map(|_| { - MerkleClaimAndProofTarget::new_virtual( - params.containers.max_depth_medium, - &mut builder, - ) - }) - .collect(), - small: Vec::new(), - }; - - let secret_keys_target: Vec<_> = aux - .secret_keys - .iter() - .map(|sk| builder.constant_biguint320(&sk.0)) - .collect(); - - let signed_by_targets: Vec<_> = aux - .signed_bys - .iter() - .map(|_| SignedByTarget::new_virtual(&mut builder)) - .collect(); - - let merkle_transition_proofs_target = MerkleTransitionProofsTarget { - medium: aux - .merkle_transition_proofs - .iter() - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.containers.max_depth_medium, - &mut builder, - ) - }) - .collect(), - small: Vec::new(), - }; - - let aux_table = build_operation_aux_table_circuit( - ¶ms, - &mut builder, - &merkle_proofs_target, - &merkle_transition_proofs_target, - &secret_keys_target, - &signed_by_targets, - &[], - &[], - )?; - - verify_operation_circuit( - ¶ms, - &mut builder, - &st_target, - &op_target, - &prev_statement_flatteneds_target, - &prev_statement_hashes_target, - &aux_table, - )?; - - let mut pw = PartialWitness::::new(); - st_target.set_targets(&mut pw, &st)?; - op_target.set_targets(&mut pw, ¶ms, &op)?; - for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { - prev_st_target.set_targets(&mut pw, prev_st)?; - } - for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { - signed_by_target.set_targets(&mut pw, signed_by)? - } - for (merkle_proof_target, merkle_proof) in merkle_proofs_target - .medium - .iter() - .zip(aux.merkle_proofs.iter()) - { - merkle_proof_target.set_targets(&mut pw, merkle_proof)? - } - for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in - merkle_transition_proofs_target - .medium - .iter() - .zip(aux.merkle_transition_proofs.iter()) - { - merkle_tree_state_transition_proof_target - .set_targets(&mut pw, merkle_tree_state_transition_proof)? - } - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) -} - -#[test] -fn test_lt_lteq_verify_failures() { - let invalid_int = RawValue([ - GoldilocksField::NEG_ONE, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - ]); - - let prev_statements = [Statement::None.into()]; - - [ - // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(55, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -56).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(-55, -56).into(), - ), - // 56 < p-1 and p-1 <= p-1 should fail to verify, where p - // is the Goldilocks prime and 'p-1' occupies a single - // limb. - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, invalid_int).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(invalid_int, invalid_int).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }); -} - -#[test] -fn test_eq_neq_verify_failures() { - let prev_statements = [Statement::None.into()]; - - [ - // 56 == 55, 55 != 55 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::equal(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::not_equal(55, 55).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) - }); -} - -#[test] -fn test_operation_verify_none() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::None), - vec![], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_copy() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::CopyStatement), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_eq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 55}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_neq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 75}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_lt() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"hello" => 56}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < negative - let dict3 = dict!({"hola" => -56}); - let dict4 = dict!({"mundo" => -55}); - let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); - let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict4, "mundo")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < positive - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_lteq() -> Result<()> { - let local = dict!({ - "n55" => 55, - "n56" => 56, - "n_56" => -56, - "n_55" => -55, - }); - let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n55")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= negative - let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); - let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_55")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= positive - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - - // Also check equality, both positive and negative. - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(1), OperationArg::Index(1)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_hashof() -> Result<()> { - let input_values = [ - Value::from(RawValue([ - GoldilocksField(1), - GoldilocksField(2), - GoldilocksField(3), - GoldilocksField(4), - ])), - Value::from(512), - ]; - let v1 = hash_values(&input_values); - let [v2, v3] = input_values; - - let local = dict!({ - "hola" => v1, - "mundo" => v2.clone(), - "!" => v3.clone(), - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); - - let st: mainpod::Statement = Statement::hash_of( - AnchoredKey::from((&local, "hola")), - AnchoredKey::from((&local, "mundo")), - AnchoredKey::from((&local, "!")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::HashOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_sumof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (sum, overflow) = a.overflowing_add(b); - overflow.not().then_some((a, b, sum)) - }) - .try_for_each(|(a, b, sum)| { - let local = dict!({ - "sum" => sum, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) -} - -#[test] -fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { - let local = dict!({ - "a" => 3, - "noise" => 99, - "sum" => 6, - }); - let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); - let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); - let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "a")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - // Non-monotonic and repeated indices to stress random-access resolution. - OperationArg::Index(2), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = vec![st_a, st_noise, st_sum]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_productof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (prod, overflow) = a.overflowing_mul(b); - overflow.not().then_some((a, b, prod)) - }) - .try_for_each(|(a, b, prod)| { - let local = dict!({ - "prod" => prod, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "prod", prod).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::product_of( - AnchoredKey::from((&local, "prod")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ProductOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) -} - -#[test] -fn test_operation_verify_maxof() -> Result<()> { - I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { - let max = i64::max(a, b); - let local = dict!({ - "max" => max, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::max_of( - AnchoredKey::from((&local, "max")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) -} - -#[test] -fn test_operation_verify_maxof_failures() { - [(5, 3, 4), (5, 5, 8), (3, 4, 5)] - .into_iter() - .for_each(|(max, a, b)| { - let st: mainpod::Statement = Statement::max_of(max, a, b).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = [Statement::None.into()]; - - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }) -} - -#[test] -fn test_operation_verify_lt_to_neq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 20, - }); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st1: mainpod::Statement = Statement::lt( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtToNotEqual), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![st1]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_transitive_eq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 10, - "c" => 10, - }); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let st1: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st2: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "b")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::TransitiveEqualFromStatements), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) -} - -#[test] -fn test_operation_verify_sintains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(5); - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - - let no_key_pf = mt.prove_nonexistence(&key.raw())?; - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); - let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotContainsFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::MerkleProofIndex(Size::Medium, 0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); - let prev_statements = vec![root_st, key_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) -} - -#[test] -fn test_operation_verify_contains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(175); - let (value, key_pf) = mt.prove(&key.raw())?; - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - "value" => value, - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - let value_ak = AnchoredKey::from((&local, "value")); - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); - let value_st: mainpod::Statement = Statement::contains(local.clone(), "value", value).into(); - - let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainsFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::MerkleProofIndex(Size::Medium, 0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); - let prev_statements = vec![root_st, key_st, value_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) -} - -#[test] -fn test_operation_verify_merkle_insert() -> Result<()> { - let mut tree = MerkleTree::new(&[].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerInsertFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) -} - -#[test] -fn test_operation_verify_merkle_update() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.update(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerUpdateFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) -} - -#[test] -fn test_operation_verify_merkle_delete() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let state_transition_proof = tree.delete(&key.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerDeleteFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) -} - -#[test] -fn test_operation_verify_publickeyof_ok() -> Result<()> { - [ - SecretKey(BigUint::one()), - SecretKey::new_rand(), - SecretKey(&*GROUP_ORDER - BigUint::one()), - ] - .into_iter() - .try_for_each(|secret_key| { - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) - }) -} - -#[test] -fn test_operation_verify_publickeyof_failure_wrong_key() { - let secret_key = SecretKey(BigUint::one()); - let public_key = SecretKey(BigUint::ZERO).public_key(); - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) -} - -#[test] -fn test_operation_verify_publickeyof_failure_pk_type() { - let secret_key = SecretKey(BigUint::one()); - let public_key = 123i64; - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) -} - -#[test] -fn test_operation_verify_publickeyof_failure_sk_type() { - let secret_key = 123i64; - let public_key = SecretKey(BigUint::from(123u32)).public_key(); - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); - assert!(operation_verify(st, op, prev_statements, aux,).is_err()) -} - -#[test] -fn test_operation_verify_publickeyof_failure_sk_size() { - let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) -} - -#[test] -fn test_operation_verify_signedby_ok() -> Result<()> { - let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); - let pk = sk.public_key(); - let msg = RawValue([F(1), F(2), F(3), F(4)]); - let nonce = BigUint::from_u32(123).unwrap(); - let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); - let signed_by = SignedBy { - msg, - pk, - sig: sig.clone(), - }; - - let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SignedBy), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::SignedByIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) -} - -#[test] -fn test_operation_replace_value_with_entry() -> Result<()> { - let d = dict!({"a" => 42, "b" => 33}); - - // 0: None - // 1: Lt(5, 42) - let st_in: mainpod::Statement = Statement::lt(5, 42).into(); - // 2: Contains(d, "a", 42) - let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); - - let st_out: mainpod::Statement = - Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); - let mut op_args: Vec<_> = iter::repeat(OperationArg::None) - .take(BASE_PARAMS.max_statement_args + 1) - .collect(); - op_args[1] = OperationArg::Index(2); - op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ReplaceValueWithEntry), - op_args, - OperationAux::None, - ); - - let prev_statements = vec![Statement::None.into(), st_in, st_entry]; - operation_verify(st_out, op, prev_statements, Aux::default()) -} - -fn helper_statement_arg_from_template( - params: &Params, - st_tmpl_arg: StatementTmplArg, - args: Vec, - expected_st_arg: StatementArg, -) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_arg_target = make_statement_arg_from_template_circuit( - params, - &mut builder, - &st_tmpl_arg_target, - &args_target, - ); - // TODO: Instead of connect, assign witness to result - let expected_st_arg_target = builder.add_virtual_statement_arg(); - builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); - - let mut pw = PartialWitness::::new(); - - st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) -} - -#[test] -fn test_statement_arg_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // case: None - let st_tmpl_arg = StatementTmplArg::None; - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::None; - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: Literal - let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("foo")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: AnchoredKey(id_wildcard, key_literal) - let st_tmpl_arg = - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: WildcardLiteral(wildcard) - let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); - let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("key")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - Ok(()) -} - -fn helper_statement_from_template( - params: &Params, - st_tmpl: StatementTmpl, - args: Vec, - expected_st: Statement, -) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_target = builder.add_virtual_statement_tmpl(false); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = - make_statement_from_template_circuit(params, &mut builder, &st_tmpl_target, &args_target); - // TODO: Instead of connect, assign witness to result - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - let mut pw = PartialWitness::::new(); - - st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_target.set_targets(&mut pw, &expected_st.into())?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) -} - -#[test] -fn test_statement_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st = Statement::equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); - let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; - let expected_st = Statement::not_equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - Ok(()) -} - -fn helper_custom_operation_verify_gadget( - params: &Params, - custom_predicate: CustomPredicateRef, - mut op_args: Vec, - mut args: Vec, - expected_st: Option, -) -> Result<()> { - // Pad - for _ in op_args.len()..BASE_PARAMS.max_operation_args { - op_args.push(Statement::None); - } - for _ in args.len()..params.max_custom_predicate_wildcards { - args.push(Value::from(EMPTY_VALUE)); - } - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); - let op_args_target: Vec<_> = (0..op_args.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let (st_target, op_type_target) = make_custom_statement_circuit( - params, - &mut builder, - &custom_predicate_target, - &op_args_target, - &args_target, - )?; - - let mut pw = PartialWitness::::new(); - - // Input - custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; - for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { - op_arg_target.set_targets(&mut pw, &op_arg.into())?; - } - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; - } - // Expected Output - if let Some(expected_st) = expected_st { - st_target.set_targets(&mut pw, &expected_st.into())?; - } - - let expected_op_type = OperationType::Custom(custom_predicate); - op_type_target.set_targets(&mut pw, &expected_op_type)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) -} - -fn value_ref(v: impl Into) -> ValueRef { - v.into() -} - -// TODO: Add negative tests -#[test] -fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("id", "key")) - .arg("secret"); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (1) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::None, - ]; - let args = vec![Value::from(dict), Value::from(0)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (2) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::None, - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - Ok(()) -} - -#[test] -fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("secret_id", "key")) - .arg(("id", "score")); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret_id"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let secret_dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND (0) Sanity check with correct values - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // AND (1) Different dict for same wildcard - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!( - helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) - .is_err() - ); - - // AND (2) key doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!( - helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) - .is_err() - ); - - // AND (3) literal doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal( - AnchoredKey::new(dict, Key::from("score")), - Value::from(0xbad), - ), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!( - helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) - .is_err() - ); - - // AND (4) predicate doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::not_equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!( - helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) - .is_err() - ); - - // OR (1) Two Nones - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![Statement::None, Statement::None]; - let args = vec![Value::from(dict), Value::from(0)]; - - assert!( - helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None) - .is_err() - ); - - Ok(()) -} - -fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let statements_target = (0..params.max_public_statements) - .map(|_| builder.add_virtual_statement(false)) - .collect_vec(); - let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); - - let mut pw = PartialWitness::::new(); - - // Input - let statements = statements - .iter() - .map(|st| { - let mut st = mainpod::Statement::from(st.clone()); - pad_statement(&mut st); - st - }) - .collect_vec(); - for (st_target, st) in statements_target.iter().zip(statements.iter()) { - st_target.set_targets(&mut pw, st)?; - } - // Expected Output - let expected_sts_hash = calculate_statements_hash(&statements); - pw.set_hash_target( - sts_hash_target, - HashOut { - elements: expected_sts_hash.0, - }, - )?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) -} - -#[test] -fn test_calculate_sts_hash() -> frontend::Result<()> { - assert_eq!(Params::num_public_statements_hash(), 16); - // Case with no public public statements - let params = Params { - max_public_statements: 0, - ..Default::default() - }; - - helper_calculate_statements_hash(¶ms, &[]).unwrap(); - - // Case with number of statements for the sts_hash equal to number of public statements - let params = Params { - max_public_statements: Params::num_public_statements_hash(), - ..Default::default() - }; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let statements = (0..Params::num_public_statements_hash()) - .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - // Case with more statements for the sts_hash than the number of public statements - let params = Params { - max_public_statements: 4, - ..Default::default() - }; - - let dict2 = Hash([F(5), F(6), F(7), F(8)]); - let statements = [ - Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), - Statement::equal( - AnchoredKey::from((dict, "bar")), - AnchoredKey::from((dict, "baz")), - ), - Statement::lt( - AnchoredKey::from((dict2, "one")), - AnchoredKey::from((dict2, "two")), - ), - ] - .into_iter() - .chain(iter::repeat(Statement::None)) - .take(params.max_public_statements) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - Ok(()) -} - -#[test] -fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { - let params = Params::default(); - - // Build a batch with two predicates: - // pred_A: Equal(x, y) - // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash - use NativePredicate as NP; - let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) - .arg("x") - .arg("y"); - cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) - .unwrap(); - - // Build pred_B's template manually with SelfPredicateHash(0) - let stb_b_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), - args: vec![ - StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), - StatementTmplArg::SelfPredicateHash(0), - ], - }; - let pred_b = CustomPredicate::new( - ¶ms, - "pred_B".into(), - true, - vec![stb_b_tmpl], - 1, - vec!["x".to_string()], - ) - .unwrap(); - cpb.predicates.push(pred_b); - let batch = cpb.finish().unwrap(); - - // Compute the expected resolved hash of pred_A - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); - let expected_pred_a_value = Value::from(pred_a_hash); - - // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to - // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce - // a statement with that literal value. - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - // Create the template target and batch id target - let st_tmpl_target = builder.add_virtual_statement_tmpl(true); - let batch_id = builder.add_virtual_hash(); - - // Normalize the template (this is what we're testing) - let normalized = normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); - - // Feed normalized template into statement generation - let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = - make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); - - // Connect to expected output - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - // Set witness - let mut pw = PartialWitness::::new(); - st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; - pw.set_target_arr(&batch_id.elements, &batch.id().0)?; - - let some_value = Value::from(42); - // args: first wildcard is "x" = some_value, rest are padding - let mut args_values = vec![some_value.clone()]; - for _ in 1..params.max_custom_predicate_wildcards { - args_values.push(Value::from(EMPTY_VALUE)); - } - for (target, value) in args_target.iter().zip(args_values.iter()) { - target.set_targets(&mut pw, value)?; - } - - // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) - let expected_st: crate::backends::plonky2::mainpod::Statement = - Statement::equal(some_value, expected_pred_a_value).into(); - expected_st_target.set_targets(&mut pw, &expected_st)?; - - // Build and verify - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) -} diff --git a/src/backends/plonky2/circuits/mux_table.rs b/src/backends/plonky2/circuits/mux_table.rs index c93d0e8..110dac9 100644 --- a/src/backends/plonky2/circuits/mux_table.rs +++ b/src/backends/plonky2/circuits/mux_table.rs @@ -107,11 +107,11 @@ impl MuxTableTarget { rev_resolved_tagged_flattened.reverse(); let resolved_tagged_flattened = rev_resolved_tagged_flattened; - builder.add_simple_generator(TableGetGenerator::new( - index.clone(), - self.tagged_entries.clone(), - resolved_tagged_flattened.clone(), - )); + builder.add_simple_generator(TableGetGenerator { + index: index.clone(), + tagged_entries: self.tagged_entries.clone(), + get_tagged_entry: resolved_tagged_flattened.clone(), + }); measure_gates_end!(builder, measure); TableEntryTarget { params: self.params.clone(), @@ -123,18 +123,8 @@ impl MuxTableTarget { #[derive(Debug, Clone, Default)] pub struct TableGetGenerator { index: IndexTarget, - entries: Vec>, - revealed_entry: Vec, -} - -impl TableGetGenerator { - pub fn new(index: IndexTarget, entries: Vec>, revealed_entry: Vec) -> Self { - Self { - index, - entries, - revealed_entry, - } - } + tagged_entries: Vec>, + get_tagged_entry: Vec, } impl, const D: usize> SimpleGenerator for TableGetGenerator { @@ -145,7 +135,7 @@ impl, const D: usize> SimpleGenerator for Tab fn dependencies(&self) -> Vec { [self.index.low, self.index.high] .into_iter() - .chain(self.entries.iter().flatten().copied()) + .chain(self.tagged_entries.iter().flatten().copied()) .collect() } @@ -158,12 +148,12 @@ impl, const D: usize> SimpleGenerator for Tab let index_high = witness.get_target(self.index.high); let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); - let entry = witness.get_targets(&self.entries[index as usize]); + let entry = witness.get_targets(&self.tagged_entries[index as usize]); - for (target, value) in self.revealed_entry.iter().zip( + for (target, value) in self.get_tagged_entry.iter().zip( entry .iter() - .chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())), + .chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), ) { out_buffer.set_target(*target, *value)?; } @@ -176,12 +166,12 @@ impl, const D: usize> SimpleGenerator for Tab dst.write_target(self.index.low)?; dst.write_target(self.index.high)?; - dst.write_usize(self.entries.len())?; - for entry in &self.entries { - dst.write_target_vec(entry)?; + dst.write_usize(self.tagged_entries.len())?; + for tagged_entry in &self.tagged_entries { + dst.write_target_vec(tagged_entry)?; } - dst.write_target_vec(&self.revealed_entry) + dst.write_target_vec(&self.get_tagged_entry) } fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { @@ -191,16 +181,16 @@ impl, const D: usize> SimpleGenerator for Tab high: src.read_target()?, }; let len = src.read_usize()?; - let mut entries = Vec::with_capacity(len); + let mut tagged_entries = Vec::with_capacity(len); for _ in 0..len { - entries.push(src.read_target_vec()?); + tagged_entries.push(src.read_target_vec()?); } - let revealed_entry = src.read_target_vec()?; + let get_tagged_entry = src.read_target_vec()?; Ok(Self { index, - entries, - revealed_entry, + tagged_entries, + get_tagged_entry, }) } } diff --git a/src/backends/plonky2/error.rs b/src/backends/plonky2/error.rs index 6d57568..355eaf1 100644 --- a/src/backends/plonky2/error.rs +++ b/src/backends/plonky2/error.rs @@ -61,8 +61,8 @@ macro_rules! new { } use InnerError::*; impl Error { - pub fn custom(s: impl Into) -> Self { - new!(Custom(s.into())) + pub fn custom(s: String) -> Self { + new!(Custom(s)) } pub fn plonky2_proof_fail(context: impl Into, e: anyhow::Error) -> Self { Self::Plonky2ProofFail(context.into(), e) diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 513b1da..341e295 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,5 +1,5 @@ pub mod operation; -use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS}; +use crate::middleware::{wildcard_values_from_op_st, PodType}; pub mod statement; use std::iter; @@ -39,7 +39,7 @@ use crate::{ middleware::{ self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, - ToFields, VDSet, Value, ValueRef, + ToFields, VDSet, Value, }, timed, }; @@ -104,20 +104,8 @@ pub(crate) fn extract_custom_predicate_verifications( if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Statement::Custom(st_cpr, st_args) = st { assert_eq!(cpr, st_cpr); - // The custom operation outputs statements with literal arguments. They can be - // replaced by references later with ReplaceValueWithEntry. - let st_args = st_args - .iter() - .map(|arg| match arg { - ValueRef::Literal(v) => Ok(v.clone()), - _ => Err(Error::custom( - "custom operation cannot output entries as arguments", - )), - }) - .collect::>>()?; - let normalized_pred = cpr.normalized_predicate(); let wildcard_values = - wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args) + wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let custom_predicate_table_index = custom_predicates @@ -148,20 +136,14 @@ pub(crate) fn extract_custom_predicate_verifications( Ok(table) } -#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct MerkleProofs { - pub(crate) medium: Vec, - pub(crate) small: Vec, -} - /// Extracts Merkle proofs from Contains/NotContains ops. pub(crate) fn extract_merkle_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], statements: &[middleware::Statement], -) -> Result { - let mut tables = MerkleProofs::default(); +) -> Result> { + let mut table = Vec::new(); for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); let (root, key, value, pf) = match (op, st) { @@ -184,42 +166,31 @@ pub(crate) fn extract_merkle_proofs( } _ => continue, }; - let claim_proof = MerkleClaimAndProof::new(Hash::from(root), key, value, pf.clone()); - if pf.existence - // TODO: Make sure there's no off-by-one error here - && pf.siblings.len() <= params.containers.max_depth_small - && tables.small.len() < params.containers.state.max_small - { - aux_list[i] = OperationAux::MerkleProofIndex(Size::Small, tables.small.len()); - tables.small.push(claim_proof); - } else { - aux_list[i] = OperationAux::MerkleProofIndex(Size::Medium, tables.medium.len()); - tables.medium.push(claim_proof); - } + aux_list[i] = OperationAux::MerkleProofIndex(table.len()); + table.push(MerkleClaimAndProof::new( + Hash::from(root), + key, + value, + pf.clone(), + )); } - if tables.medium.len() > params.containers.state.max_medium { + if table.len() > params.max_merkle_proofs_containers { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - tables.medium.len(), - params.containers.state.max_medium + table.len(), + params.max_merkle_proofs_containers ))); } - Ok(tables) -} - -#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct MerkleTransitionProofs { - pub(crate) medium: Vec, - pub(crate) small: Vec, + Ok(table) } /// Extracts Merkle state transition proofs from container update ops. -pub(crate) fn extract_merkle_transition_proofs( +pub(crate) fn extract_merkle_tree_state_transition_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], -) -> Result { - let mut tables = MerkleTransitionProofs::default(); +) -> Result> { + let mut table = Vec::new(); for (i, op) in operations.iter().enumerate() { let pf = match op { middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf) @@ -227,27 +198,17 @@ pub(crate) fn extract_merkle_transition_proofs( | middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(), _ => continue, }; - if pf.op_proof.existence - // TODO: Make sure there's no off-by-one error here - && pf.siblings.len() <= params.containers.max_depth_small - && tables.small.len() < params.containers.transition.max_small - { - aux_list[i] = OperationAux::MerkleTransitionProofIndex(Size::Small, tables.small.len()); - tables.small.push(pf); - } else { - aux_list[i] = - OperationAux::MerkleTransitionProofIndex(Size::Medium, tables.medium.len()); - tables.medium.push(pf); - } + aux_list[i] = OperationAux::MerkleTreeStateTransitionProofIndex(table.len()); + table.push(pf); } - if tables.medium.len() > params.containers.transition.max_medium { + if table.len() > params.max_merkle_tree_state_transition_proofs_containers { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - tables.medium.len(), - params.containers.transition.max_medium + table.len(), + params.max_merkle_tree_state_transition_proofs_containers ))); } - Ok(tables) + Ok(table) } pub(crate) fn extract_public_key_of( @@ -264,10 +225,11 @@ pub(crate) fn extract_public_key_of( ) = (op, st) { let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); - let value = value_from_op(sk_s, sk_ref).ok_or_else(deduction_err)?; - let sk = value - .as_secret_key() - .ok_or_else(|| Error::custom("{value} not SecretKey"))?; + let sk = SecretKey::try_from( + value_from_op(sk_s, sk_ref) + .ok_or_else(deduction_err)? + .typed(), + )?; aux_list[i] = OperationAux::PublicKeyOfIndex(table.len()); table.push(sk); } @@ -321,9 +283,7 @@ pub(crate) fn extract_signatures( aux_list[i] = OperationAux::SignedByIndex(table.len()); table.push(SignedBy { msg: msg.raw(), - pk: pk - .as_public_key() - .ok_or_else(|| Error::custom(format!("{pk} is not PublicKey")))?, + pk: PublicKey::try_from(pk.typed())?, sig: sig.clone(), }); } @@ -367,8 +327,8 @@ pub fn pad_statement(s: &mut Statement) { fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) } -fn pad_operation_args(args: &mut Vec) { - fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args) +fn pad_operation_args(params: &Params, args: &mut Vec) { + fill_pad(args, OperationArg::None, params.max_operation_args) } /// Returns the statements from the given MainPodInputs, padding to the respective max lengths @@ -466,7 +426,7 @@ pub(crate) fn process_private_statements_operations( .map(|mid_arg| find_op_arg(statements, mid_arg)) .collect::>>()?; - pad_operation_args(&mut args); + pad_operation_args(params, &mut args); operations.push(Operation(op.op_type(), args, *aux)); } Ok(operations) @@ -497,11 +457,7 @@ pub(crate) fn process_public_statements_operations( OperationAux::None, ) }; - fill_pad( - &mut op.1, - OperationArg::None, - BASE_PARAMS.max_operation_args, - ); + fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); operations.push(op); } Ok(operations) @@ -511,7 +467,6 @@ pub struct Prover {} impl MainPodProver for Prover { fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result> { - assert_eq!(inputs.statements.len(), inputs.operations.len()); // Pad input recursive pods with empty pods if necessary let empty_pod = if inputs.pods.len() == params.max_input_pods { // We don't need padding so we skip creating an EmptyPod @@ -540,8 +495,6 @@ impl MainPodProver for Prover { let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; - let merkle_transition_proofs = - extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; let custom_predicates = extract_custom_predicates(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, @@ -566,6 +519,9 @@ impl MainPodProver for Prover { let signed_bys = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; + let merkle_tree_state_transition_proofs = + extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; + let (statements, public_statements) = layout_statements(params, false, &inputs)?; let operations = process_private_statements_operations( params, @@ -598,15 +554,20 @@ impl MainPodProver for Prover { .collect_vec(); let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len()); - let pad_vd_mt_proof = inputs.vd_set.get_vds_proof_0(); for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) { vd_mt_proofs.push(if pod.is_main() { - inputs.vd_set.get_vds_proof(vd)? + (true, inputs.vd_set.get_vds_proof(vd)?) } else { // For intro pods we don't verify inclusion of their vk into the vd set, so we - // use a valid vds proof that matches the expected root but not the value to pass - // the constraints - pad_vd_mt_proof.clone() + // generate a dummy mt proof with expected root and value to pass some constraints + ( + false, + MerkleClaimAndProof { + root: inputs.vd_set.root(), + value: RawValue::from(pod.verifier_data_hash()), + ..MerkleClaimAndProof::empty() + }, + ) }); } @@ -619,7 +580,7 @@ impl MainPodProver for Prover { merkle_proofs, public_key_of_sks, signed_bys, - merkle_transition_proofs, + merkle_tree_state_transition_proofs, custom_predicates_with_mpt_proofs, custom_predicate_verifications, }; @@ -1006,18 +967,7 @@ pub mod tests { max_statements: 2, max_public_statements: 1, max_input_pods_public_statements: 0, - containers: middleware::ParamsContainers { - state: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 0, - }, - transition: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 0, - }, - max_depth_small: 8, - max_depth_medium: 32, - }, + max_merkle_proofs_containers: 0, max_public_key_of: 0, max_custom_predicate_verifications: 0, max_custom_predicates: 0, @@ -1053,23 +1003,15 @@ pub mod tests { max_input_pods_public_statements: 2, max_statements: 5, max_public_statements: 2, + max_operation_args: 5, max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, + max_merkle_proofs_containers: 2, + max_merkle_tree_state_transition_proofs_containers: 2, max_public_key_of: 2, + max_depth_mt_containers: 4, max_depth_mt_vds: 6, - containers: middleware::ParamsContainers { - state: middleware::ParamsMerkleProofs { - max_small: 2, - max_medium: 2, - }, - transition: middleware::ParamsMerkleProofs { - max_small: 2, - max_medium: 2, - }, - max_depth_small: 2, - max_depth_medium: 4, - }, }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); @@ -1126,20 +1068,11 @@ pub mod tests { max_input_pods: 0, max_statements: 9, max_public_statements: 4, + max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, - containers: middleware::ParamsContainers { - state: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 3, - }, - transition: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 0, - }, - max_depth_small: 8, - max_depth_medium: 32, - }, + max_merkle_proofs_containers: 3, + max_merkle_tree_state_transition_proofs_containers: 0, ..Default::default() }; println!("{:#?}", params); @@ -1162,7 +1095,7 @@ pub mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?; - let cpb = cpb_builder.finish()?; + let cpb = cpb_builder.finish(); let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); @@ -1196,72 +1129,6 @@ pub mod tests { Ok(pod.verify()?) } - #[test] - fn test_main_self_predicate_hash() -> frontend::Result<()> { - use frontend::BuilderArg; - - let params = Params { - max_signed_by: 0, - max_input_pods: 0, - max_statements: 6, - max_public_statements: 2, - max_custom_predicate_wildcards: 4, - max_custom_predicate_verifications: 2, - containers: middleware::ParamsContainers { - state: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 0, - }, - transition: middleware::ParamsMerkleProofs { - max_small: 0, - max_medium: 0, - }, - max_depth_small: 8, - max_depth_medium: 32, - }, - ..Default::default() - }; - let mut vds = DEFAULT_VD_LIST.clone(); - vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(&vds); - - // Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash - let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb_a = STB::new_from_pred(NP::Equal) - .arg("x") - .arg(BuilderArg::SelfPredicateHash("pred_B".into())); - cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?; - - let stb_b = STB::new_from_pred(NP::Equal) - .arg("x") - .arg(BuilderArg::SelfPredicateHash("pred_A".into())); - cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?; - - let batch = cpb.finish()?; - - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash()); - - // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) - let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set); - let eq_st = - pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?; - pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?; - - // Mock - let prover = MockProver {}; - let pod = pod_builder.prove(&prover)?; - assert!(pod.pod.verify().is_ok()); - - // Real - let prover = Prover {}; - let pod = pod_builder.prove(&prover)?; - let pod = (pod.pod as Box).downcast::().unwrap(); - - Ok(pod.verify()?) - } - #[test] fn test_set_contains() -> frontend::Result<()> { let params = Params::default(); @@ -1325,108 +1192,10 @@ pub mod tests { ); let st = middleware::Statement::Custom( cpr, - [1, 1, 2] - .into_iter() - .map(middleware::ValueRef::from) - .collect(), + [1, 1, 2].into_iter().map(middleware::Value::from).collect(), ); - builder.insert((st.clone(), op)).unwrap(); - builder.reveal(&st).unwrap(); + builder.insert(true, (st, op)).unwrap(); let prover = Prover {}; builder.prove(&prover).unwrap(); } - - #[test] - fn test_replace_value_with_entry() { - let params = middleware::Params::default(); - let vd_set = &*DEFAULT_VD_SET; - let mut builder = MainPodBuilder::new(¶ms, vd_set); - let d = dict!({"a" => 42, "b" => 33}); - builder - .priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42)) - .unwrap(); - let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap(); - // Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)` - builder - .pub_op(frontend::Operation::replace_value_with_entry( - vec![None, Some((&d, "a"))], - st, - )) - .unwrap(); - - // Mock - let prover = MockProver {}; - let pod = builder.prove(&prover).unwrap(); - pod.pod.verify().unwrap(); - assert_eq!( - middleware::Statement::Lt( - middleware::ValueRef::Literal(Value::from(5)), - middleware::ValueRef::Key(middleware::AnchoredKey { - root: d.commitment(), - key: middleware::Key::from("a") - }) - ), - pod.public_statements[0] - ); - - // Real - let prover = Prover {}; - let pod = builder.prove(&prover).unwrap(); - pod.pod.verify().unwrap() - } - - #[test] - fn test_entry_custom_statement_arg() { - let params = middleware::Params::default(); - let vd_set = &*DEFAULT_VD_SET; - let input = r#" - PredA(x) = AND( - Lt(x, 100) - ) - - PredB(d) = AND( - PredA(d.x) - ) - "#; - let module = load_module(input, "my_mod", ¶ms, &[]).expect("lang parse"); - let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap(); - let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap(); - - let mut builder = MainPodBuilder::new(¶ms, vd_set); - let d = dict!({"x" => 42, "y" => 33}); - - let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap(); - let st_a = builder - .priv_op(frontend::Operation::custom(pred_a, [st_lt])) - .unwrap(); - builder - .priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42)) - .unwrap(); - // Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)` - let st_a1 = builder - .priv_op(frontend::Operation::replace_value_with_entry( - vec![Some((&d, "x"))], - st_a, - )) - .unwrap(); - - builder - .pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1])) - .unwrap(); - - // Mock - let prover = MockProver {}; - let pod = builder.prove(&prover).unwrap(); - pod.pod.verify().unwrap(); - let expected = middleware::Statement::Custom( - pred_b, - vec![middleware::ValueRef::Literal(Value::from(d))], - ); - assert_eq!(expected, pod.public_statements[0]); - - // Real - let prover = Prover {}; - let pod = builder.prove(&prover).unwrap(); - pod.pod.verify().unwrap() - } } diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index 2060ac7..d7b44bb 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ error::{Error, Result}, - mainpod::{MerkleProofs, MerkleTransitionProofs, SignedBy, Statement}, + mainpod::{SignedBy, Statement}, + primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, }, middleware::{self, OperationType, Params}, }; @@ -29,89 +30,50 @@ impl OperationArg { } } -#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] -pub enum Size { - Small, - Medium, -} - -impl fmt::Display for Size { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Small => write!(f, "small"), - Self::Medium => write!(f, "medium"), - } - } -} - -impl Size { - pub const fn min() -> Self { - Self::Small - } - pub const fn max() -> Self { - Self::Medium - } -} - #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum OperationAux { None, - MerkleProofIndex(Size, usize), - MerkleTransitionProofIndex(Size, usize), + MerkleProofIndex(usize), PublicKeyOfIndex(usize), SignedByIndex(usize), + MerkleTreeStateTransitionProofIndex(usize), CustomPredVerifyIndex(usize), } impl OperationAux { - fn table_offset_merkle_proof(params: &Params, size: Size) -> usize { - match size { - // At index 0 we store a zero entry - Size::Small => 1, - Size::Medium => { - Self::table_offset_merkle_proof(params, Size::Small) - + params.containers.state.max_small - } - } - } - fn table_offset_merkle_transition_proof(params: &Params, size: Size) -> usize { - match size { - Size::Small => { - Self::table_offset_merkle_proof(params, Size::min()) - + params.containers.state.max_total() - } - Size::Medium => { - Self::table_offset_merkle_transition_proof(params, Size::Small) - + params.containers.transition.max_small - } - } - } - fn table_offset_custom_pred_verify(params: &Params) -> usize { - Self::table_offset_merkle_transition_proof(params, Size::min()) - + params.containers.transition.max_total() + fn table_offset_merkle_proof(_params: &Params) -> usize { + // At index 0 we store a zero entry + 1 } fn table_offset_public_key_of(params: &Params) -> usize { - Self::table_offset_custom_pred_verify(params) + params.max_custom_predicate_verifications + Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers } fn table_offset_signed_by(params: &Params) -> usize { Self::table_offset_public_key_of(params) + params.max_public_key_of } + fn table_offset_merkle_tree_state_transition_proof(params: &Params) -> usize { + Self::table_offset_signed_by(params) + params.max_signed_by + } + fn table_offset_custom_pred_verify(params: &Params) -> usize { + Self::table_offset_merkle_tree_state_transition_proof(params) + + params.max_merkle_tree_state_transition_proofs_containers + } pub(crate) fn table_size(params: &Params) -> usize { - 1 + params.containers.state.max_total() - + params.containers.transition.max_total() - + params.max_custom_predicate_verifications + 1 + params.max_merkle_proofs_containers + params.max_public_key_of + params.max_signed_by + + params.max_merkle_tree_state_transition_proofs_containers + + params.max_custom_predicate_verifications } pub fn table_index(&self, params: &Params) -> usize { match self { Self::None => 0, - Self::MerkleProofIndex(size, i) => Self::table_offset_merkle_proof(params, *size) + *i, - Self::MerkleTransitionProofIndex(size, i) => { - Self::table_offset_merkle_transition_proof(params, *size) + *i - } + Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i, Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i, Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i, + Self::MerkleTreeStateTransitionProofIndex(i) => { + Self::table_offset_merkle_tree_state_transition_proof(params) + *i + } Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i, } } @@ -134,8 +96,8 @@ impl Operation { &self, statements: &[Statement], signatures: &[SignedBy], - merkle_proofs: &MerkleProofs, - merkle_transition_proofs: &MerkleTransitionProofs, + merkle_proofs: &[MerkleClaimAndProof], + merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProof], ) -> Result { let deref_args = self .1 @@ -151,26 +113,17 @@ impl Operation { .collect::>>()?; let deref_aux = match self.2 { OperationAux::None => crate::middleware::OperationAux::None, - OperationAux::MerkleProofIndex(size, i) => { - let table = match size { - Size::Small => &merkle_proofs.small, - Size::Medium => &merkle_proofs.medium, - }; - crate::middleware::OperationAux::MerkleProof( - table - .get(i) - .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? - .proof - .clone(), - ) - } - OperationAux::MerkleTransitionProofIndex(size, i) => { - let table = match size { - Size::Small => &merkle_transition_proofs.small, - Size::Medium => &merkle_transition_proofs.medium, - }; + OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, + OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( + merkle_proofs + .get(i) + .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? + .proof + .clone(), + ), + OperationAux::MerkleTreeStateTransitionProofIndex(i) => { crate::middleware::OperationAux::MerkleTreeStateTransitionProof( - table + merkle_tree_state_transition_proofs .get(i) .ok_or(Error::custom(format!( "Missing Merkle state transition proof index {}", @@ -179,7 +132,6 @@ impl Operation { .clone(), ) } - OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature( signatures .get(i) @@ -213,14 +165,12 @@ impl fmt::Display for Operation { } match self.2 { OperationAux::None => (), - OperationAux::MerkleProofIndex(size, i) => { - write!(f, " {}_merkle_proof_{:02}", size, i)? - } + OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?, OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?, - OperationAux::MerkleTransitionProofIndex(size, i) => { - write!(f, " {}_merkle_transition_proof_{:02}", size, i)? + OperationAux::MerkleTreeStateTransitionProofIndex(i) => { + write!(f, " merkle_tree_state_transition_proof_{:02}", i)? } } Ok(()) diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 64fe675..27776a6 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -4,9 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::error::{Error, Result}, - middleware::{ - self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS, - }, + middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -98,15 +96,15 @@ impl TryFrom for middleware::Statement { )))?, }, Predicate::Custom(cpr) => { - let args: Vec = proper_args + let vs: Vec = proper_args .into_iter() .filter_map(|arg| match arg { - StatementArg::Literal(v) => Some(ValueRef::Literal(v)), - StatementArg::Key(k) => Some(ValueRef::Key(k)), - StatementArg::None => None, + SA::None => None, + SA::Literal(v) => Some(v), + _ => unreachable!(), }) .collect(); - S::Custom(cpr, args) + S::Custom(cpr, vs) } Predicate::Intro(ir) => { let vs: Vec = proper_args diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 8dd710a..dcb1355 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -11,12 +11,13 @@ use crate::{ basetypes::{Proof, VerifierOnlyCircuitData}, error::{Error, Result}, mainpod::{ - calculate_statements_hash, extract_merkle_proofs, extract_merkle_transition_proofs, - extract_signatures, layout_statements, process_private_statements_operations, - process_public_statements_operations, MerkleProofs, MerkleTransitionProofs, Operation, + calculate_statements_hash, extract_merkle_proofs, + extract_merkle_tree_state_transition_proofs, extract_signatures, layout_statements, + process_private_statements_operations, process_public_statements_operations, Operation, OperationAux, SignedBy, Statement, }, mock::emptypod::MockEmptyPod, + primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, recursion::hash_verifier_data, }, middleware::{ @@ -44,10 +45,10 @@ pub struct MockMainPod { operations: Vec, // public subset of the `statements` vector public_statements: Vec, - // All Merkle proofs for containers - merkle_proofs: MerkleProofs, - // All Merkle tree state transition proofs for containers - merkle_transition_proofs: MerkleTransitionProofs, + // All Merkle proofs + merkle_proofs_containers: Vec, + // All Merkle tree state transition proofs + merkle_tree_state_transition_proofs_containers: Vec, // All verified signatures signatures: Vec, } @@ -123,8 +124,8 @@ struct Data { public_statements: Vec, operations: Vec, statements: Vec, - merkle_proofs: MerkleProofs, - merkle_transition_proofs: MerkleTransitionProofs, + merkle_proofs: Vec, + merkle_tree_state_transition_proofs: Vec, signatures: Vec, input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>, } @@ -152,8 +153,8 @@ impl MockMainPod { let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; // Similarly for Merkle state transition proofs. - let merkle_transition_proofs = - extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; + let merkle_tree_state_transition_proofs = + extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; let signatures = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; @@ -184,8 +185,8 @@ impl MockMainPod { public_statements, statements, operations, - merkle_proofs, - merkle_transition_proofs, + merkle_proofs_containers: merkle_proofs, + merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, signatures, }) } @@ -259,8 +260,8 @@ impl Pod for MockMainPod { .deref( &self.statements[..input_statement_offset + i], &self.signatures, - &self.merkle_proofs, - &self.merkle_transition_proofs, + &self.merkle_proofs_containers, + &self.merkle_tree_state_transition_proofs_containers, )? .check_and_log(&self.params, &s.clone().try_into()?) .map_err(|e| e.into()) @@ -320,8 +321,10 @@ impl Pod for MockMainPod { public_statements: self.public_statements.clone(), operations: self.operations.clone(), statements: self.statements.clone(), - merkle_proofs: self.merkle_proofs.clone(), - merkle_transition_proofs: self.merkle_transition_proofs.clone(), + merkle_proofs: self.merkle_proofs_containers.clone(), + merkle_tree_state_transition_proofs: self + .merkle_tree_state_transition_proofs_containers + .clone(), signatures: self.signatures.clone(), input_pods, }) @@ -341,7 +344,7 @@ impl Pod for MockMainPod { operations, statements, merkle_proofs, - merkle_transition_proofs, + merkle_tree_state_transition_proofs, signatures, input_pods, } = serde_json::from_value(data)?; @@ -359,8 +362,8 @@ impl Pod for MockMainPod { public_statements, operations, statements, - merkle_proofs, - merkle_transition_proofs, + merkle_proofs_containers: merkle_proofs, + merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, signatures, }) } @@ -377,8 +380,7 @@ pub mod tests { great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET, }, - frontend::{self}, - middleware, + frontend, middleware, middleware::{Signer as _, Value}, }; diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index 67b7513..caf3727 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -207,7 +207,7 @@ impl Point { u: *u, }); points.find(|p| p.is_in_subgroup()).ok_or(Error::custom( - "One of the points must lie in the EC subgroup.", + "One of the points must lie in the EC subgroup.".into(), )) } pub fn as_bytes_from_subgroup(&self) -> Result, Error> { diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index f53a143..0c5978f 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -32,7 +32,7 @@ use crate::{ circuits::common::{CircuitBuilderPod, ValueTarget}, error::{Error, Result}, primitives::merkletree::{ - MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, MAX_DEPTH, + MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, }, }, measure_gates_begin, measure_gates_end, @@ -42,6 +42,8 @@ use crate::{ #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MerkleClaimAndProofTarget { pub(crate) max_depth: usize, + // `enabled` determines if the merkleproof verification is enabled + pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -119,9 +121,16 @@ pub fn verify_merkle_proof_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs) + // check that obtained_root==root (from inputs), when enabled==true + let zero = builder.zero(); + let expected_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) + .collect(); + let computed_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) + .collect(); for j in 0..HASH_SIZE { - builder.connect(obtained_root.elements[j], proof.root.elements[j]); + builder.connect(computed_root[j], expected_root[j]); } measure_gates_end!(builder, measure); } @@ -130,6 +139,7 @@ impl MerkleClaimAndProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleClaimAndProofTarget { max_depth, + enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -144,7 +154,12 @@ impl MerkleClaimAndProofTarget { } /// assigns the given values to the targets #[allow(clippy::too_many_arguments)] - pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + enabled: bool, + mp: &MerkleClaimAndProof, + ) -> Result<()> { if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( self.max_depth, @@ -152,6 +167,7 @@ impl MerkleClaimAndProofTarget { ))); } + pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -191,6 +207,8 @@ impl MerkleClaimAndProofTarget { #[derive(Clone, Serialize, Deserialize)] pub struct MerkleProofExistenceTarget { max_depth: usize, + // `enabled` determines if the merkleproof verification is enabled + pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -218,9 +236,16 @@ pub fn verify_merkle_proof_existence_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs) + // check that obtained_root==root (from inputs), when enabled==true + let zero = builder.zero(); + let expected_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) + .collect(); + let computed_root: Vec = (0..HASH_SIZE) + .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) + .collect(); for j in 0..HASH_SIZE { - builder.connect(obtained_root.elements[j], proof.root.elements[j]); + builder.connect(computed_root[j], expected_root[j]); } measure_gates_end!(builder, measure); @@ -231,6 +256,7 @@ impl MerkleProofExistenceTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleProofExistenceTarget { max_depth, + enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -239,7 +265,12 @@ impl MerkleProofExistenceTarget { } } /// assigns the given values to the targets - pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + enabled: bool, + mp: &MerkleClaimAndProof, + ) -> Result<()> { assert!(mp.proof.existence); // sanity check if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( @@ -248,6 +279,7 @@ impl MerkleProofExistenceTarget { ))); } + pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -424,6 +456,8 @@ fn hash_with_flag_target>( #[derive(Clone, Serialize, Deserialize)] pub struct MerkleTreeStateTransitionProofTarget { pub(crate) max_depth: usize, + // `enabled` determines if the merkleproof state transition verification is enabled + pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) op_proof: MerkleClaimAndProofTarget, @@ -477,6 +511,7 @@ pub fn verify_merkle_state_transition_circuit( }; let new_key_proof = MerkleProofExistenceTarget { max_depth: proof.max_depth, + enabled: proof.enabled, root, key: proof.op_key, value: proof.op_value, @@ -488,7 +523,13 @@ pub fn verify_merkle_state_transition_circuit( // Insert/Delete: Non-existence // Update: Existence let proof_type = is_update; - builder.connect(proof.op_proof.existence.target, proof_type.target); + builder.conditional_assert_eq( + proof.enabled.target, + proof.op_proof.existence.target, + proof_type.target, + ); + // 3.2) assert that proof.enabled matches with op_proof.enabled + builder.connect(proof.op_proof.enabled.target, proof.enabled.target); // 4) assert proof_non_existence.root corresponds to the root // specified by the op (old_root for Insert/Update and new_root @@ -504,9 +545,17 @@ pub fn verify_merkle_state_transition_circuit( }; for j in 0..HASH_SIZE { // 4.1) assert that proof.proof_non_existence.root == proof.old_root - builder.connect(proof.op_proof.root.elements[j], claim_root.elements[j]); + builder.conditional_assert_eq( + proof.enabled.target, + proof.op_proof.root.elements[j], + claim_root.elements[j], + ); // 4.2) assert that the non-existence proof uses the op_key (value not needed). - builder.connect(proof.op_proof.key.elements[j], proof.op_key.elements[j]); + builder.conditional_assert_eq( + proof.enabled.target, + proof.op_proof.key.elements[j], + proof.op_key.elements[j], + ); } // prepare value for check 5.2) @@ -544,7 +593,7 @@ pub fn verify_merkle_state_transition_circuit( .map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j])) .collect(); for j in 0..HASH_SIZE { - builder.connect(old_sibling_i[j], new_sibling_i[j]); + builder.conditional_assert_eq(proof.enabled.target, old_sibling_i[j], new_sibling_i[j]); } // 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that: @@ -562,7 +611,7 @@ pub fn verify_merkle_state_transition_circuit( let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level); // do the case2's checks - let sel = in_case_5_2; + let sel = builder.and(proof.enabled, in_case_5_2); for j in 0..HASH_SIZE { builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero); builder.conditional_assert_eq( @@ -592,6 +641,7 @@ impl MerkleTreeStateTransitionProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { Self { max_depth, + enabled: builder.add_virtual_bool_target_safe(), op: builder.add_virtual_target(), old_root: builder.add_virtual_hash(), @@ -611,6 +661,7 @@ impl MerkleTreeStateTransitionProofTarget { pub fn set_targets( &self, pw: &mut PartialWitness, + enabled: bool, mp: &MerkleTreeStateTransitionProof, ) -> Result<()> { let new_siblings = mp.siblings.clone(); @@ -621,11 +672,13 @@ impl MerkleTreeStateTransitionProofTarget { ))); } + pw.set_bool_target(self.enabled, enabled)?; pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?; pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?; self.op_proof.set_targets( pw, + enabled, &MerkleClaimAndProof { root: if mp.op == MerkleTreeOp::Delete { mp.new_root @@ -650,13 +703,10 @@ impl MerkleTreeStateTransitionProofTarget { { pw.set_hash_target(self.siblings[i], HashOut::from_vec(sibling.0.to_vec()))?; } - let div_lvl = if new_siblings.is_empty() { - // don't subtract since it would underflow, use MAX_DEPTH - MAX_DEPTH as u64 - } else { - (new_siblings.len() - 1) as u64 - }; - pw.set_target(self.divergence_level, F::from_canonical_u64(div_lvl))?; + pw.set_target( + self.divergence_level, + F::from_canonical_u64((new_siblings.len() - 1) as u64), + )?; Ok(()) } @@ -806,6 +856,7 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, + true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -817,42 +868,6 @@ pub mod tests { Ok(()) } - #[test] - fn test_merkleproof_pad_valid() -> Result<()> { - // circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleClaimAndProofTarget::new_virtual(32, &mut builder); - verify_merkle_proof_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, &MerkleClaimAndProof::pad())?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) - } - - #[test] - fn test_merkleproof_transition_pad_valid() -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleTreeStateTransitionProofTarget::new_virtual(32, &mut builder); - verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, &MerkleTreeStateTransitionProof::pad())?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - Ok(()) - } - #[test] fn test_merkleproof_only_existence_verify() -> Result<()> { for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { @@ -888,6 +903,7 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, + true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -963,6 +979,7 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, + true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -1008,15 +1025,32 @@ pub mod tests { let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_proof_circuit(&mut builder, &targets); - // proof of existence + // verification enabled & proof of existence let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof); - targets.set_targets(&mut pw, &mp)?; + targets.set_targets(&mut pw, true, &mp)?; // generate proof, expecting it to fail (since we're using the wrong // root) let data = builder.build::(); assert!(data.prove(pw).is_err()); + // Now generate a new proof, using `enabled=false`, which should pass the verification + // despite containing 'wrong' witness. + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); + // verification disabled & proof of existence + targets.set_targets(&mut pw, false, &mp)?; + + // generate proof, should pass despite using wrong witness, since the + // `enabled=false` + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) } @@ -1039,7 +1073,7 @@ pub mod tests { let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, state_transition_proof)?; + targets.set_targets(&mut pw, true, state_transition_proof)?; // generate & verify proof let data = builder.build::(); @@ -1236,4 +1270,71 @@ pub mod tests { assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check Ok(()) } + + #[test] + fn test_state_transition_gadget_disabled() -> Result<()> { + let max_depth: usize = 32; + let mut kvs = HashMap::new(); + for i in 0..8 { + kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); + } + let mut tree = MerkleTree::new(&kvs); + + let key = RawValue::from(37); + let value = RawValue::from(1037); + let _ = tree.insert(&key, &value)?; + + let key = RawValue::from(21); + let value = RawValue::from(1021); + let original_state_transition_proof = tree.insert(&key, &value)?; + + let mut state_transition_proof = original_state_transition_proof.clone(); + + // modify the proof, so that it should fail when `enabled=true`, by + // changing the new_root + state_transition_proof.new_root = state_transition_proof.old_root; + + run_circuit_disabled(max_depth, &state_transition_proof)?; + + // modify the proof, so that it should fail when `enabled=true`, by + // changing the new_sibling at the divergence level, which should not + // pass the verification in the case where we're inserting key=21 + let mut state_transition_proof = original_state_transition_proof.clone(); + state_transition_proof.siblings[4] = EMPTY_HASH; + + run_circuit_disabled(max_depth, &state_transition_proof)?; + + Ok(()) + } + + fn run_circuit_disabled( + max_depth: usize, + state_transition_proof: &MerkleTreeStateTransitionProof, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_state_transition_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, true, state_transition_proof)?; + + // generate proof, and expect it to fail + let data = builder.build::(); + assert!(data.prove(pw).is_err()); // expect prove to fail + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); + verify_merkle_state_transition_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, false, state_transition_proof)?; + + // generate and expect it to pass + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } } diff --git a/src/backends/plonky2/primitives/merkletree/db/mod.rs b/src/backends/plonky2/primitives/merkletree/db/mod.rs deleted file mode 100644 index 7082eaa..0000000 --- a/src/backends/plonky2/primitives/merkletree/db/mod.rs +++ /dev/null @@ -1,97 +0,0 @@ -//! Module that implements the key-value DB used at the MerkleTree module. - -use std::{ - collections::HashMap, - fmt::Debug, - sync::{Arc, Mutex}, -}; - -use anyhow::{anyhow, Result}; -use dyn_clone::DynClone; - -use crate::{ - backends::plonky2::primitives::merkletree::{Intermediate, Node}, - middleware::{Hash, EMPTY_HASH}, -}; - -#[cfg(feature = "db_rocksdb")] -pub mod rocks; - -pub trait DB: Debug + DynClone + Sync + Send { - /// Must always return the empty intermediate node when hash is EMPTY_HASH - fn load_node(&self, hash: Hash) -> Result>; - fn store_node(&mut self, node: Node) -> Result<()>; -} -dyn_clone::clone_trait_object!(DB); - -/// MemDB implements the DB trait in a in-memory HashMap. -#[derive(Clone, Debug, Default)] -pub(crate) struct MemDB { - inner: Arc>>, -} - -impl MemDB { - pub fn new() -> Self { - Self::default() - } -} - -impl DB for MemDB { - fn load_node(&self, hash: Hash) -> Result> { - let db = self - .inner - .lock() - .map_err(|e| anyhow!("failed to acquire memdb lock for read: {}", e))?; - - if hash == EMPTY_HASH { - return Ok(Some(Node::Intermediate(Intermediate::new( - EMPTY_HASH, EMPTY_HASH, - )))); - } - Ok(db.get(&hash).cloned()) - } - - fn store_node(&mut self, node: Node) -> Result<()> { - let mut db = self - .inner - .lock() - .map_err(|e| anyhow!("failed to acquire memdb lock for write: {}", e))?; - db.insert(node.hash(), node); - Ok(()) - } -} - -#[cfg(test)] -pub mod tests { - - use super::{super::Leaf, *}; - - #[test] - fn test_db() -> Result<()> { - let mut db = MemDB::new(); - test_db_opt(&mut db)?; - - #[cfg(feature = "db_rocksdb")] - { - let path = "/tmp/rocksdb"; - let mut db = rocks::RocksDB::open(path)?; - test_db_opt(&mut db)?; - } - - Ok(()) - } - - fn test_db_opt(db: &mut dyn DB) -> Result<()> { - let node = Leaf::new(1.into(), 1.into()); - db.store_node(Node::Leaf(node.clone()))?; - - let obtained_node = db.load_node(node.hash)?.unwrap(); - let leaf = match obtained_node { - Node::Leaf(l) => l, - _ => panic!("expected a leaf"), - }; - assert_eq!(leaf.hash, node.hash); - - Ok(()) - } -} diff --git a/src/backends/plonky2/primitives/merkletree/db/rocks.rs b/src/backends/plonky2/primitives/merkletree/db/rocks.rs deleted file mode 100644 index 0601983..0000000 --- a/src/backends/plonky2/primitives/merkletree/db/rocks.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::{fmt, path::Path, sync::Arc}; - -use anyhow::{anyhow, Result}; -use rocksdb::{Options, TransactionDB, TransactionDBOptions}; - -use crate::{ - backends::plonky2::primitives::merkletree::{self, db}, - middleware::{Hash, RawValue, EMPTY_HASH}, -}; - -#[derive(Clone)] -pub struct RocksDB(Arc); - -#[allow(dead_code)] -impl RocksDB { - pub fn open(path: impl AsRef) -> Result { - let mut options = Options::default(); - options.create_if_missing(true); - let txn_options = TransactionDBOptions::default(); - let inner = - TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?; - Ok(Self(Arc::new(inner))) - } -} - -impl fmt::Debug for RocksDB { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "RocksDB") - } -} - -impl db::DB for RocksDB { - fn load_node(&self, hash: Hash) -> Result> { - if hash == EMPTY_HASH { - return Ok(Some(merkletree::Node::Intermediate( - merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), - ))); - } - - match self - .0 - .get(RawValue::from(hash).to_bytes()) - .map_err(|e| anyhow!("rocksdb: get failed: {e}"))? - { - None => Ok(None), - Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)), - } - } - - fn store_node(&mut self, node: merkletree::Node) -> Result<()> { - self.0 - .put(RawValue::from(node.hash()).to_bytes(), node.encode()?) - .map_err(|e| anyhow!("rocksdb transaction put failed: {e}")) - } -} diff --git a/src/backends/plonky2/primitives/merkletree/error.rs b/src/backends/plonky2/primitives/merkletree/error.rs index 9345700..2eb3198 100644 --- a/src/backends/plonky2/primitives/merkletree/error.rs +++ b/src/backends/plonky2/primitives/merkletree/error.rs @@ -2,16 +2,12 @@ use std::{backtrace::Backtrace, fmt::Debug}; -use crate::middleware::Hash; - pub type TreeResult = core::result::Result; #[derive(Debug, thiserror::Error)] pub enum TreeInnerError { #[error("key not found")] KeyNotFound, - #[error("node with hash {0} not found")] - NodeNotFound(Hash), #[error("key already exists")] KeyExists, #[error("max depth reached")] @@ -26,9 +22,6 @@ pub enum TreeInnerError { StateTransitionProofFail(String), #[error("circuit max_depth {0} is smaller than proof depth {1}")] CircuitDepthTooSmall(usize, usize), - // Other - #[error("{0}")] - Custom(String), } #[derive(thiserror::Error)] @@ -38,8 +31,8 @@ pub enum TreeError { inner: Box, backtrace: Box, }, - #[error("database error: {0}")] - Database(anyhow::Error), + #[error("anyhow::Error: {0}")] + Anyhow(#[from] anyhow::Error), } impl Debug for TreeError { @@ -67,9 +60,6 @@ impl TreeError { pub(crate) fn key_not_found() -> Self { new!(KeyNotFound) } - pub(crate) fn node_not_found(hash: Hash) -> Self { - new!(NodeNotFound(hash)) - } pub(crate) fn key_exists() -> Self { new!(KeyExists) } @@ -91,7 +81,4 @@ impl TreeError { pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self { new!(CircuitDepthTooSmall(circuit_depth, proof_depth)) } - pub(crate) fn custom(s: impl Into) -> Self { - new!(Custom(s.into())) - } } diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index e84da20..35c4c11 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -2,7 +2,6 @@ //! . use std::{collections::HashMap, fmt, iter::IntoIterator}; -use anyhow::anyhow; use itertools::zip_eq; use plonky2::{ field::types::Field, @@ -16,15 +15,8 @@ use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F}; pub mod circuit; pub use circuit::*; -pub mod db; -pub use db::DB; pub mod error; pub use error::{TreeError, TreeResult}; -use error::{TreeError as Error, TreeResult as Result}; - -// TODO: Replace all `&RawValue` for `RawValue`. This type is very small and `Copy` so there's -// no benefit in passing a reference instead of a copy. Moreover, most of the times the value is -// being copied in methods that receive the reference: see all `*key` and `*value` in the code. /// Theoretical max depth of a merkle tree. This limits appears because we store keys of 256 bits. const MAX_DEPTH: usize = 256; @@ -33,8 +25,7 @@ const MAX_DEPTH: usize = 256; /// #[derive(Clone, Debug)] pub struct MerkleTree { - root: Hash, - db: Box, + root: Node, } impl PartialEq for MerkleTree { @@ -44,226 +35,42 @@ impl PartialEq for MerkleTree { } impl Eq for MerkleTree {} -pub(crate) fn load_node(db: &dyn DB, hash: Hash) -> Result { - match db.load_node(hash) { - Err(e) => Err(Error::Database(e)), - Ok(None) => Err(Error::node_not_found(hash)), - Ok(Some(node)) => Ok(node), - } -} -fn store_node(db: &mut dyn DB, node: Node) -> Result<()> { - match db.store_node(node) { - Ok(_) => Ok(()), - Err(e) => Err(Error::Database(e)), - } -} - impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values pub fn new(kvs: &HashMap) -> Self { - let db = db::MemDB::new(); - Self::new_with_db(Box::new(db), kvs).unwrap() - } - pub fn new_with_db(db: Box, kvs: &HashMap) -> Result { // Start with an empty node as root. - let (root, db) = { - let mut db = db; + let mut root = Node::None; - // Iterate over key-value pairs (if any) and add them. - let mut root = EMPTY_HASH; - for (k, v) in kvs.iter() { - root = Self::apply_op(db.as_mut(), MerkleTreeOp::Insert, root, *k, Some(*v))?; - } - (root, db) - }; + // Iterate over key-value pairs (if any) and add them. + for (k, v) in kvs.iter() { + root.apply_op(MerkleTreeOp::Insert, *k, Some(*v)).unwrap(); + } - Ok(Self { root, db }) - } - - pub fn empty_with_db(db: Box) -> Self { - Self::from_db(EMPTY_HASH, db) - } - - pub fn from_db(root: Hash, db: Box) -> Self { - Self { root, db } + // Fill in hashes. + let _ = root.compute_hash(); + Self { root } } /// returns the root of the tree pub fn root(&self) -> Hash { - self.root - } - - /// Goes down from the current node until it encounters a terminal node, - /// viz. a leaf or empty node, or until it reaches the maximum depth. The - /// `siblings` parameter is used to store the siblings while going down to - /// the leaf, if the given parameter is set to `None`, then no siblings are - /// stored. In this way, the same method `down` can be used by MerkleTree - /// methods `get`, `contains`, `prove` and `prove_nonexistence`. - /// - /// Be aware that this method will return the found leaf at the given path, - /// which may contain a different key and value than the expected one. And - /// while it does not return explicitly a `siblings` variable, the input - /// `siblings` is modified adding there the siblings found along the path. - fn down( - db: &dyn DB, - path_and_lvl: (Vec, usize), // path and lvl - curr_node_hash: Hash, // hash of current level node - new_key: RawValue, // key to be added/found at the leaf - mut siblings: Option<&mut Vec>, - op: MerkleTreeOp, - ) -> Result> { - let (path, lvl) = path_and_lvl; - - if lvl > MAX_DEPTH { - return Err(Error::max_depth()); - } - - if curr_node_hash == EMPTY_HASH { - return Ok(None); - } - - let node = load_node(db, curr_node_hash)?; - match node { - Node::Intermediate(n) => { - if path[lvl] { - if let Some(s) = siblings.as_mut() { - s.push(n.left); - } - Self::down(db, (path, lvl + 1), n.right, new_key, siblings, op) - } else { - if let Some(s) = siblings.as_mut() { - s.push(n.right); - } - Self::down(db, (path, lvl + 1), n.left, new_key, siblings, op) - } - } - Node::Leaf(old_leaf) => { - if op == MerkleTreeOp::ReadOnly { - return Ok(Some((old_leaf.key, old_leaf.value))); - } - - if new_key == old_leaf.key { - if op == MerkleTreeOp::Insert { - // in Insert, key should not exist - return Err(Error::key_exists()); - } - // we're at the operation Update/Delete case - return Ok(Some((old_leaf.key, old_leaf.value))); - } - - Self::down_till_divergence( - lvl, - curr_node_hash.into(), - old_leaf.path, - path, - siblings.ok_or(Error::custom("expected siblings, got None"))?, - )?; - Ok(Some((old_leaf.key, old_leaf.value))) - } - } - } - - /// goes down through a 'virtual' path till finding a divergence. This - /// method is used for when adding a new leaf another already existing leaf - /// is found, so that both leaves (new and old) are pushed down the path - /// till their keys diverge. - fn down_till_divergence( - lvl: usize, - old_key: RawValue, - old_path: Vec, - new_path: Vec, - siblings: &mut Vec, - ) -> Result<()> { - if lvl > MAX_DEPTH { - return Err(Error::max_depth()); - } - if old_path[lvl] == new_path[lvl] { - siblings.push(EMPTY_HASH); - return Self::down_till_divergence(lvl + 1, old_key, old_path, new_path, siblings); - } - // reached the divergence - siblings.push(old_key.into()); - Ok(()) - } - - /// go up recursively updating the intermediate nodes - fn up( - db: &mut dyn DB, - path: Vec, - curr_lvl: usize, - key: Hash, - siblings: Vec, - op: MerkleTreeOp, - // first_zeroes should be set to `true` when calling `up` from outside - // the method itself. It is used internally to know when to go up - // 'virtually' for the first batch of zeroes. - first_zeroes: bool, - ) -> Result { - // recall, in the delete case, the `key` is the `remaining_key` - let key_node = load_node(db, key)?; - if op == MerkleTreeOp::Delete - && first_zeroes - && matches!(key_node, Node::Leaf(..)) - && siblings[curr_lvl] == EMPTY_HASH - { - // - if we're at operation delete, the node that we're holding is a leaf, - // and we're at the first consecutive zero siblings - // - in operation Delete, go up till the first non-zero sibling and - // pair the given key with that sibling. - // This is only done for the first batch of zero siblings, that is, - // after a non-zero sibling, no matter how many zero siblings it - // has, don't do this logic anymore. - if curr_lvl == 0 { - return Ok(key); - } - return Self::up(db, path, curr_lvl - 1, key, siblings, op, true); - } - - let node = if path[curr_lvl] { - Intermediate::new(siblings[curr_lvl], key) - } else { - Intermediate::new(key, siblings[curr_lvl]) - }; - let node_hash = node.hash; // variable to avoid cloning `node` later - - // store in db - store_node(db, Node::Intermediate(node))?; - - if curr_lvl == 0 { - return Ok(node_hash); - } - Self::up(db, path, curr_lvl - 1, node_hash, siblings, op, false) + self.root.hash() } /// returns the value at the given key - pub fn get(&self, key: &RawValue) -> Result> { + pub fn get(&self, key: &RawValue) -> TreeResult { let path = keypath(*key); - let key_resolution = Self::down( - self.db.as_ref(), - (path, 0), - self.root, - *key, - None, - MerkleTreeOp::ReadOnly, - )?; + let (key_resolution, _) = self.root.down(0, path, None); match key_resolution { - Some((k, v)) if &k == key => Ok(Some(v)), - _ => Ok(None), + Some((k, v)) if &k == key => Ok(v), + _ => Err(TreeError::key_not_found()), } } /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &RawValue) -> Result { + pub fn contains(&self, key: &RawValue) -> TreeResult { let path = keypath(*key); - match Self::down( - self.db.as_ref(), - (path, 0), - self.root, - *key, - None, - MerkleTreeOp::ReadOnly, - )? { - Some((k, _)) if &k == key => Ok(true), + match self.root.down(0, path, None) { + (Some((k, _)), _) if &k == key => Ok(true), _ => Ok(false), } } @@ -272,18 +79,13 @@ impl MerkleTree { &mut self, key: &RawValue, value: &RawValue, - ) -> Result { + ) -> TreeResult { let proof_non_existence = self.prove_nonexistence(key)?; - let old_root: Hash = self.root; - - self.root = Self::apply_op( - self.db.as_mut(), - MerkleTreeOp::Insert, - self.root, - *key, - Some(*value), - )?; + let old_root: Hash = self.root.hash(); + self.root + .apply_op(MerkleTreeOp::Insert, *key, Some(*value))?; + let new_root = self.root.compute_hash(); let (v, proof) = self.prove(key)?; assert!(proof.existence); @@ -294,7 +96,7 @@ impl MerkleTree { op: MerkleTreeOp::Insert, // insertion old_root, op_proof: proof_non_existence, - new_root: self.root, + new_root, op_key: *key, op_value: *value, value: None, @@ -306,17 +108,13 @@ impl MerkleTree { &mut self, key: &RawValue, value: &RawValue, - ) -> Result { + ) -> TreeResult { let (old_value, old_proof) = self.prove(key)?; - let old_root: Hash = self.root; - self.root = Self::apply_op( - self.db.as_mut(), - MerkleTreeOp::Update, - self.root, - *key, - Some(*value), - )?; + let old_root: Hash = self.root.hash(); + self.root + .apply_op(MerkleTreeOp::Update, *key, Some(*value))?; + let new_root = self.root.compute_hash(); let (v, proof) = self.prove(key)?; assert!(proof.existence); @@ -327,7 +125,7 @@ impl MerkleTree { op: MerkleTreeOp::Update, old_root, op_proof: old_proof, - new_root: self.root, + new_root, op_key: *key, op_value: *value, value: Some(old_value), @@ -335,17 +133,12 @@ impl MerkleTree { }) } - pub fn delete(&mut self, key: &RawValue) -> Result { + pub fn delete(&mut self, key: &RawValue) -> TreeResult { let (value, proof_existence) = self.prove(key)?; - let old_root: Hash = self.root; - self.root = Self::apply_op( - self.db.as_mut(), - MerkleTreeOp::Delete, - self.root, - *key, - None, - )?; + let old_root: Hash = self.root.hash(); + self.root.apply_op(MerkleTreeOp::Delete, *key, None)?; + let new_root = self.root.compute_hash(); let proof = self.prove_nonexistence(key)?; assert!(!proof.existence); @@ -354,7 +147,7 @@ impl MerkleTree { op: MerkleTreeOp::Delete, old_root, op_proof: proof, - new_root: self.root, + new_root, op_key: *key, op_value: value, value: None, @@ -365,19 +158,13 @@ impl MerkleTree { /// returns a proof of existence, which proves that the given key exists in /// the tree. It returns the `value` of the leaf at the given `key`, and the /// `MerkleProof`. - pub fn prove(&self, key: &RawValue) -> Result<(RawValue, MerkleProof)> { + pub fn prove(&self, key: &RawValue) -> TreeResult<(RawValue, MerkleProof)> { let path = keypath(*key); let mut siblings: Vec = Vec::new(); - match Self::down( - self.db.as_ref(), - (path, 0), - self.root, - *key, - Some(&mut siblings), - MerkleTreeOp::ReadOnly, - )? { - Some((k, v)) if &k == key => Ok(( + + match self.root.down(0, path, Some(&mut siblings)) { + (Some((k, v)), _) if &k == key => Ok(( v, MerkleProof { existence: true, @@ -385,7 +172,7 @@ impl MerkleTree { other_leaf: None, }, )), - _ => Err(Error::key_not_found()), + _ => Err(TreeError::key_not_found()), } } @@ -393,43 +180,41 @@ impl MerkleTree { /// `key` does not exist in the tree. The return value specifies /// the key-value pair in the leaf reached as a result of /// resolving `key` as well as a `MerkleProof`. - pub fn prove_nonexistence(&self, key: &RawValue) -> Result { + pub fn prove_nonexistence(&self, key: &RawValue) -> TreeResult { let path = keypath(*key); let mut siblings: Vec = Vec::new(); // note: non-existence of a key can be in 2 cases: - match Self::down( - self.db.as_ref(), - (path, 0), - self.root, - *key, - Some(&mut siblings), - MerkleTreeOp::ReadOnly, - )? { + match self.root.down(0, path, Some(&mut siblings)) { // case i) the expected leaf does not exist - None => Ok(MerkleProof { + (None, _) => Ok(MerkleProof { existence: false, siblings, other_leaf: None, }), // case ii) the expected leaf does exist in the tree, but it has a different `key` - Some((k, v)) if &k != key => Ok(MerkleProof { + (Some((k, v)), _) if &k != key => Ok(MerkleProof { existence: false, siblings, other_leaf: Some((k, v)), }), - _ => Err(Error::key_exists()), + _ => Err(TreeError::key_exists()), } // both cases prove that the given key don't exist in the tree. } /// verifies an inclusion proof for the given `key` and `value` - pub fn verify(root: Hash, proof: &MerkleProof, key: &RawValue, value: &RawValue) -> Result<()> { + pub fn verify( + root: Hash, + proof: &MerkleProof, + key: &RawValue, + value: &RawValue, + ) -> TreeResult<()> { let h = proof.compute_root_from_leaf(key, Some(*value))?; if h != root { - Err(Error::proof_fail("inclusion".to_string())) + Err(TreeError::proof_fail("inclusion".to_string())) } else { Ok(()) } @@ -437,16 +222,18 @@ impl MerkleTree { /// verifies a non-inclusion proof for the given `key`, that is, the given /// `key` does not exist in the tree - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &RawValue) -> Result<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &RawValue) -> TreeResult<()> { match proof.other_leaf { - Some((k, _v)) if &k == key => Err(Error::invalid_proof("non-existence".to_string())), + Some((k, _v)) if &k == key => { + Err(TreeError::invalid_proof("non-existence".to_string())) + } _ => { let k = proof.other_leaf.map(|(k, _)| k).unwrap_or(*key); let v: Option = proof.other_leaf.map(|(_, v)| v); let h = proof.compute_root_from_leaf(&k, v)?; if h != root { - Err(Error::proof_fail("exclusion".to_string())) + Err(TreeError::proof_fail("exclusion".to_string())) } else { Ok(()) } @@ -454,7 +241,7 @@ impl MerkleTree { } } - pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> TreeResult<()> { let mut old_siblings = proof.op_proof.siblings.clone(); let new_siblings = proof.siblings.clone(); @@ -470,17 +257,12 @@ impl MerkleTree { Self::verify_state_transition(&equivalent_insertion_proof) } MerkleTreeOp::Update => { - if proof.value.is_none() { - return Err(Error::state_transition_fail( - "Invalid proof of update: proof.value should not be None".to_string(), - )); - } // check that for the old_root, (op_key, value) *does* exist in the tree Self::verify( proof.old_root, &proof.op_proof, &proof.op_key, - &proof.value.unwrap(), // unrawp is safe due prev `is_none` check + &proof.value.unwrap(), )?; // check that for the new_root, (op_key, op_value) *does* exist in the tree Self::verify( @@ -497,7 +279,7 @@ impl MerkleTree { // All siblings should agree (proof.siblings == proof.op_proof.siblings) .then_some(()) - .ok_or(Error::state_transition_fail(format!( + .ok_or(TreeError::state_transition_fail(format!( "Invalid proof of update for key {}: Siblings don't match.", proof.op_key ))) @@ -526,11 +308,11 @@ impl MerkleTree { let divergence_lvl: usize = match zip_eq(old_path, new_path).position(|(x, y)| x != y) { Some(d) => d, - None => return Err(Error::max_depth()), + None => return Err(TreeError::max_depth()), }; if divergence_lvl != new_siblings.len() - 1 { - return Err(Error::state_transition_fail( + return Err(TreeError::state_transition_fail( "paths divergence does not match".to_string(), )); } @@ -546,7 +328,7 @@ impl MerkleTree { if new_siblings.is_empty() { return (old_siblings.is_empty() && proof.old_root == EMPTY_HASH) .then_some(()) - .ok_or(Error::state_transition_fail( + .ok_or(TreeError::state_transition_fail( "new tree has no siblings yet old tree is not the empty tree" .to_string(), )); @@ -556,14 +338,14 @@ impl MerkleTree { old_siblings.resize(d + 1, EMPTY_HASH); for i in 0..d { if old_siblings[i] != new_siblings[i] { - return Err(Error::state_transition_fail( + return Err(TreeError::state_transition_fail( "siblings don't match: old[i]!=new[i] ∀ i (except at i==d)".to_string(), )); } } if old_siblings[d] != new_siblings[d] { if old_siblings[d] != EMPTY_HASH { - return Err(Error::state_transition_fail( + return Err(TreeError::state_transition_fail( "siblings don't match: old[d]!=empty".to_string(), )); } @@ -571,137 +353,27 @@ impl MerkleTree { .op_proof .other_leaf .map(|(k, _)| k) - .ok_or(Error::state_transition_fail( + .ok_or(TreeError::state_transition_fail( "proof.proof_non_existence.other_leaf can not be empty for the case old_siblings[d]!=new_siblings[d]".to_string() ))?; let v: Option = proof.op_proof.other_leaf.map(|(_, v)| v); let old_leaf_hash = kv_hash(&k, v); if new_siblings[d] != old_leaf_hash { - return Err(Error::state_transition_fail( + return Err(TreeError::state_transition_fail( "siblings don't match: new[d]!=old_leaf_hash".to_string(), )); } } Ok(()) } - _ => Err(Error::invalid_proof("proof.op".to_string())), } } -} -// auxiliary methods -impl MerkleTree { - /// Applies given Merkle tree op. - pub(crate) fn apply_op( - db: &mut dyn DB, - op: MerkleTreeOp, - root: Hash, - k: RawValue, - maybe_value: Option, - ) -> Result { - // Rule out invalid arguments - match (op, maybe_value) { - (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { - Err(Error::invalid_state_transition_proof_arg(format!( - "{:?} op requires a value argument.", - op - ))) - } - (MerkleTreeOp::Delete, Some(_)) => { - Err(Error::invalid_state_transition_proof_arg(format!( - "{:?} op requires no value argument, yet one was provided.", - op - ))) - } - (MerkleTreeOp::ReadOnly, _) => Err(Error::invalid_state_transition_proof_arg(format!( - "{:?} 'read only' op should not reach the 'apply_op' method", - op - ))), - _ => Ok(()), - }?; - - // go down, update the leaf, go up storing new hashes in the db - let path = keypath(k); - let mut siblings: Vec = Vec::new(); - let _ = Self::down( - db, - (path.clone(), 0), // from lvl 0 - root, - k, - Some(&mut siblings), - op, - )?; - - let node: Node = match (op, maybe_value) { - (MerkleTreeOp::Insert, Some(value)) | (MerkleTreeOp::Update, Some(value)) => { - Node::Leaf(Leaf::new(k, value)) - } - (MerkleTreeOp::Delete, None) => { - // return a node whose hash is 'empty', to indicate that there is no leaf - Node::Intermediate(Intermediate { - hash: EMPTY_HASH, - left: EMPTY_HASH, - right: EMPTY_HASH, - }) - } - _ => { - return Err(Error::invalid_state_transition_proof_arg(format!( - "{:?} op has invalid value type: {:?}", - op, maybe_value - ))) - } - }; - let node_hash = node.hash(); // variable to avoid cloning `leaf` later - store_node(db, node)?; - if siblings.is_empty() { - // return the leaf's hash as root - return Ok(node_hash); + /// returns an iterator over the leaves of the tree + pub fn iter(&self) -> Iter<'_> { + Iter { + state: vec![&self.root], } - - let new_root = if op == MerkleTreeOp::Delete { - if siblings.len() == 1 { - // we're at the root-1 level, there is only a sibling, and we're - // removing the current leaf. - // If the sibling is a Leaf, the sibling (leaf) is now the new root - let sibling_node = load_node(db, siblings[0])?; - if matches!(sibling_node, Node::Leaf(..)) { - return Ok(siblings[0]); - } - // if the sibling is an Intermediate node, it means that the - // branch goes deeper, so don't short the path going up and pair - // it with an empty hash. - let node = if path[0] { - Intermediate::new(siblings[0], EMPTY_HASH) - } else { - Intermediate::new(EMPTY_HASH, siblings[0]) - }; - let node_hash = node.hash; // variable to avoid cloning `node` later - - // store in db - store_node(db, Node::Intermediate(node))?; - return Ok(node_hash); - } - // use the last sibling as the key that we will push up from - let l = siblings.len() - 1; - let remaining_key = siblings[l]; - siblings[l] = EMPTY_HASH; - // invert the last sibling level - let mut path = path.clone(); - path[siblings.len() - 1] = !path[siblings.len() - 1]; - Self::up( - db, - path, - siblings.len() - 1, - remaining_key, - siblings, - op, - true, - )? - } else { - Self::up(db, path, siblings.len() - 1, node_hash, siblings, op, true)? - }; - - Ok(new_root) } } @@ -750,51 +422,14 @@ fn hash_with_flag(flag: F, inputs: &[F]) -> Hash { } } -impl MerkleTree { - /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> Iter { - Iter { - state: if self.root == EMPTY_HASH { - vec![] - } else { - vec![self.root] - }, - db: self.db.clone(), - } - } -} -impl IntoIterator for &MerkleTree { - type Item = (RawValue, RawValue); - type IntoIter = Iter; +impl<'a> IntoIterator for &'a MerkleTree { + type Item = (&'a RawValue, &'a RawValue); + type IntoIter = Iter<'a>; + fn into_iter(self) -> Self::IntoIter { self.iter() } } -pub struct Iter { - state: Vec, - db: Box, -} -impl Iterator for Iter { - type Item = (RawValue, RawValue); - fn next(&mut self) -> Option { - let node_hash = self.state.pop()?; - - // Inspect node - let node = load_node(self.db.as_ref(), node_hash).ok()?; - - match node { - Node::Leaf(Leaf { key, value, .. }) => Some((key, value)), - Node::Intermediate(Intermediate { left, right, .. }) => { - [right, left].into_iter().for_each(|h| { - if h != EMPTY_HASH { - self.state.push(h) - } - }); - self.next() - } - } - } -} impl fmt::Display for MerkleTree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -804,53 +439,11 @@ impl fmt::Display for MerkleTree { )?; writeln!(f, "digraph hierarchy {{")?; writeln!(f, "node [fontname=Monospace,fontsize=10,shape=box]")?; - print_graph_viz(f, self.db.as_ref(), self.root)?; + write!(f, "{}", self.root)?; writeln!(f, "\n}}\n-----") } } -fn print_graph_viz(f: &mut fmt::Formatter<'_>, db: &dyn DB, hash: Hash) -> fmt::Result { - if hash == EMPTY_HASH { - return Ok(()); - } - - let node = load_node(db, hash).map_err(|_| fmt::Error)?; - match node { - Node::Intermediate(n) => { - let left_hash: String = if n.left == EMPTY_HASH { - writeln!( - f, - "\"{}_child_of_{}\" [label=\"{}\"]", - n.left, n.hash, n.left - )?; - format!("\"{}_child_of_{}\"", n.left, n.hash) - } else { - writeln!(f, "\"{}\"", n.left)?; - format!("\"{}\"", n.left) - }; - let right_hash = if n.right == EMPTY_HASH { - writeln!( - f, - "\"{}_child_of_{}\" [label=\"{}\"]", - n.right, n.hash, n.right - )?; - format!("\"{}_child_of_{}\"", n.right, n.hash) - } else { - writeln!(f, "\"{}\"", n.right,)?; - format!("\"{}\"", n.right) - }; - writeln!(f, "\"{}\" -> {{ {} {} }}", n.hash, left_hash, right_hash,)?; - print_graph_viz(f, db, n.left)?; - print_graph_viz(f, db, n.right) - } - Node::Leaf(l) => { - writeln!(f, "\"{}\" [style=filled]", l.hash)?; - writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value)?; - writeln!(f, "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", l.hash, l.key, l.value,) - } - } -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct MerkleProof { // note: currently we don't use the `_existence` field, we would use if we merge the methods @@ -878,15 +471,12 @@ impl MerkleProof { /// Computes the root of the Merkle tree suggested by a Merkle proof given a /// key & value. If a value is not provided, the terminal node is assumed to /// be empty. - fn compute_root_from_leaf(&self, key: &RawValue, value: Option) -> Result { + fn compute_root_from_leaf(&self, key: &RawValue, value: Option) -> TreeResult { let path = keypath(*key); let h = kv_hash(key, value); self.compute_root_from_node(&h, path) } - fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> Result { - if self.siblings.len() > MAX_DEPTH { - return Err(Error::max_depth()); - } + fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> TreeResult { let mut h = *node_hash; for (i, sibling) in self.siblings.iter().enumerate().rev() { let input: Vec = if path[i] { @@ -921,21 +511,6 @@ impl MerkleClaimAndProof { }, } } - /// Value used for padding. This is a valid merkle proof. - pub fn pad() -> Self { - let [key, value] = [EMPTY_VALUE, EMPTY_VALUE]; - let root = kv_hash(&key, Some(value)); - Self { - root, - key, - value, - proof: MerkleProof { - existence: true, - siblings: vec![], - other_leaf: None, - }, - } - } pub fn new(root: Hash, key: RawValue, value: Option, proof: MerkleProof) -> Self { Self { root, @@ -957,7 +532,6 @@ pub enum MerkleTreeOp { Insert = 0, Update, Delete, - ReadOnly, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -989,6 +563,7 @@ pub struct MerkleTreeStateTransitionProof { } impl MerkleTreeStateTransitionProof { + /// Value used for padding. pub fn empty() -> Self { let empty_proof_and_claim = MerkleClaimAndProof::empty(); Self { @@ -1002,83 +577,380 @@ impl MerkleTreeStateTransitionProof { siblings: vec![], } } - /// Value used for padding. This is a valid transition proof. - pub fn pad() -> Self { - let pad_proof_and_claim = MerkleClaimAndProof::pad(); - Self { - op: MerkleTreeOp::Update, - old_root: pad_proof_and_claim.root, - op_proof: pad_proof_and_claim.proof, - new_root: pad_proof_and_claim.root, - op_key: pad_proof_and_claim.key, - op_value: pad_proof_and_claim.value, - value: Some(pad_proof_and_claim.value), - siblings: vec![], - } - } } -// NOTE: currently we use automatic serialization/deserialization, which is -// used when storing the node into the DB; but we could manually implement it -// for more disk-space efficiency. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub enum Node { +#[derive(Clone, Debug)] +enum Node { + None, Leaf(Leaf), Intermediate(Intermediate), } -impl Node { - pub fn hash(&self) -> Hash { + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Node::Leaf(Leaf { hash, .. }) => *hash, - Node::Intermediate(Intermediate { hash, .. }) => *hash, + Self::Intermediate(n) => { + let left_hash: String = if n.left.is_empty() { + writeln!( + f, + "\"{}_child_of_{}\" [label=\"{}\"]", + n.left.hash(), + n.hash(), + n.left.hash() + )?; + format!("\"{}_child_of_{}\"", n.left.hash(), n.hash()) + } else { + writeln!(f, "\"{}\"", n.left.hash(),)?; + format!("\"{}\"", n.left.hash()) + }; + let right_hash = if n.right.is_empty() { + writeln!( + f, + "\"{}_child_of_{}\" [label=\"{}\"]", + n.right.hash(), + n.hash(), + n.right.hash() + )?; + format!("\"{}_child_of_{}\"", n.right.hash(), n.hash()) + } else { + writeln!(f, "\"{}\"", n.right.hash(),)?; + format!("\"{}\"", n.right.hash()) + }; + writeln!(f, "\"{}\" -> {{ {} {} }}", n.hash(), left_hash, right_hash,)?; + write!(f, "{}", n.left)?; + write!(f, "{}", n.right) + } + Self::Leaf(l) => { + writeln!(f, "\"{}\" [style=filled]", l.hash())?; + writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value)?; + writeln!( + f, + "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", + l.hash(), + l.key, + l.value, + ) + } + Self::None => Ok(()), } } - // NOTE: this can be replaced by `.to_bytes` & `from_bytes` optimized methods at `Node` - pub fn encode(&self) -> Result, anyhow::Error> { - serde_json::to_vec(self).map_err(|e| anyhow!("failed to serialize node: {e}")) - } - pub fn decode(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(|e| anyhow!("failed to deserialize node: {e}")) - } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Intermediate { - hash: Hash, - left: Hash, - right: Hash, -} -impl Intermediate { - pub fn new(left: Hash, right: Hash) -> Self { - if left == EMPTY_HASH && right == EMPTY_HASH { - return Self { - hash: EMPTY_HASH, +impl Node { + fn is_empty(&self) -> bool { + match self { + Self::None => true, + Self::Leaf(_l) => false, + Self::Intermediate(_n) => false, + } + } + fn is_intermediate(&self) -> bool { + match self { + Self::None => false, + Self::Leaf(_l) => false, + Self::Intermediate(_n) => true, + } + } + fn compute_hash(&mut self) -> Hash { + match self { + Self::None => EMPTY_HASH, + Self::Leaf(l) => l.compute_hash(), + Self::Intermediate(n) => n.compute_hash(), + } + } + fn hash(&self) -> Hash { + match self { + Self::None => EMPTY_HASH, + Self::Leaf(l) => l.hash(), + Self::Intermediate(n) => n.hash(), + } + } + + /// Goes down from the current node until it encounters a terminal node, + /// viz. a leaf or empty node, or until it reaches the maximum depth. The + /// `siblings` parameter is used to store the siblings while going down to + /// the leaf, if the given parameter is set to `None`, then no siblings are + /// stored. In this way, the same method `down` can be used by MerkleTree + /// methods `get`, `contains`, `prove` and `prove_nonexistence`. + /// + /// Be aware that this method will return the found leaf at the given path, + /// which may contain a different key and value than the expected one. + fn down( + &self, + lvl: usize, + path: Vec, + mut siblings: Option<&mut Vec>, + ) -> (Option<(RawValue, RawValue)>, usize) { + match self { + Self::Intermediate(n) => { + if path[lvl] { + if let Some(s) = siblings.as_mut() { + s.push(n.left.hash()); + } + n.right.down(lvl + 1, path, siblings) + } else { + if let Some(s) = siblings.as_mut() { + s.push(n.right.hash()); + } + n.left.down(lvl + 1, path, siblings) + } + } + Self::Leaf(Leaf { + key, + value, + path: _p, + hash: _h, + }) => (Some((*key, *value)), lvl), + _ => (None, lvl), + } + } + + /// Applies given Merkle tree op without computing hashes. + pub(crate) fn apply_op( + &mut self, + op: MerkleTreeOp, + key: RawValue, + maybe_value: Option, + ) -> TreeResult<()> { + let key_path = keypath(key); + // Rule out invalid arguments + match (op, maybe_value) { + (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { + Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} op requires a value argument.", + op + ))) + } + (MerkleTreeOp::Delete, Some(_)) => { + Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} op requires no value argument, yet one was provided.", + op + ))) + } + _ => Ok(()), + }?; + + // Loop through to leaf. + self.apply_op_loop(0, op, key, &key_path, maybe_value)?; + + // If we are dealing with a deletion, normalise along key + // path. + if let MerkleTreeOp::Delete = op { + self.normalise_path(&key_path); + } + + Ok(()) + } + + /// Normalises a Merkle tree along a specified path. Useful + /// post-deletion. + fn normalise_path(&mut self, key_path: &[bool]) { + match self { + Self::Leaf(_) | Self::None => (), + Self::Intermediate(Intermediate { + hash: _h, left, right, - }; + }) => { + if key_path[0] { + right.normalise_path(&key_path[1..]); + } else { + left.normalise_path(&key_path[1..]); + } + + // If we have a branch with children (NIL, X) or (X, + // NIL) where X is not a branch, then replace with X. + if left.is_empty() && !right.is_intermediate() { + *self = *right.clone(); + } else if right.is_empty() && !left.is_intermediate() { + *self = *left.clone(); + } + } } - let input: Vec = [left.0.to_vec(), right.0.to_vec()].concat(); - let hash = hash_with_flag(F::TWO, &input); - Self { hash, left, right } + } + + fn apply_op_loop( + &mut self, + lvl: usize, + op: MerkleTreeOp, + key: RawValue, + key_path: &[bool], + maybe_value: Option, + ) -> TreeResult<()> { + match self { + Self::Intermediate(n) => { + if key_path[lvl] { + n.right + .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) + } else { + n.left + .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) + } + } + _ => { + *self = Self::op_node_check(lvl, self, op, key, key_path, maybe_value)?; + Ok(()) + } + } + } + + /// Checks the terminal node against the desired op and returns a + /// suitable replacement. + /// + /// - Insertion => Node should be empty or contain a different + /// key. A leaf is inserted in the right place. + /// - Update/Deletion => Node should contain the given key. The + /// value is replaced in the case of an update and the leaf removed + /// in the case of a deletion. + pub(crate) fn op_node_check( + lvl: usize, + node: &Node, + op: MerkleTreeOp, + key: RawValue, + key_path: &[bool], + maybe_value: Option, + ) -> TreeResult { + use MerkleTreeOp::*; + + // Invalid args are assumed to have been ruled out. + match (op, node, maybe_value) { + // Insertion case + (Insert, Node::None, Some(value)) => Ok(Node::Leaf(Leaf::new(key, value))), + (Insert, Node::Leaf(l), Some(value)) => { + // in this case, it means that we found a leaf in the new-leaf + // path, thus we need to push both leaves (old-leaf and + // new-leaf) down the path till their paths diverge. + + // first check that keys of both leaves are different + // (l=old-leaf, leaf=new-leaf) + if l.key == key { + // Note: current approach returns an error when trying to + // add to a leaf where the key already exists. We could also + // ignore it if needed. + Err(TreeError::key_exists()) + } else { + let old_leaf = l.clone(); + // set new node as an intermediate node + let mut new_node = Node::Intermediate(Intermediate::empty()); + new_node.down_till_divergence( + lvl, + old_leaf, + Leaf { + hash: None, + path: key_path.to_vec(), + key, + value, + }, + )?; + Ok(new_node) + } + } + // Update case + (Update, Node::Leaf(l), Some(value)) if l.key == key => { + Ok(Node::Leaf(Leaf::new(key, value))) + } + // Deletion case + (Delete, Node::Leaf(l), None) if l.key == key => Ok(Node::None), + // Case of terminal node that does not match. + _ => Err(TreeError::state_transition_fail(format!( + "{:?} op requires key {} to be present in the tree, yet it is not.", + op, key + ))), + } + } + + /// goes down through a 'virtual' path till finding a divergence. This + /// method is used for when adding a new leaf another already existing leaf + /// is found, so that both leaves (new and old) are pushed down the path + /// till their keys diverge. + fn down_till_divergence( + &mut self, + lvl: usize, + old_leaf: Leaf, + new_leaf: Leaf, + ) -> TreeResult<()> { + if let Node::Intermediate(ref mut n) = self { + if old_leaf.path[lvl] != new_leaf.path[lvl] { + // reached divergence in next level, set the leaves as children + // at the current node + if new_leaf.path[lvl] { + n.left = Box::new(Node::Leaf(old_leaf)); + n.right = Box::new(Node::Leaf(new_leaf)); + } else { + n.left = Box::new(Node::Leaf(new_leaf)); + n.right = Box::new(Node::Leaf(old_leaf)); + } + return Ok(()); + } + + // no divergence yet, continue going down + if new_leaf.path[lvl] { + n.right = Box::new(Node::Intermediate(Intermediate::empty())); + return n.right.down_till_divergence(lvl + 1, old_leaf, new_leaf); + } else { + n.left = Box::new(Node::Intermediate(Intermediate::empty())); + return n.left.down_till_divergence(lvl + 1, old_leaf, new_leaf); + } + } + Ok(()) } } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Leaf { - pub(crate) hash: Hash, +#[derive(Clone, Debug)] +struct Intermediate { + hash: Option, + left: Box, + right: Box, +} +impl Intermediate { + fn empty() -> Self { + Self { + hash: None, + left: Box::new(Node::None), + right: Box::new(Node::None), + } + } + fn compute_hash(&mut self) -> Hash { + if self.left.clone().is_empty() && self.right.clone().is_empty() { + self.hash = Some(EMPTY_HASH); + return EMPTY_HASH; + } + let l_hash = self.left.compute_hash(); + let r_hash = self.right.compute_hash(); + let input: Vec = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat(); + let h = hash_with_flag(F::TWO, &input); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.expect("Hash has not been computed.") + } +} + +#[derive(Clone, Debug)] +pub(crate) struct Leaf { + pub(crate) hash: Option, pub(crate) path: Vec, pub(crate) key: RawValue, pub(crate) value: RawValue, } impl Leaf { - pub fn new(key: RawValue, value: RawValue) -> Self { + fn new(key: RawValue, value: RawValue) -> Self { Self { - hash: kv_hash(&key, Some(value)), + hash: None, path: keypath(key), key, value, } } + fn compute_hash(&mut self) -> Hash { + let h = kv_hash(&self.key, Some(self.value)); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.expect("Hash has not been computed.") + } } // NOTE 1: think if maybe the length of the returned vector can be <256 @@ -1096,6 +968,37 @@ pub(crate) fn keypath(k: RawValue) -> Vec { .collect() } +pub struct Iter<'a> { + state: Vec<&'a Node>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a RawValue, &'a RawValue); + + fn next(&mut self) -> Option { + let node = self.state.pop(); + match node { + Some(Node::None) => self.next(), + Some(Node::Leaf(Leaf { + hash: _, + path: _, + key, + value, + })) => Some((key, value)), + Some(Node::Intermediate(Intermediate { + hash: _, + left, + right, + })) => { + self.state.push(right); + self.state.push(left); + self.next() + } + _ => None, + } + } +} + #[cfg(test)] pub mod tests { use std::cmp::Ordering; @@ -1105,21 +1008,7 @@ pub mod tests { use super::*; #[test] - fn test_merkletree() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_merkletree_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_merkletree_opt(db)?; - } - - Ok(()) - } - fn test_merkletree_opt(db: Box) -> Result<()> { + fn test_merkletree() -> TreeResult<()> { let mut kvs = HashMap::new(); for i in 0..8 { if i == 1 { @@ -1131,7 +1020,7 @@ pub mod tests { let value = RawValue::from(1013); kvs.insert(key, value); - let tree = MerkleTree::new_with_db(db, &kvs)?; + let tree = MerkleTree::new(&kvs); // when printing the tree, it should print the same tree as in // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); @@ -1184,8 +1073,8 @@ pub mod tests { }; let sorted_kvs = kvs - .into_iter() - .sorted_by(|(k1, _), (k2, _)| cmp(*k1, *k2)) + .iter() + .sorted_by(|(k1, _), (k2, _)| cmp(**k1, **k2)) .collect::>(); assert_eq!(collected_kvs, sorted_kvs); @@ -1194,326 +1083,13 @@ pub mod tests { } #[test] - fn test_merkletree_pad() { - let claim = MerkleClaimAndProof::pad(); - MerkleTree::verify(claim.root, &claim.proof, &claim.key, &claim.value).unwrap(); - - let proof = MerkleTreeStateTransitionProof::pad(); - MerkleTree::verify_state_transition(&proof).unwrap(); - } - - #[test] - fn test_key_not_found() -> Result<()> { - let db = Box::new(db::MemDB::new()); - let mut tree = MerkleTree::empty_with_db(db.clone()); - let value_option = tree.get(&RawValue::from(5)).unwrap(); - assert_eq!(None, value_option); - - tree.insert(&RawValue::from(1), &RawValue::from(42))?; - let value_option = tree.get(&RawValue::from(5)).unwrap(); - assert_eq!(None, value_option); - - // If the root doesn't exist there should be an error - let tree = MerkleTree::from_db(Hash::from(RawValue::from(42)), db); - let result = tree.get(&RawValue::from(5)); - assert!(result.is_err()); - - Ok(()) - } - - #[test] - fn test_delete_to_empty() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_delete_to_empty_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_delete_to_empty_opt(db)?; - } - - Ok(()) - } - fn test_delete_to_empty_opt(db: Box) -> Result<()> { - let mut tree = MerkleTree::new_with_db(db, &HashMap::new())?; - - let (key, value) = (RawValue::from(2), RawValue::from(1002)); - let _ = tree.insert(&key, &value)?; - - let (key, value) = (RawValue::from(6), RawValue::from(1006)); - let _ = tree.insert(&key, &value)?; - - let (key, value) = (RawValue::from(3), RawValue::from(1003)); - let _ = tree.insert(&key, &value)?; - - let (key, value) = (RawValue::from(7), RawValue::from(1007)); - let _ = tree.insert(&key, &value)?; - - let _ = tree.delete(&RawValue::from(3))?; - let _ = tree.delete(&RawValue::from(7))?; - let _ = tree.delete(&RawValue::from(6))?; - assert_eq!( - tree.root, - Leaf::new(RawValue::from(2), RawValue::from(1002)).hash - ); - - let _ = tree.delete(&RawValue::from(2))?; - assert_eq!(tree.root, EMPTY_HASH); - - Ok(()) - } - - #[test] - fn test_prove_verify() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_prove_verify_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_prove_verify_opt(db)?; - } - - Ok(()) - } - fn test_prove_verify_opt(db: Box) -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let tree = MerkleTree::new_with_db(db, &kvs)?; - - let (key, value) = (175.into(), 0.into()); - let (v, proof) = tree.prove(&key)?; - assert_eq!(v, value); - MerkleTree::verify(tree.root(), &proof, &key, &value)?; - - let (key, value) = (2.into(), 88.into()); - let (v, proof) = tree.prove(&key)?; - assert_eq!(v, value); - MerkleTree::verify(tree.root(), &proof, &key, &value)?; - - let (key, value) = (175.into(), 0.into()); - let (v, proof) = tree.prove(&key)?; - assert_eq!(v, value); - MerkleTree::verify(tree.root(), &proof, &key, &value)?; - - Ok(()) - } - - #[test] - fn test_update_leaf() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_update_leaf_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_update_leaf_opt(db)?; - } - - Ok(()) - } - fn test_update_leaf_opt(db: Box) -> Result<()> { - let kvs = [ - (1.into(), 1.into()), - (9.into(), 9.into()), - (7.into(), 7.into()), - (15.into(), 15.into()), - ] - .into_iter() - .collect(); - let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; - let state_transition_proof = tree.update(&7.into(), &0.into())?; - MerkleTree::verify_state_transition(&state_transition_proof)?; - - let kvs = [ - (1.into(), 1.into()), - (9.into(), 9.into()), - (7.into(), 0.into()), - (15.into(), 15.into()), - ] - .into_iter() - .collect(); - let tree2 = MerkleTree::new_with_db(db, &kvs)?; - - assert_eq!(tree.root, tree2.root); - - // update the other leaves - let state_transition_proof = tree.update(&1.into(), &0.into())?; - MerkleTree::verify_state_transition(&state_transition_proof)?; - let state_transition_proof = tree.update(&9.into(), &0.into())?; - MerkleTree::verify_state_transition(&state_transition_proof)?; - let state_transition_proof = tree.update(&15.into(), &0.into())?; - MerkleTree::verify_state_transition(&state_transition_proof) - } - - #[test] - fn test_update_delete_leaf() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_update_delete_leaf_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_update_delete_leaf_opt(db)?; - } - - Ok(()) - } - fn test_update_delete_leaf_opt(db: Box) -> Result<()> { - let kvs: HashMap = (0..10) - .map(|i| (i.into(), i.into())) - .collect::>(); - let mut mt = MerkleTree::new_with_db(db, &kvs)?; - - // insert - (11..20) - .map(|i| (i.into(), i.into())) - .try_for_each(|(k, v)| { - let mtp = mt.insert(&k, &v).unwrap(); - MerkleTree::verify_state_transition(&mtp) - })?; - // update - (11..20) - .map(|i| (i.into(), (i + 1).into())) - .try_for_each(|(k, v)| { - let mtp = mt.update(&k, &v).unwrap(); - MerkleTree::verify_state_transition(&mtp) - })?; - // delete - (11..20).map(|i| i.into()).try_for_each(|k| { - let mtp = mt.delete(&k).unwrap(); - MerkleTree::verify_state_transition(&mtp) - })?; - - Ok(()) - } - - #[test] - fn test_delete_leaf() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_delete_leaf_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_delete_leaf_opt(db)?; - } - - Ok(()) - } - fn test_delete_leaf_opt(db: Box) -> Result<()> { - let kvs = [(1.into(), 1.into()), (9.into(), 9.into())] - .into_iter() - .collect(); - let tree = MerkleTree::new_with_db(db.clone(), &kvs)?; - let expected_root = tree.root; - - let kvs = [ - (1.into(), 1.into()), - (9.into(), 9.into()), - (7.into(), 7.into()), - (15.into(), 15.into()), - ] - .into_iter() - .collect(); - let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; - let state_transition_proof = tree.delete(&15.into())?; - MerkleTree::verify_state_transition(&state_transition_proof)?; - - let kvs = [ - (1.into(), 1.into()), - (9.into(), 9.into()), - (7.into(), 7.into()), - ] - .into_iter() - .collect(); - let tree2 = MerkleTree::new_with_db(db, &kvs)?; - - assert_eq!(tree.root, tree2.root); - - // delete the leaf '7', which when deleted will leave an entire branch - // empty - let state_transition_proof = tree.delete(&7.into())?; - MerkleTree::verify_state_transition(&state_transition_proof)?; - - assert_eq!(tree.root, expected_root); - - Ok(()) - } - - #[test] - fn test_delete_from_two_leaves() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_delete_from_two_leaves_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_delete_from_two_leaves_opt(db)?; - } - - Ok(()) - } - fn test_delete_from_two_leaves_opt(db: Box) -> Result<()> { - // tree with two leaves whose keys diverge at the first bit, so that when - // deleting one key leads to a tree with a single Leaf as a root - let mut kvs = HashMap::new(); - kvs.insert(RawValue::from(0), RawValue::from(1000)); - kvs.insert(RawValue::from(1), RawValue::from(1001)); - - let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; - tree.delete(&RawValue::from(1))?; - - // the expected_tree has a single leaf, which should match the tree that - // started from two leaves and got one removed - let expected = [(RawValue::from(0), RawValue::from(1000))] - .into_iter() - .collect::>(); - let expected_tree = MerkleTree::new_with_db(db, &expected)?; - - assert_eq!(tree.root(), expected_tree.root()); - Ok(()) - } - - #[test] - fn test_state_transition() -> Result<()> { - let db = Box::new(db::MemDB::new()); - test_state_transition_opt(db)?; - - #[cfg(feature = "db_rocksdb")] - { - let db = Box::new( - db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), - ); - test_state_transition_opt(db)?; - } - - Ok(()) - } - fn test_state_transition_opt(db: Box) -> Result<()> { + fn test_state_transition() -> TreeResult<()> { let mut kvs = HashMap::new(); for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new_with_db(db, &kvs)?; + let mut tree = MerkleTree::new(&kvs); let old_root = tree.root(); // key=37 shares path with key=5, till the level 6, needing 2 extra diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 0780c7e..2b490f9 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -180,7 +180,11 @@ impl EthDosHelper { }; assert_eq!(int, Value::from(int_attestation.public_key)); - let n_i64 = n.as_int().unwrap(); + let n_i64 = if let TypedValue::Int(x) = n.typed() { + *x + } else { + panic!("distance value is not Int") + }; // eth_dos src->dst dist=n+1 self.n_plus_1(&mut pod, eth_dos_int_to_dst, int_attestation, n_i64)?; diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 8de6871..92fdc4f 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -18,8 +18,6 @@ pub enum BuilderArg { /// Key: (origin, key), where origin is Wildcard and key is Key Key(String, String), WildcardLiteral(String), - /// Reference to a same-batch predicate's identity hash (resolved by name in finish()). - SelfPredicateHash(String), } /// When defining a `BuilderArg`, it can be done from 3 different inputs: @@ -132,8 +130,6 @@ pub struct CustomPredicateBatchBuilder { params: Params, pub name: String, pub predicates: Vec, - /// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name) - pending_self_pred_hashes: Vec<(usize, usize, usize, String)>, } impl CustomPredicateBatchBuilder { @@ -142,7 +138,6 @@ impl CustomPredicateBatchBuilder { params, name, predicates: Vec::new(), - pending_self_pred_hashes: Vec::new(), } } @@ -176,12 +171,6 @@ impl CustomPredicateBatchBuilder { priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { - if self.predicates.iter().any(|p| p.name == name) { - return Err(Error::custom(format!( - "Duplicate predicate name '{}' in batch", - name - ))); - } if self.predicates.len() >= Params::max_custom_batch_size() { return Err(Error::max_length( "self.predicates.len".to_string(), @@ -205,18 +194,14 @@ impl CustomPredicateBatchBuilder { )); } - let pred_idx = self.predicates.len(); - let mut pending = Vec::new(); let statements = sts .iter() - .enumerate() - .map(|(stmt_idx, sb)| { + .map(|sb| { let stb = sb.clone().desugar(); let st_tmpl_args = stb .args .iter() - .enumerate() - .map(|(arg_idx, a)| { + .map(|a| { Ok::<_, Error>(match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( @@ -226,22 +211,6 @@ impl CustomPredicateBatchBuilder { BuilderArg::WildcardLiteral(v) => { StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) } - BuilderArg::SelfPredicateHash(pred_name) => { - // Try backward reference first - match self.predicates.iter().position(|p| p.name == *pred_name) { - Some(index) => StatementTmplArg::SelfPredicateHash(index), - None => { - // Forward reference - placeholder, resolved in finish() - pending.push(( - pred_idx, - stmt_idx, - arg_idx, - pred_name.clone(), - )); - StatementTmplArg::SelfPredicateHash(0) - } - } - } }) }) .collect::>()?; @@ -271,27 +240,11 @@ impl CustomPredicateBatchBuilder { .collect(), )?; self.predicates.push(custom_predicate); - self.pending_self_pred_hashes.extend(pending); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } - pub fn finish(mut self) -> Result> { - // Resolve forward references for SelfPredicateHash - for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes { - let target_idx = self - .predicates - .iter() - .position(|p| p.name == *name) - .ok_or_else(|| { - Error::custom(format!( - "SelfPredicateHash references unknown predicate '{}'", - name - )) - })?; - self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] = - StatementTmplArg::SelfPredicateHash(target_idx); - } - Ok(CustomPredicateBatch::new(self.name, self.predicates)) + pub fn finish(self) -> Arc { + CustomPredicateBatch::new(self.name, self.predicates) } } @@ -316,9 +269,7 @@ mod tests { backends::plonky2::mock::mainpod::MockProver, examples::{custom::eth_dos_batch, MOCK_VD_SET}, frontend::{MainPodBuilder, Operation}, - middleware::{ - self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET, - }, + middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET}, }; #[test] @@ -355,7 +306,7 @@ mod tests { .arg("s2"); builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?; - let batch = builder.finish()?; + let batch = builder.finish(); let batch_clone = batch.clone(); let gt_custom_pred = CustomPredicateRef::new(batch, 0); @@ -405,7 +356,7 @@ mod tests { &[], &[set_contains_stb], )?; - let batch = builder.finish()?; + let batch = builder.finish(); let batch_clone = batch.clone(); let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); @@ -435,83 +386,4 @@ mod tests { Ok(()) } - - #[test] - fn test_builder_self_predicate_hash_unknown_ref() { - let params = Params::default(); - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - - let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) - .arg("x") - .arg(BuilderArg::SelfPredicateHash("nonexistent".into())); - builder - .predicate_and("pred_A", &["x"], &[], &[stb]) - .unwrap(); - - // finish() should fail because "nonexistent" was never defined - assert!(builder.finish().is_err()); - } - - /// Tests cyclic SelfPredicateHash references end-to-end: - /// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward - /// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD - /// using pred_A via MockProver. - #[test] - fn test_builder_self_predicate_hash_e2e() -> Result<()> { - let params = Params::default(); - let vd_set = &*MOCK_VD_SET; - - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - - // pred_A references pred_B's hash (forward ref, pred_B not yet defined) - let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) - .arg("x") - .arg(BuilderArg::SelfPredicateHash("pred_B".into())); - builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?; - - // pred_B references pred_A's hash (backward ref, pred_A already defined) - let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) - .arg("x") - .arg(BuilderArg::SelfPredicateHash("pred_A".into())); - builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?; - - let batch = builder.finish()?; - - // Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0) - assert_eq!( - batch.predicates()[0].statements[0].args[1], - StatementTmplArg::SelfPredicateHash(1) - ); - assert_eq!( - batch.predicates()[1].statements[0].args[1], - StatementTmplArg::SelfPredicateHash(0) - ); - - // Compute concrete hashes - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); - - // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) - let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); - let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?; - mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?; - - // Prove and verify - let prover = MockProver {}; - let proof = mp_builder.prove(&prover)?; - proof.pod.verify()?; - - // Verify the public statement contains pred_b_hash as its argument - let pub_sts = proof.pod.pub_self_statements(); - let custom_st = pub_sts - .iter() - .find(|s| matches!(s, middleware::Statement::Custom(_, _))) - .expect("should have a custom statement"); - if let middleware::Statement::Custom(_, args) = custom_st { - assert_eq!(args[0], ValueRef::Literal(pred_b_hash)); - } - - Ok(()) - } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index b6e8691..04fe1ed 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -4,7 +4,7 @@ use std::{ collections::{HashMap, HashSet}, convert::From, - fmt, iter, + fmt, }; use itertools::Itertools; @@ -13,12 +13,10 @@ use serde::{Deserialize, Serialize}; pub use serialization::SerializedMainPod; use crate::middleware::{ - self, check_custom_pred, - containers::{Container, Dictionary}, - fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key, - MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, - RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS, - EMPTY_VALUE, + self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op, + prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, + OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, + StatementArg, VDSet, Value, ValueRef, }; mod custom; @@ -94,11 +92,8 @@ impl fmt::Display for SignedDict { // https://0xparc.github.io/pod2/merkletree.html will not need it since it will be // deterministic based on the keys values not on the order of the keys when added into the // tree. - for kv in self.dict.iter() { - match kv { - Ok((k, v)) => writeln!(f, " - {} = {}", k, v)?, - Err(e) => writeln!(f, " - ERR: {}", e)?, - } + for (k, v) in self.dict.kvs().iter().sorted_by_key(|kv| kv.0.hash()) { + writeln!(f, " - {} = {}", k, v)?; } Ok(()) } @@ -111,13 +106,16 @@ impl SignedDict { .then_some(()) .ok_or(Error::custom("Invalid signature!")) } - pub fn get(&self, key: impl Into) -> Option { - self.dict.get(&key.into()).unwrap() + pub fn kvs(&self) -> &HashMap { + self.dict.kvs() + } + pub fn get(&self, key: impl Into) -> Option<&Value> { + self.kvs().get(&key.into()) } // Returns the Contains statement that defines key if it exists. pub fn get_statement(&self, key: impl Into) -> Option { let key: Key = key.into(); - self.dict.get(&key).unwrap().map(|value| { + self.kvs().get(&key).map(|value| { Statement::Contains( ValueRef::Literal(Value::from(self.dict.clone())), ValueRef::Literal(Value::from(key.name())), @@ -138,7 +136,7 @@ pub struct MainPodBuilder { pub operations: Vec, pub public_statements: Vec, // Internal state - contains: Vec<(RawValue, RawValue)>, // (root, key) + dict_contains: Vec<(Value, Value)>, // (root, key) } impl fmt::Display for MainPodBuilder { @@ -158,11 +156,6 @@ impl fmt::Display for MainPodBuilder { } } -fn as_container_or_err(v: &Value) -> Result { - v.as_container() - .ok_or_else(|| Error::custom(format!("{v} not a container"))) -} - impl MainPodBuilder { pub fn new(params: &Params, vd_set: &VDSet) -> Self { Self { @@ -172,16 +165,10 @@ impl MainPodBuilder { statements: Vec::new(), operations: Vec::new(), public_statements: Vec::new(), - contains: Vec::new(), + dict_contains: Vec::new(), } } - pub fn stmt_len(&self) -> usize { - self.statements.len() - } pub fn add_pod(&mut self, pod: MainPod) -> Result<()> { - for st in &pod.public_statements { - self.track_contains(st); - } self.input_pods.push(pod); match self.input_pods.len() > self.params.max_input_pods { true => Err(Error::too_many_input_pods( @@ -191,26 +178,31 @@ impl MainPodBuilder { _ => Ok(()), } } + pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) -> Result<()> { + // TODO: Do error handling instead of panic + let (st, op) = st_op; - // If we're adding a Contains statement with literal arguments (an Entry), track it in - // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. - fn track_contains(&mut self, st: &Statement) { + // If we're adding a Contains statement with literal arguments (an Entry), track it in + // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. if let Statement::Contains( ValueRef::Literal(dict), ValueRef::Literal(key), ValueRef::Literal(_), ) = &st { - let root_key = (dict.raw(), key.raw()); - self.contains.push(root_key); + let root_key = (dict.clone(), key.clone()); + self.dict_contains.push(root_key); } - } - - pub fn insert(&mut self, st_op: (Statement, Operation)) -> Result<()> { - // TODO: Do error handling instead of panic - let (st, op) = st_op; - self.track_contains(&st); + if public { + self.public_statements.push(st.clone()); + } + if self.public_statements.len() > self.params.max_public_statements { + return Err(Error::too_many_public_statements( + self.public_statements.len(), + self.params.max_public_statements, + )); + } self.statements.push(st); self.operations.push(op); if self.statements.len() > self.params.max_statements { @@ -355,12 +347,11 @@ impl MainPodBuilder { .ok_or(Error::custom(format!( "Invalid key argument for op {}.", op - )))? - .raw(); + )))?; let proof = if op_type == &Native(ContainsFromEntries) { - as_container_or_err(container)?.prove(key)?.1 + container.prove_existence(key)?.1 } else { - as_container_or_err(container)?.prove_nonexistence(key)? + container.prove_nonexistence(key)? }; Ok(Operation(op_type.clone(), op.1, OpAux::MerkleProof(proof))) } @@ -384,16 +375,18 @@ impl MainPodBuilder { let value = op.1.get(3) .and_then(|arg| arg.value()) - .cloned() - .unwrap_or(Value::from(EMPTY_VALUE)); + .ok_or(Error::custom(format!( + "Invalid key argument for op {}.", + op + ))); let proof = match op_type { Native(ContainerInsertFromEntries) => { - as_container_or_err(old_container)?.insert(key.clone(), value)? + old_container.prove_insertion(key, value?)? } Native(ContainerUpdateFromEntries) => { - as_container_or_err(old_container)?.update(key.raw(), value)? + old_container.prove_update(key, value?)? } - _ => as_container_or_err(old_container)?.delete(key.raw())?, + _ => old_container.prove_deletion(key)?, }; Ok(Operation( op_type.clone(), @@ -406,7 +399,7 @@ impl MainPodBuilder { } fn op_statement( - &self, + &mut self, wildcard_values: Vec<(usize, Value)>, op: Operation, ) -> Result { @@ -567,37 +560,6 @@ impl MainPodBuilder { // TODO: validate proof Statement::ContainerDelete(r1, r2, r3) } - (ReplaceValueWithEntry, &args, _) => { - let mut args = args.to_vec(); - if args.len() != BASE_PARAMS.max_statement_args + 1 { - return Err(Error::custom(format!( - "ReplaceValueWithEntry requires exactly {} args but {} were found", - BASE_PARAMS.max_statement_args + 1, - args.len() - ))); - } - let st = match args.pop().expect("valid vec len") { - OperationArg::Statement(st) => st, - _ => return Err(Error::custom("expected statement")), - }; - let new_st_args = iter::zip(st.args().into_iter(), args) - .map(|(st_arg, arg)| match (st_arg, arg) { - (st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg), - ( - StatementArg::Literal(st_arg_v), - OperationArg::Statement(Statement::Contains( - ValueRef::Literal(root), - ValueRef::Literal(key), - ValueRef::Literal(v), - )), - ) if st_arg_v == v => root_key_to_ak(&root, &key) - .map(StatementArg::Key) - .ok_or_else(native_arg_error), - _ => Err(Error::custom("unexpected operation argument")), - }) - .collect::>>()?; - Statement::from_args(st.predicate(), new_st_args)? - } (t, _, _) => { if t.is_syntactic_sugar() { return Err(Error::custom(format!( @@ -611,7 +573,7 @@ impl MainPodBuilder { } } OperationType::Custom(cpr) => { - let pred = cpr.normalized_predicate(); + let pred = &cpr.batch.predicates()[cpr.index]; if pred.statements.len() != op.1.len() { return Err(Error::custom(format!( "Custom predicate operation needs {} statements but has {}.", @@ -639,7 +601,7 @@ impl MainPodBuilder { } wildcard_map[index] = Some(value); } - fill_wildcard_values(&pred, &args, &mut wildcard_map)?; + fill_wildcard_values(pred, &args, &mut wildcard_map)?; let v_default = Value::from(0); let st_args: Vec<_> = wildcard_map .into_iter() @@ -647,14 +609,14 @@ impl MainPodBuilder { .map(|v| v.unwrap_or_else(|| v_default.clone())) .collect(); check_custom_pred(&self.params, &cpr, &args, &st_args)?; - Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect()) + Statement::Custom(cpr, st_args) } }; Ok(st) } /// For every operation that has Entry statements as arguments we add a Contains statement to - /// open the dictionary (unless such Contains already exists). + /// open the dictionary. fn add_entries_contains(&mut self, op: &Operation) -> Result<()> { for arg in &op.1 { if let OperationArg::Statement(Statement::Contains( @@ -663,9 +625,9 @@ impl MainPodBuilder { ValueRef::Literal(v), )) = arg { - let root_key = (dict.raw(), key.raw()); - if !self.contains.contains(&root_key) { - self.contains.push(root_key); + let root_key = (dict.clone(), key.clone()); + if !self.dict_contains.contains(&root_key) { + self.dict_contains.push(root_key); self.priv_op(Operation::dict_contains(dict, key, v))?; } } @@ -683,28 +645,13 @@ impl MainPodBuilder { self.add_entries_contains(&op)?; let op = Self::fill_in_aux(Self::lower_op(op)?)?; let st = self.op_statement(wildcard_values, op.clone())?; - // Skip adding the statement and operation if it already exists - if !self.statements.contains(&st) { - self.insert((st.clone(), op))?; - } - if public { - self.reveal(&st)?; - } + self.insert(public, (st, op))?; - Ok(st) + Ok(self.statements[self.statements.len() - 1].clone()) } - pub fn reveal(&mut self, st: &Statement) -> Result<()> { - if !self.public_statements.contains(st) { - self.public_statements.push(st.clone()); - } - if self.public_statements.len() > self.params.max_public_statements { - return Err(Error::too_many_public_statements( - self.public_statements.len(), - self.params.max_public_statements, - )); - } - Ok(()) + pub fn reveal(&mut self, st: &Statement) { + self.public_statements.push(st.clone()); } pub fn prove(&self, prover: &dyn MainPodProver) -> Result { @@ -1399,9 +1346,11 @@ pub mod tests { OperationAux::None, ); builder - .insert((value_of_a.clone(), op_contains.clone())) + .insert(false, (value_of_a.clone(), op_contains.clone())) + .unwrap(); + builder + .insert(false, (value_of_b.clone(), op_contains)) .unwrap(); - builder.insert((value_of_b.clone(), op_contains)).unwrap(); let st = Statement::equal( AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "b")), @@ -1414,7 +1363,7 @@ pub mod tests { ], OperationAux::None, ); - builder.insert((st, op)).unwrap(); + builder.insert(false, (st, op)).unwrap(); let prover = MockProver {}; let pod = builder.prove(&prover).unwrap(); diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs index 2839ea8..a5d89da 100644 --- a/src/frontend/multi_pod/cost.rs +++ b/src/frontend/multi_pod/cost.rs @@ -6,20 +6,60 @@ use std::collections::BTreeSet; use crate::{ - frontend::Operation, - middleware::{CustomPredicateRef, Hash, NativeOperation, OperationType, Predicate}, + frontend::{Operation, OperationArg}, + middleware::{ + CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef, + }, }; -/// Unique identifier for a custom predicate in a module. +/// Unique identifier for a custom predicate batch. /// -/// Uses the predicate's cryptographic hash as identifier. Two predicates with the same +/// Uses the batch's cryptographic hash as identifier. Two batches with the same /// hash are considered identical for resource counting purposes. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CustomPredicateId(pub Hash); +pub struct CustomBatchId(pub Hash); -impl From<&CustomPredicateRef> for CustomPredicateId { - fn from(predicate: &CustomPredicateRef) -> Self { - Self(Predicate::Custom(predicate.clone()).hash()) +impl From<&CustomPredicateBatch> for CustomBatchId { + fn from(batch: &CustomPredicateBatch) -> Self { + Self(batch.id()) + } +} + +/// Unique identifier for an anchored key (dict, key) pair. +/// +/// When a Contains statement is used as an argument to operations like gt(), eq(), etc., +/// the value is accessed via an "anchored key" - a reference to a specific key in a +/// specific dictionary. Each unique anchored key used in a POD requires a Contains +/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed). +/// +/// We use the raw values of the dict and key for comparison, as they uniquely identify +/// the anchored key regardless of the specific Value types involved. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct AnchoredKeyId { + /// The dictionary root value (raw representation for Ord). + pub dict: RawValue, + /// The key within the dictionary (raw representation for Ord). + pub key: RawValue, +} + +impl AnchoredKeyId { + /// Create a new anchored key ID from raw values. + pub fn new(dict: RawValue, key: RawValue) -> Self { + Self { dict, key } + } + + /// Try to extract an anchored key ID from a Contains statement with all literal values. + pub fn from_contains_statement(stmt: &Statement) -> Option { + if let Statement::Contains( + ValueRef::Literal(dict), + ValueRef::Literal(key), + ValueRef::Literal(_value), + ) = stmt + { + Some(Self::new(dict.raw(), key.raw())) + } else { + None + } } } @@ -48,9 +88,17 @@ pub struct StatementCost { /// Limit: `params.max_public_key_of` pub public_key_of: usize, - /// Custom predicates used (for custom predicate cardinality constraint). - /// Limit: `params.max_custom_predicates` distinct custom predicates per POD. - pub custom_predicates_ids: BTreeSet, + /// Custom predicate batches used (for batch cardinality constraint). + /// Limit: `params.max_custom_predicate_batches` distinct batches per POD. + pub custom_batch_ids: BTreeSet, + + /// Anchored keys referenced by this operation. + /// + /// When a Contains statement with all literal values is used as an argument, + /// the operation references an "anchored key" (dict, key pair). Each unique + /// anchored key used in a POD incurs an additional Contains statement cost, + /// as MainPodBuilder::add_entries_contains will auto-insert it if not already present. + pub anchored_keys: BTreeSet, } impl StatementCost { @@ -111,14 +159,25 @@ impl StatementCost { // Syntactic sugar variants (lowered before proving) | NativeOperation::GtEqFromEntries | NativeOperation::GtFromEntries - | NativeOperation::GtToNotEqual - | NativeOperation::ReplaceValueWithEntry => {} + | NativeOperation::GtToNotEqual => {} } } OperationType::Custom(cpr) => { cost.custom_pred_verifications = 1; - cost.custom_predicates_ids - .insert(CustomPredicateId::from(cpr)); + cost.custom_batch_ids + .insert(CustomBatchId::from(&*cpr.batch)); + } + } + + // Extract anchored keys from operation arguments. + // Any argument that is a Contains statement with all literal values + // represents an anchored key reference that will require a Contains + // statement in the POD (auto-inserted by MainPodBuilder if needed). + for arg in &op.1 { + if let OperationArg::Statement(stmt) = arg { + if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) { + cost.anchored_keys.insert(anchored_key); + } } } diff --git a/src/frontend/multi_pod/deps.rs b/src/frontend/multi_pod/deps.rs index 9472a1f..97b4ef4 100644 --- a/src/frontend/multi_pod/deps.rs +++ b/src/frontend/multi_pod/deps.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; +use super::cost::AnchoredKeyId; use crate::{ frontend::{Operation, OperationArg}, middleware::{Hash, Statement}, @@ -99,6 +100,11 @@ impl DependencyGraph { pod_hash, statement: dep_stmt.clone(), })); + } else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() { + // Anchored-key Contains args may be implicit requirements that are + // auto-materialized by MainPodBuilder. They are handled by anchored-key + // resource accounting, not by statement dependency edges. + continue; } else { // Statement arguments should either be internal (created earlier) // or from external PODs (except anchored-key implicit Contains). @@ -122,8 +128,9 @@ impl DependencyGraph { mod tests { use super::*; use crate::{ + dict, frontend::Operation as FrontendOp, - middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef}, + middleware::{AnchoredKey, NativeOperation, OperationAux, OperationType, Value, ValueRef}, }; fn equal_stmt(n: i64) -> Statement { @@ -188,4 +195,32 @@ mod tests { assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]); } + + #[test] + fn test_anchored_key_contains_arg_is_treated_as_implicit_requirement() { + // A literal Contains statement can be used as an anchored-key argument even when + // no explicit producer statement exists in internal/external statements, because + // MainPodBuilder auto-inserts Contains statements for anchored keys. + let dict = dict!({ + "k" => 7_i64 + }); + + let anchored_contains = Statement::Contains( + ValueRef::Literal(Value::from(dict.clone())), + ValueRef::Literal(Value::from("k")), + ValueRef::Literal(Value::from(7_i64)), + ); + let ak = AnchoredKey::from((&dict, "k")); + let produced_statement = Statement::Equal(ValueRef::Key(ak.clone()), ValueRef::Key(ak)); + + // Use a typical frontend operation that consumes entry-like args. + // We're only testing the dependency graph, not the actual proof, so the operation + // just needs to have the right arguments to test what we're looking for. + let statements = vec![produced_statement]; + let operations = vec![FrontendOp::eq(anchored_contains.clone(), anchored_contains)]; + + let graph = DependencyGraph::build(&statements, &operations, &HashMap::new()); + + assert!(graph.statement_deps[0].is_empty()); + } } diff --git a/src/frontend/multi_pod/diagnostics.rs b/src/frontend/multi_pod/diagnostics.rs deleted file mode 100644 index f56778f..0000000 --- a/src/frontend/multi_pod/diagnostics.rs +++ /dev/null @@ -1,466 +0,0 @@ -//! Diagnostic utilities for multi-POD resource analysis. -//! -//! Provides two views: -//! - [`ResourceSummary`]: Pre-solve aggregate resource demand vs. per-POD limits. -//! Shows which resource category is the bottleneck (requires the most PODs). -//! - [`SolutionBreakdown`]: Post-solve per-POD utilization showing how full each POD is. - -use std::{collections::BTreeSet, fmt}; - -use super::cost::StatementCost; -use crate::middleware::Params; - -/// A single resource category's usage vs. per-POD limit. -/// -/// Used both for pre-solve aggregate demand (in [`ResourceSummary`]) where -/// `used` is the total across all statements, and for post-solve per-POD -/// breakdown (in [`PodUtilization`]) where `used` is the POD's consumption. -#[derive(Clone, Debug)] -pub struct UtilizationRow { - pub name: &'static str, - pub used: usize, - pub limit: usize, -} - -impl UtilizationRow { - /// Utilization as a fraction (0.0 to 1.0). - pub fn utilization(&self) -> f64 { - if self.limit == 0 { - if self.used == 0 { - 0.0 - } else { - f64::INFINITY - } - } else { - self.used as f64 / self.limit as f64 - } - } - - /// Minimum PODs needed for this resource alone: `ceil(used / limit)`. - /// `None` if `limit` is 0 and `used > 0` (infeasible). - pub fn min_pods(&self) -> Option { - lower_bound(self.used, self.limit) - } -} - -/// Aggregate resource usage over a set of statement costs into per-category rows. -/// -/// Single source of truth for the resource categories and their corresponding -/// `Params` limits. Used both for pre-solve totals and per-POD breakdowns. -fn aggregate_rows<'a>( - costs: impl IntoIterator, - params: &Params, -) -> (Vec, usize) { - let mut num_stmts = 0usize; - let mut merkle_proofs = 0usize; - let mut merkle_state_transitions = 0usize; - let mut custom_pred_verifications = 0usize; - let mut signed_by = 0usize; - let mut public_key_of = 0usize; - let mut custom_pred_ids = BTreeSet::new(); - - for c in costs { - num_stmts += 1; - merkle_proofs += c.merkle_proofs; - merkle_state_transitions += c.merkle_state_transitions; - custom_pred_verifications += c.custom_pred_verifications; - signed_by += c.signed_by; - public_key_of += c.public_key_of; - custom_pred_ids.extend(c.custom_predicates_ids.iter().cloned()); - } - - let rows = vec![ - UtilizationRow { - name: "private statements", - used: num_stmts, - limit: params.max_priv_statements(), - }, - UtilizationRow { - name: "merkle proofs", - used: merkle_proofs, - limit: params.containers.state.max_medium, - }, - UtilizationRow { - name: "merkle state transitions", - used: merkle_state_transitions, - limit: params.containers.transition.max_medium, - }, - UtilizationRow { - name: "custom pred verifications", - used: custom_pred_verifications, - limit: params.max_custom_predicate_verifications, - }, - UtilizationRow { - name: "signed_by", - used: signed_by, - limit: params.max_signed_by, - }, - UtilizationRow { - name: "public_key_of", - used: public_key_of, - limit: params.max_public_key_of, - }, - UtilizationRow { - name: "distinct custom predicates", - used: custom_pred_ids.len(), - limit: params.max_custom_predicates, - }, - ]; - - (rows, num_stmts) -} - -/// Pre-solve aggregate resource summary. -/// -/// Shows total resource demand across all operations and the minimum PODs -/// each resource category would require independently. -#[derive(Clone, Debug)] -pub struct ResourceSummary { - pub rows: Vec, - pub num_statements: usize, -} - -impl ResourceSummary { - /// Compute a resource summary from per-statement costs and params. - pub fn from_costs(costs: &[StatementCost], params: &Params) -> Self { - let (rows, num_statements) = aggregate_rows(costs.iter(), params); - Self { - rows, - num_statements, - } - } - - /// The resource category requiring the most PODs (the bottleneck). - /// Returns `None` only if there are no statements. - pub fn bottleneck(&self) -> Option<&UtilizationRow> { - self.rows - .iter() - .filter(|r| r.used > 0) - .max_by_key(|r| r.min_pods().unwrap_or(usize::MAX)) - } -} - -impl fmt::Display for ResourceSummary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Resource Summary ({} statements)", self.num_statements)?; - writeln!( - f, - " {:<30} {:>5} {:>9} {:>8}", - "Category", "Total", "Limit/POD", "Min PODs" - )?; - - let bottleneck_name = self.bottleneck().map(|r| r.name); - - for row in &self.rows { - let min_pods_str = match row.min_pods() { - Some(n) => format!("{}", n), - None => "inf".to_string(), - }; - let marker = if Some(row.name) == bottleneck_name && row.used > 0 { - " <<<" - } else { - "" - }; - writeln!( - f, - " {:<30} {:>5} {:>9} {:>8}{}", - row.name, row.used, row.limit, min_pods_str, marker - )?; - } - - Ok(()) - } -} - -/// Per-POD resource utilization in a solved solution. -#[derive(Clone, Debug)] -pub struct PodUtilization { - /// POD index. - pub pod_idx: usize, - /// Whether this is the output POD (last). - pub is_output: bool, - /// Number of statements in this POD. - pub num_statements: usize, - /// Resource usage vs. limits for each category. - pub resources: Vec, -} - -/// Post-solve per-POD resource breakdown. -#[derive(Clone, Debug)] -pub struct SolutionBreakdown { - pub pods: Vec, - pub num_statements: usize, - pub pod_count: usize, -} - -impl SolutionBreakdown { - /// Compute a solution breakdown from per-statement costs, the solution's - /// pod_statements assignment, and params. - pub fn from_solution( - costs: &[StatementCost], - pod_statements: &[Vec], - pod_count: usize, - num_statements: usize, - params: &Params, - ) -> Self { - let pods = (0..pod_count) - .map(|pod_idx| { - let stmts = &pod_statements[pod_idx]; - let (resources, num_stmts) = - aggregate_rows(stmts.iter().map(|&s| &costs[s]), params); - PodUtilization { - pod_idx, - is_output: pod_idx == pod_count - 1, - num_statements: num_stmts, - resources, - } - }) - .collect(); - - Self { - pods, - num_statements, - pod_count, - } - } -} - -impl fmt::Display for SolutionBreakdown { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "Solution Breakdown ({} statements -> {} PODs)", - self.num_statements, self.pod_count - )?; - - for pod in &self.pods { - let role = if pod.is_output { - "output" - } else { - "intermediate" - }; - writeln!(f, " POD {} ({}):", pod.pod_idx, role)?; - - for row in &pod.resources { - // Only show rows with nonzero usage to reduce noise - if row.used > 0 { - let pct = if row.limit > 0 { - format!("({:>3}%)", (row.used * 100) / row.limit) - } else { - "".to_string() - }; - writeln!( - f, - " {:<30} {:>3}/{:<3} {}", - row.name, row.used, row.limit, pct - )?; - } - } - writeln!(f)?; - } - - Ok(()) - } -} - -fn lower_bound(used: usize, limit: usize) -> Option { - if used == 0 { - Some(0) - } else if limit == 0 { - None - } else { - Some(used.div_ceil(limit)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - frontend::multi_pod::cost::CustomPredicateId, - middleware::{Hash, ParamsContainers, ParamsMerkleProofs, RawValue}, - }; - - fn default_params() -> Params { - Params { - max_statements: 48, - max_public_statements: 8, - containers: ParamsContainers { - state: ParamsMerkleProofs { - max_small: 0, - max_medium: 8, - }, - transition: ParamsMerkleProofs { - max_small: 0, - max_medium: 4, - }, - ..Default::default() - }, - max_custom_predicate_verifications: 10, - max_custom_predicates: 2, - max_signed_by: 4, - max_public_key_of: 4, - ..Params::default() - } - } - - #[test] - fn test_resource_summary_bottleneck() { - let params = default_params(); - // max_priv = 48 - 8 = 40 - - // 6 merkle proofs, 3 state transitions, rest zero-cost - let costs: Vec = (0..14) - .map(|i| { - let mut c = StatementCost::default(); - if i < 6 { - c.merkle_proofs = 1; - } else if i < 9 { - c.merkle_state_transitions = 1; - } - c - }) - .collect(); - - let summary = ResourceSummary::from_costs(&costs, ¶ms); - - // 14 statements / 40 per pod = 1 pod for statements - // 6 merkle proofs / 8 per pod = 1 pod - // 3 state transitions / 4 per pod = 1 pod - // All categories need 1 pod, so bottleneck is whichever has the highest min_pods. - // They're all 1, so the first with total > 0 wins in max_by_key (stable). - let bottleneck = summary.bottleneck().unwrap(); - assert_eq!(bottleneck.min_pods(), Some(1)); - - // Verify display doesn't panic - let display = format!("{}", summary); - assert!(display.contains("Resource Summary (14 statements)")); - assert!(display.contains("merkle proofs")); - } - - #[test] - fn test_resource_summary_signed_by_bottleneck() { - let params = Params { - max_statements: 48, - max_public_statements: 8, - max_signed_by: 2, - ..Params::default() - }; - // max_priv = 40 - - // 6 signed_by operations - let costs: Vec = (0..6) - .map(|_| StatementCost { - signed_by: 1, - ..Default::default() - }) - .collect(); - - let summary = ResourceSummary::from_costs(&costs, ¶ms); - let bottleneck = summary.bottleneck().unwrap(); - - assert_eq!(bottleneck.name, "signed_by"); - // 6 / 2 = 3 pods - assert_eq!(bottleneck.min_pods(), Some(3)); - } - - #[test] - fn test_resource_summary_custom_predicates_bottleneck() { - let params = Params { - max_statements: 48, - max_public_statements: 8, - max_custom_predicates: 1, // Only 1 distinct predicate per POD - max_custom_predicate_verifications: 10, - ..Params::default() - }; - - // 3 statements using 3 different custom predicates - let costs: Vec = (0..3) - .map(|i| { - let mut ids = std::collections::BTreeSet::new(); - ids.insert(CustomPredicateId(Hash::from(RawValue::from(i as i64)))); - StatementCost { - custom_pred_verifications: 1, - custom_predicates_ids: ids, - ..Default::default() - } - }) - .collect(); - - let summary = ResourceSummary::from_costs(&costs, ¶ms); - let bottleneck = summary.bottleneck().unwrap(); - - assert_eq!(bottleneck.name, "distinct custom predicates"); - // 3 distinct predicates / 1 per pod = 3 pods - assert_eq!(bottleneck.min_pods(), Some(3)); - } - - #[test] - fn test_solution_breakdown_display() { - let params = default_params(); - - let costs: Vec = (0..8) - .map(|i| { - let mut c = StatementCost::default(); - if i < 4 { - c.merkle_proofs = 1; - } else { - c.merkle_state_transitions = 1; - } - c - }) - .collect(); - - let pod_statements = vec![ - vec![0, 1, 2, 3], // POD 0: 4 merkle proofs - vec![4, 5, 6, 7], // POD 1: 4 state transitions - ]; - - let breakdown = SolutionBreakdown::from_solution(&costs, &pod_statements, 2, 8, ¶ms); - - assert_eq!(breakdown.pods.len(), 2); - assert!(!breakdown.pods[0].is_output); - assert!(breakdown.pods[1].is_output); - - // POD 0 should have 4 merkle proofs - let mp = breakdown.pods[0] - .resources - .iter() - .find(|r| r.name == "merkle proofs") - .unwrap(); - assert_eq!(mp.used, 4); - assert_eq!(mp.limit, 8); - - // POD 1 should have 4 state transitions - let mst = breakdown.pods[1] - .resources - .iter() - .find(|r| r.name == "merkle state transitions") - .unwrap(); - assert_eq!(mst.used, 4); - assert_eq!(mst.limit, 4); - - // Verify display doesn't panic and contains expected content - let display = format!("{}", breakdown); - assert!(display.contains("Solution Breakdown (8 statements -> 2 PODs)")); - assert!(display.contains("POD 0 (intermediate)")); - assert!(display.contains("POD 1 (output)")); - } - - #[test] - fn test_utilization_row_fraction() { - let row = UtilizationRow { - name: "test", - used: 3, - limit: 4, - }; - assert!((row.utilization() - 0.75).abs() < f64::EPSILON); - - let zero_row = UtilizationRow { - name: "test", - used: 0, - limit: 4, - }; - assert!((zero_row.utilization()).abs() < f64::EPSILON); - } -} diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index 813e333..d25fcce 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -48,23 +48,21 @@ //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder use std::{ - collections::{BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet, HashMap}, fmt, }; use crate::{ - frontend::{MainPod, MainPodBuilder, Operation}, + frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value}, }; mod cost; mod deps; -pub mod diagnostics; mod solver; -use cost::StatementCost; +use cost::{AnchoredKeyId, StatementCost}; use deps::{DependencyGraph, StatementSource}; -pub use diagnostics::{ResourceSummary, SolutionBreakdown}; pub use solver::MultiPodSolution; /// Error type for multi-POD operations. @@ -170,8 +168,12 @@ pub struct MultiPodBuilder { options: Options, /// External input PODs (already proved). input_pods: Vec, + /// Statements created by this builder. + statements: Vec, + /// Operations that produce each statement. + operations: Vec, /// Optional initial wildcard values for custom operations - operations_wildcard_values: HashMap>, + operations_wildcard_values: Vec>, /// Indices of statements that should be public in output PODs. /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. output_public_indices: Vec, @@ -191,7 +193,7 @@ pub struct SolvedMultiPod { statements: Vec, operations: Vec, output_public_indices: Vec, - operations_wildcard_values: HashMap>, + operations_wildcard_values: Vec>, solution: MultiPodSolution, deps: DependencyGraph, } @@ -202,22 +204,6 @@ impl SolvedMultiPod { &self.solution } - /// Compute a post-solve per-POD resource utilization breakdown. - pub fn solution_breakdown(&self) -> SolutionBreakdown { - let costs: Vec = self - .operations - .iter() - .map(StatementCost::from_operation) - .collect(); - SolutionBreakdown::from_solution( - &costs, - &self.solution.pod_statements, - self.solution.pod_count, - self.statements.len(), - &self.params, - ) - } - /// Build and prove all PODs. /// /// Builds PODs in dependency order (0, 1, ..., k) and proves each one. @@ -274,27 +260,56 @@ impl SolvedMultiPod { let statements_sorted: BTreeSet = statements_in_this_pod.iter().copied().collect(); let public_set = &solution.pod_public_statements[pod_idx]; + // Track statements proved locally in this POD for argument remapping. + // We index by statement content so duplicate statements can reuse a single + // built statement slot in MainPodBuilder. + let mut added_statements_by_content: HashMap = HashMap::new(); + for &stmt_idx in &statements_sorted { - let op = self.operations[stmt_idx].clone(); - let wildcard_values = self - .operations_wildcard_values - .get(&stmt_idx) - .cloned() - .unwrap_or_default(); + let original_stmt = self.statements[stmt_idx].clone(); + + // If this statement content was already built in this POD, reuse it instead + // of replaying the operation. If any duplicate is public, reveal the + // already-built statement. + if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) { + continue; + } + + let mut op = self.operations[stmt_idx].clone(); + let wildcard_values = self.operations_wildcard_values[stmt_idx].clone(); + + // Remap Statement arguments that reference locally-proved statements. + // For external dependencies (from input PODs including earlier generated PODs), + // the original Statement is used directly - MainPodBuilder will find it in + // the input POD's public statements via find_op_arg. + for arg in &mut op.1 { + if let OperationArg::Statement(ref orig_stmt) = arg { + if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) { + *arg = OperationArg::Statement(remapped_stmt.clone()); + } + } + } let stmt = builder.op(false, wildcard_values, op)?; - assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check + + added_statements_by_content.insert(original_stmt, stmt); } // For the output pod, make statements public in the original order. // Intermediate pods use the solver-selected public set. if pod_idx == solution.pod_count - 1 { for idx in &self.output_public_indices { - builder.reveal(&self.statements[*idx])?; + let stmt = added_statements_by_content + .get(&self.statements[*idx]) + .expect("exists"); + builder.reveal(stmt); } } else { for idx in public_set { - builder.reveal(&self.statements[*idx])?; + let stmt = added_statements_by_content + .get(&self.statements[*idx]) + .expect("exists"); + builder.reveal(stmt); } } @@ -302,7 +317,7 @@ impl SolvedMultiPod { // for this POD. These do not require local proving in this POD. for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] { let ext_premise = &solution.external_premises[*ext_premise_idx]; - builder.reveal(&ext_premise.statement)?; + builder.reveal(&ext_premise.statement); } // Step 4: Prove the POD @@ -441,7 +456,9 @@ impl MultiPodBuilder { options, builder, input_pods: Vec::new(), - operations_wildcard_values: HashMap::new(), + statements: Vec::new(), + operations: Vec::new(), + operations_wildcard_values: Vec::new(), output_public_indices: Vec::new(), } } @@ -463,16 +480,6 @@ impl MultiPodBuilder { self.op(false, vec![], op) } - // Find the index of a statement that has been added. Panics if the statement doesn't - // exist. - fn stmt_index(&self, stmt: &Statement) -> usize { - self.builder - .statements - .iter() - .position(|s| s == stmt) - .expect("exists") - } - pub fn op( &mut self, public: bool, @@ -481,10 +488,8 @@ impl MultiPodBuilder { ) -> Result { let stmt = self.add_operation(wildcard_values, op)?; if public { - let index = self.stmt_index(&stmt); - if !self.output_public_indices.contains(&index) { - self.output_public_indices.push(index); - } + // Index is always new (just added), so push without duplicate check + self.output_public_indices.push(self.statements.len() - 1); } Ok(stmt) } @@ -505,8 +510,10 @@ impl MultiPodBuilder { let stmt = self .builder .op(false, wildcard_values.clone(), op.clone())?; - self.operations_wildcard_values - .insert(self.stmt_index(&stmt), wildcard_values.clone()); + + self.statements.push(stmt.clone()); + self.operations.push(op); + self.operations_wildcard_values.push(wildcard_values); Ok(stmt) } @@ -516,7 +523,7 @@ impl MultiPodBuilder { /// Returns an error if the statement was not found in the builder. /// Calling this multiple times on the same statement is idempotent. pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { - if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) { + if let Some(idx) = self.statements.iter().position(|s| s == stmt) { if !self.output_public_indices.contains(&idx) { self.output_public_indices.push(idx); } @@ -529,22 +536,8 @@ impl MultiPodBuilder { } /// Get the number of statements. - pub fn stmt_len(&self) -> usize { - self.builder.stmt_len() - } - - /// Compute a pre-solve resource summary showing aggregate demand vs. per-POD limits. - /// - /// This is useful for understanding which resource category is the bottleneck - /// before running the solver, especially when debugging solver performance issues. - pub fn resource_summary(&self) -> ResourceSummary { - let costs: Vec = self - .builder - .operations - .iter() - .map(StatementCost::from_operation) - .collect(); - ResourceSummary::from_costs(&costs, &self.params) + pub fn num_statements(&self) -> usize { + self.statements.len() } /// Solve the packing problem and return a solved builder ready for proving. @@ -552,31 +545,66 @@ impl MultiPodBuilder { /// This runs the MILP solver to find the optimal POD assignment. /// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved. pub fn solve(self) -> Result { - let MainPodBuilder { - statements, - operations, - .. - } = self.builder; // Compute costs for each statement - let costs: Vec = operations + let costs: Vec = self + .operations .iter() .map(StatementCost::from_operation) .collect(); + // Collect all unique anchored keys from the costs + let all_anchored_keys: Vec = costs + .iter() + .flat_map(|c| c.anchored_keys.iter().cloned()) + .collect::>() + .into_iter() + .collect(); + + // Build map from anchored key to its producing statement index (if any). + // A Contains statement with literal (dict, key, value) "produces" that anchored key. + let mut ak_to_producer: HashMap = HashMap::new(); + for (stmt_idx, stmt) in self.statements.iter().enumerate() { + if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { + // First producer wins (shouldn't have duplicates in practice) + ak_to_producer.entry(ak).or_insert(stmt_idx); + } + } + + // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] + let anchored_key_producers: Vec> = all_anchored_keys + .iter() + .map(|ak| ak_to_producer.get(ak).copied()) + .collect(); + // Build external POD statement mapping let external_pod_statements = build_external_statement_map(&self.input_pods); // Build dependency graph - let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements); + let deps = + DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements); + + // Build statement content groups for deduplication. + // Statements with identical content share a single slot in the POD. + // Keep groups ordered by first occurrence index for deterministic solver input. + let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new(); + let mut groups_by_first_idx: BTreeMap> = BTreeMap::new(); + for (idx, stmt) in self.statements.iter().enumerate() { + let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx); + groups_by_first_idx.entry(first_idx).or_default().push(idx); + } + let statement_content_groups: Vec> = groups_by_first_idx.into_values().collect(); // Run solver let input = solver::SolverInput { - num_statements: statements.len(), + num_statements: self.statements.len(), costs: &costs, deps: &deps, output_public_indices: &self.output_public_indices, params: &self.params, max_pods: self.options.max_pods, + all_anchored_keys: &all_anchored_keys, + anchored_key_producers: &anchored_key_producers, + statement_content_groups: &statement_content_groups, }; let solution = solver::solve(&input)?; @@ -585,8 +613,8 @@ impl MultiPodBuilder { params: self.params, vd_set: self.vd_set, input_pods: self.input_pods, - statements, - operations, + statements: self.statements, + operations: self.operations, output_public_indices: self.output_public_indices, operations_wildcard_values: self.operations_wildcard_values, solution, @@ -817,13 +845,33 @@ mod tests { let solution = solved.solution(); // Expected: exactly 2 PODs - // Solution A: - // - POD 0 (intermediate): public statements 0 (contains) - // - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out), - // public statement 2 (b_out) - // Solution B: - // - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out) - // - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out) + // - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public + // - POD 1 (output): statement 2 (b_out); b_out is public + // The output POD accesses a_out from POD 0 to satisfy b_out's dependency. + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs for 3-statement chain with max_priv=2" + ); + + // POD 0 should contain statements 0 and 1 (contains and a_out) + assert!( + solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), + "POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}", + solution.pod_statements[0] + ); + + // Statement 1 (a_out) should be public in POD 0 so POD 1 can access it + assert!( + solution.pod_public_statements[0].contains(&1), + "Statement 1 (a_out) should be public in POD 0" + ); + + // POD 1 (output) should contain statement 2 (b_out) + assert!( + solution.pod_statements[1].contains(&2), + "POD 1 should contain statement 2 (b_out), got {:?}", + solution.pod_statements[1] + ); // Statement 2 (b_out) should be public in POD 1 (it's output-public) assert!( diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index 8d81ab3..9a24fb0 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -52,7 +52,7 @@ use itertools::Itertools; use super::Result; use crate::{ frontend::multi_pod::{ - cost::{CustomPredicateId, StatementCost}, + cost::{AnchoredKeyId, CustomBatchId, StatementCost}, deps::{DependencyGraph, ExternalDependency, StatementSource}, }, middleware::{Hash, Params}, @@ -95,6 +95,7 @@ struct DependencyStats { struct SolveDebugContext { dep_stats: DependencyStats, batch_memberships: usize, + anchored_key_memberships: usize, } #[derive(Clone, Copy, Debug, Default)] @@ -104,8 +105,10 @@ struct ModelSizeEstimate { vars_public_external: usize, vars_pod_used: usize, vars_batch_used: usize, + vars_anchored_key_used: usize, vars_uses_input: usize, vars_uses_external: usize, + vars_content_group_used: usize, vars_total: usize, c1_coverage: usize, c2_output_public: usize, @@ -117,6 +120,7 @@ struct ModelSizeEstimate { c6_pre_content_group: usize, c6_resource_limits: usize, c7_batch_cardinality: usize, + c7b_anchored_key_tracking: usize, c8a_internal_inputs: usize, c8b_external_dep_inputs: usize, c8c_external_forward_inputs: usize, @@ -137,6 +141,8 @@ impl ModelSizeEstimate { debug_ctx: &SolveDebugContext, ) -> Self { let n = input.num_statements; + let num_groups = input.statement_content_groups.len(); + let num_anchored_keys = input.all_anchored_keys.len(); let triangular_k = target_pods * target_pods.saturating_sub(1) / 2; let vars_prove = n * target_pods; @@ -144,15 +150,19 @@ impl ModelSizeEstimate { let vars_public_external = external_premises_len * target_pods; let vars_pod_used = target_pods; let vars_batch_used = all_batches_len * target_pods; + let vars_anchored_key_used = num_anchored_keys * target_pods; let vars_uses_input = triangular_k; let vars_uses_external = external_pods_len * target_pods; + let vars_content_group_used = num_groups * target_pods; let vars_total = vars_prove + vars_public + vars_public_external + vars_pod_used + vars_batch_used + + vars_anchored_key_used + vars_uses_input - + vars_uses_external; + + vars_uses_external + + vars_content_group_used; let c1_coverage = n; let c2_output_public = input.output_public_indices.len(); @@ -161,10 +171,12 @@ impl ModelSizeEstimate { let c4_pod_existence = n * target_pods; let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods; let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods; - let c6_pre_content_group = n * target_pods; + let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods); let c6_resource_limits = 7 * target_pods; let c7_batch_cardinality = (debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods); + let c7b_anchored_key_tracking = + (debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods); let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k; let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k; let c8c_external_forward_inputs = external_premises_len * triangular_k; @@ -182,6 +194,7 @@ impl ModelSizeEstimate { + c6_pre_content_group + c6_resource_limits + c7_batch_cardinality + + c7b_anchored_key_tracking + c8a_internal_inputs + c8b_external_dep_inputs + c8c_external_forward_inputs @@ -196,8 +209,10 @@ impl ModelSizeEstimate { vars_public_external, vars_pod_used, vars_batch_used, + vars_anchored_key_used, vars_uses_input, vars_uses_external, + vars_content_group_used, vars_total, c1_coverage, c2_output_public, @@ -209,6 +224,7 @@ impl ModelSizeEstimate { c6_pre_content_group, c6_resource_limits, c7_batch_cardinality, + c7b_anchored_key_tracking, c8a_internal_inputs, c8b_external_dep_inputs, c8c_external_forward_inputs, @@ -284,7 +300,6 @@ pub struct MultiPodSolution { } /// Input to the MILP solver. -#[derive(Debug)] pub struct SolverInput<'a> { /// Number of statements. pub num_statements: usize, @@ -303,6 +318,28 @@ pub struct SolverInput<'a> { /// Maximum number of PODs the solver will consider. pub max_pods: usize, + + /// All unique anchored keys referenced by any statement. + /// + /// Each unique (dict, key) pair that is used as an anchored key reference + /// in any operation. When a Contains statement with literal values is used + /// as an argument, it creates an anchored key reference. + pub all_anchored_keys: &'a [AnchoredKeyId], + + /// For each anchored key, the statement index that produces it (if any). + /// + /// When a Contains statement with literal (dict, key, value) args is explicitly + /// added, it "produces" that anchored key. If the producer is in the same POD + /// as statements using the anchored key, no auto-insertion is needed. + /// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`. + pub anchored_key_producers: &'a [Option], + + /// Statement content groups for deduplication. + /// + /// Each inner Vec contains statement indices that have identical content. + /// When multiple statements with the same content are proved in the same POD, + /// they only use one statement slot (the POD deduplicates identical statements). + pub statement_content_groups: &'a [Vec], } /// Solve the MILP problem to find optimal POD packing. @@ -349,11 +386,11 @@ pub fn solve(input: &SolverInput) -> Result { ))); } - // Collect all unique custom predicate IDs used - let all_custom_predicates: Vec = input + // Collect all unique custom batch IDs used + let all_batches: Vec = input .costs .iter() - .flat_map(|c| c.custom_predicates_ids.iter().cloned()) + .flat_map(|c| c.custom_batch_ids.iter().cloned()) .unique() .collect(); @@ -380,26 +417,27 @@ pub fn solve(input: &SolverInput) -> Result { } let dep_stats = dependency_stats(input.deps); - let batch_memberships: usize = input - .costs - .iter() - .map(|c| c.custom_predicates_ids.len()) - .sum(); + let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum(); + let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum(); let debug_ctx = SolveDebugContext { dep_stats, batch_memberships, + anchored_key_memberships, }; if log::log_enabled!(log::Level::Debug) { let resource_totals = ResourceTotals::from_costs(input.costs); - let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod); + let lb_statement_groups = + lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod); let lb_merkle = lower_bound_from_total( resource_totals.merkle_proofs, - input.params.containers.state.max_medium, + input.params.max_merkle_proofs_containers, ); let lb_merkle_transitions = lower_bound_from_total( resource_totals.merkle_state_transitions, - input.params.containers.transition.max_medium, + input + .params + .max_merkle_tree_state_transition_proofs_containers, ); let lb_custom_pred_verifications = lower_bound_from_total( resource_totals.custom_pred_verifications, @@ -425,12 +463,14 @@ pub fn solve(input: &SolverInput) -> Result { .expect("non-empty lower-bound candidate list"); log::debug!( - "MILP summary: statements={} output_public={} \ - custom_predicates={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ + "MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \ + batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ external_premises={} search_min_pods={} max_pods={}", n, num_output_public, - all_custom_predicates.len(), + input.statement_content_groups.len(), + input.all_anchored_keys.len(), + all_batches.len(), dep_stats.internal_edges, dep_stats.external_edges, external_pods.len(), @@ -441,13 +481,14 @@ pub fn solve(input: &SolverInput) -> Result { log::debug!( "MILP resource totals: merkle_proofs={} merkle_state_transitions={} \ custom_pred_verifications={} signed_by={} public_key_of={} \ - batch_memberships={}", + batch_memberships={} anchored_key_memberships={}", resource_totals.merkle_proofs, resource_totals.merkle_state_transitions, resource_totals.custom_pred_verifications, resource_totals.signed_by, resource_totals.public_key_of, batch_memberships, + anchored_key_memberships ); log::debug!( "MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \ @@ -472,7 +513,7 @@ pub fn solve(input: &SolverInput) -> Result { if let Some(solution) = try_solve_with_pods( input, target_pods, - &all_custom_predicates, + &all_batches, &external_pods, &external_premises, &debug_ctx, @@ -499,7 +540,7 @@ pub fn solve(input: &SolverInput) -> Result { fn try_solve_with_pods( input: &SolverInput, target_pods: usize, - all_custom_predicates: &[CustomPredicateId], + all_batches: &[CustomBatchId], external_pods: &[Hash], external_premises: &[ExternalDependency], debug_ctx: &SolveDebugContext, @@ -533,8 +574,21 @@ fn try_solve_with_pods( .map(|_| vars.add(variable().binary())) .collect(); - // custom_predicates[b][p] - custom predicate b is used in POD p - let custom_predicate_used: Vec> = (0..all_custom_predicates.len()) + // batch_used[b][p] - custom batch b is used in POD p + let batch_used: Vec> = (0..all_batches.len()) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // anchored_key_used[ak][p] - anchored key ak is used in POD p + // When a statement references an anchored key (via a Contains statement argument), + // that POD must have a Contains statement for that (dict, key) pair. + // MainPodBuilder::add_entries_contains auto-inserts these, and we must account + // for them in the statement count. + let anchored_key_used: Vec> = (0..input.all_anchored_keys.len()) .map(|_| { (0..target_pods) .map(|_| vars.add(variable().binary())) @@ -579,19 +633,31 @@ fn try_solve_with_pods( .map(|(i, ext)| (ext.clone(), i)) .collect(); + // content_group_used[g][p] - content group g has at least one statement proved in POD p + // When multiple statements have identical content, they share a slot in the POD. + // This variable tracks whether at least one statement from each content group is proved. + let num_groups = input.statement_content_groups.len(); + let content_group_used: Vec> = (0..num_groups) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + if log::log_enabled!(log::Level::Debug) { let estimate = ModelSizeEstimate::for_target_pods( input, target_pods, - all_custom_predicates.len(), + all_batches.len(), external_pods.len(), external_premises.len(), debug_ctx, ); log::debug!( "MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \ - public_external={} batch_used={} uses_input={} \ - uses_external={}]", + public_external={} batch_used={} anchored_key_used={} uses_input={} \ + uses_external={} content_group_used={}]", target_pods, estimate.vars_total, estimate.vars_prove, @@ -599,12 +665,14 @@ fn try_solve_with_pods( estimate.vars_pod_used, estimate.vars_public_external, estimate.vars_batch_used, + estimate.vars_anchored_key_used, estimate.vars_uses_input, estimate.vars_uses_external, + estimate.vars_content_group_used ); log::debug!( "MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \ - c5i={} c5e={} c6_pre={} c6_limits={} c7={} c8a={} c8b={} c8c={} \ + c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \ c8d={} c9={} c10={} c10b={}]", target_pods, estimate.constraints_total, @@ -618,6 +686,7 @@ fn try_solve_with_pods( estimate.c6_pre_content_group, estimate.c6_resource_limits, estimate.c7_batch_cardinality, + estimate.c7b_anchored_key_tracking, estimate.c8a_internal_inputs, estimate.c8b_external_dep_inputs, estimate.c8c_external_forward_inputs, @@ -729,11 +798,35 @@ fn try_solve_with_pods( } } + // Constraint 6: Resource limits per POD + // + // 6a-pre: Content group tracking for statement deduplication + // When multiple statement indices have identical content, they share a single slot in the POD. + // content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p. + for (g, group) in input.statement_content_groups.iter().enumerate() { + for p in 0..target_pods { + // Lower bound: if any statement in the group is proved, the group is used + for &s in group { + model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p])); + } + // Upper bound: if no statements in the group are proved, the group is not used + let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum(); + model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum)); + } + } + for p in 0..target_pods { - // 6a: Statement count - let stmt_sum: Expression = (0..n).map(|g| prove[g][p]).sum(); + // 6a: Unique statement count (unique content groups + anchored key Contains) + // Statements with identical content share a slot, so we count content groups, not indices. + // Anchored key Contains statements are auto-inserted by MainPodBuilder when needed. + // The total must not exceed max_priv_statements (= max_statements - max_public_statements). + let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum(); + let anchored_key_sum: Expression = (0..input.all_anchored_keys.len()) + .map(|ak| anchored_key_used[ak][p]) + .sum(); model.add_constraint(constraint!( - stmt_sum <= (input.params.max_priv_statements() as f64) * pod_used[p] + unique_stmt_sum + anchored_key_sum + <= (input.params.max_priv_statements() as f64) * pod_used[p] )); // 6b: Public statement count (internal public statements + forwarded external premises) @@ -751,7 +844,7 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - merkle_sum <= (input.params.containers.state.max_medium as f64) * pod_used[p] + merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p] )); // 6d: Merkle state transitions @@ -759,7 +852,11 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - mst_sum <= (input.params.containers.transition.max_medium as f64) * pod_used[p] + mst_sum + <= (input + .params + .max_merkle_tree_state_transition_proofs_containers as f64) + * pod_used[p] )); // 6e: Custom predicate verifications @@ -788,31 +885,67 @@ fn try_solve_with_pods( } // Constraint 7: Batch cardinality - // custom_predicate_used[b][p] >= prove[s][p] for all s that use custom predicate b (custom - // predicate is used if any statement uses it) - // custom_predicate_used[b][p] <= sum of prove[s][p] for all s using custom predicate b (custom - // predicate is 0 if no statements use it) - for (b, predicate_id) in all_custom_predicates.iter().enumerate() { + // batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it) + // batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it) + for (b, batch_id) in all_batches.iter().enumerate() { for p in 0..target_pods { let mut sum: Expression = 0.into(); for s in 0..n { - if input.costs[s].custom_predicates_ids.contains(predicate_id) { - model.add_constraint(constraint!(custom_predicate_used[b][p] >= prove[s][p])); + if input.costs[s].custom_batch_ids.contains(batch_id) { + model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p])); sum += prove[s][p]; } } - model.add_constraint(constraint!(custom_predicate_used[b][p] <= sum)); + model.add_constraint(constraint!(batch_used[b][p] <= sum)); } } - // Custom predicate count per POD - for p in 0..target_pods { - let custom_predicate_sum: Expression = (0..all_custom_predicates.len()) - .map(|b| custom_predicate_used[b][p]) - .sum(); - model.add_constraint(constraint!( - custom_predicate_sum <= (input.params.max_custom_predicates as f64) * pod_used[p] - )); + // Constraint 7b: Anchored key tracking + // + // anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p. + // This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p. + // + // If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)): + // - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak + // - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p] + // This ensures overhead is 0 when the producer is in the same POD. + // + // If no Contains produces ak (anchored_key_producers[ak] = None): + // - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak + // - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak + // Auto-insertion is always needed when any user is present. + for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() { + let producer = input.anchored_key_producers[ak_idx]; + + for p in 0..target_pods { + let mut user_sum: Expression = 0.into(); + for s in 0..n { + if input.costs[s].anchored_keys.contains(ak) { + if let Some(prod_idx) = producer { + // Producer exists: only count overhead if producer not in this POD + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p] + )); + } else { + // No producer: always need auto-insertion if user is present + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] >= prove[s][p] + )); + } + user_sum += prove[s][p]; + } + } + + if let Some(prod_idx) = producer { + // If producer is in POD, no auto-insertion needed (overhead = 0) + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p] + )); + } else { + // No producer: overhead is bounded by whether any user is present + model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum)); + } + } } // Constraint 8a: Internal input POD tracking using uses_input. @@ -1014,6 +1147,9 @@ mod tests { output_public_indices: &[], params: ¶ms, max_pods: 20, + all_anchored_keys: &[], + anchored_key_producers: &[], + statement_content_groups: &[], }; let result = solve(&input); @@ -1059,6 +1195,7 @@ mod tests { }; let costs = vec![StatementCost::default(), StatementCost::default()]; + let statement_content_groups = vec![vec![0], vec![1]]; let output_public = vec![1]; let input = SolverInput { @@ -1068,6 +1205,9 @@ mod tests { output_public_indices: &output_public, params: ¶ms, max_pods: 4, + all_anchored_keys: &[], + anchored_key_producers: &[], + statement_content_groups: &statement_content_groups, }; let solution = solve(&input).expect("solver should find a feasible forwarding layout"); diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 9794e60..a61623c 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,10 +1,10 @@ -use std::{fmt, iter}; +use std::fmt; use crate::{ frontend::SignedDict, middleware::{ containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, - OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS, + OperationType, Signature, Statement, TypedValue, Value, ValueRef, }, }; @@ -39,9 +39,10 @@ impl OperationArg { } pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> { - self.value_and_ref() - .and_then(|(r, v)| v.as_int().map(|i| Some((r, i)))) - .flatten() + self.value_and_ref().and_then(|(r, v)| match v.typed() { + &TypedValue::Int(i) => Some((r, i)), + _ => None, + }) } } @@ -70,7 +71,7 @@ impl From<&Value> for OperationArg { impl From<(&Dictionary, &str)> for OperationArg { fn from((dict, key): (&Dictionary, &str)) -> Self { // TODO: Use TryFrom - let value = dict.get(&key.into()).unwrap().unwrap(); + let value = dict.get(&key.into()).cloned().unwrap(); Self::Statement(Statement::Contains( dict.clone().into(), key.into(), @@ -219,24 +220,6 @@ impl Operation { op_impl_oa!(set_insert, SetInsertFromEntries, 3); op_impl_oa!(set_delete, SetDeleteFromEntries, 3); op_impl_oa!(array_update, ArrayUpdateFromEntries, 4); - pub fn replace_value_with_entry(args: Vec>, st: Statement) -> Self { - assert!(args.len() <= BASE_PARAMS.max_statement_args); - let args = args - .into_iter() - .chain(iter::repeat(None)) - .take(BASE_PARAMS.max_statement_args) - .map(|a| match a { - None => OperationArg::Statement(Statement::None), - Some((dict, key)) => OperationArg::from((dict, key)), - }) - .chain(iter::once(OperationArg::Statement(st))) - .collect(); - Self( - OperationType::Native(NativeOperation::ReplaceValueWithEntry), - args, - OperationAux::None, - ) - } pub fn signed_by( msg: impl Into, pk: impl Into, diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 1def7c3..8a47db3 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -83,7 +83,7 @@ mod tests { middleware::{ self, containers::{Array, Dictionary, Set}, - Params, Signer as _, Value, DEFAULT_VD_LIST, + Params, Signer as _, TypedValue, DEFAULT_VD_LIST, }, }; @@ -91,46 +91,48 @@ mod tests { fn test_value_serialization() { // Pairs of values and their expected serialized representations let values = vec![ - (Value::from("hello"), "\"hello\""), - (Value::from(42), "{\"Int\":\"42\"}"), - (Value::from(true), r#"{"Int":"1"}"#), + (TypedValue::String("hello".to_string()), "\"hello\""), + (TypedValue::Int(42), "{\"Int\":\"42\"}"), + (TypedValue::Bool(true), "true"), ( - Value::from(Array::new(vec![Value::from("foo"), Value::from(false)])), - r#"{"inner":[[{"Int":"0"},"foo"],[{"Int":"1"},{"Int":"0"}]]}"#, + TypedValue::Array(Array::new(vec!["foo".into(), false.into()])), + "{\"array\":[\"foo\",false]}", ), ( - Value::from(Dictionary::new(HashMap::from([ - // The set of valid keys is equal to the set of valid JSON keys - ("foo".into(), 123.into()), - // Empty strings are valid JSON keys - (("".into()), "baz".into()), - // Keys can contain whitespace - ((" hi".into()), false.into()), - // Keys can contain special characters - (("!@£$%^&&*()".into()), "".into()), - // Keys can contain _very_ special characters - (("\0".into()), "".into()), - // Keys can contain emojis - (("🥳".into()), "party time!".into()), - ]))), - r#"{"inner":[["!@£$%^&&*()",""],["🥳","party time!"],[" hi",{"Int":"0"}],["foo",{"Int":"123"}],["\u0000",""],["","baz"]]}"#, + TypedValue::Dictionary( + Dictionary::new(HashMap::from([ + // The set of valid keys is equal to the set of valid JSON keys + ("foo".into(), 123.into()), + // Empty strings are valid JSON keys + (("".into()), "baz".into()), + // Keys can contain whitespace + ((" hi".into()), false.into()), + // Keys can contain special characters + (("!@£$%^&&*()".into()), "".into()), + // Keys can contain _very_ special characters + (("\0".into()), "".into()), + // Keys can contain emojis + (("🥳".into()), "party time!".into()), + ])) + ), + "{\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}", ), ( - Value::from(Set::new(HashSet::from(["foo".into(), "bar".into()]))), - r#"{"inner":[["bar"],["foo"]]}"#, + TypedValue::Set(Set::new(HashSet::from(["foo".into(), "bar".into()]))), + "{\"set\":[\"bar\",\"foo\"]}", ), ]; for (value, expected) in values { let serialized = serde_json::to_string(&value).unwrap(); assert_eq!(serialized, expected); - let deserialized: Value = serde_json::from_str(&serialized).unwrap(); + let deserialized: TypedValue = serde_json::from_str(&serialized).unwrap(); assert_eq!( value, deserialized, "value {:#?} should equal deserialized {:#?}", value, deserialized ); - let expected_deserialized: Value = serde_json::from_str(expected).unwrap(); + let expected_deserialized: TypedValue = serde_json::from_str(expected).unwrap(); assert_eq!(value, expected_deserialized); } } @@ -175,10 +177,7 @@ mod tests { "deserialized: {}", serde_json::to_string_pretty(&deserialized).unwrap() ); - assert_eq!( - signed_dict.dict.dump().unwrap(), - deserialized.dict.dump().unwrap() - ); + assert_eq!(signed_dict.dict.kvs(), deserialized.dict.kvs()); assert_eq!(signed_dict.public_key, deserialized.public_key); assert_eq!(signed_dict.signature, deserialized.signature); assert_eq!(signed_dict.verify().is_ok(), deserialized.verify().is_ok()); diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index 7807318..ea528ef 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -174,6 +174,18 @@ fn render_validation_error( "second REQUEST here", ), + ValidationError::InvalidArgumentType { predicate, span } => { + let title = format!("invalid argument type for `{}`", predicate); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "anchored keys not allowed here", + ) + } + ValidationError::DuplicateWildcard { name, span } => { let title = format!("duplicate wildcard: {}", name); render_with_optional_span( @@ -275,17 +287,6 @@ fn render_validation_error( ValidationError::NoRequestBlock => { render_title_only(renderer, "requests must contain a REQUEST block") } - - ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { span } => { - render_with_optional_span( - renderer, - source, - path, - "self-referential predicate literal not allowed in requests", - span.as_ref(), - "not allowed here", - ) - } } } diff --git a/src/lang/error.rs b/src/lang/error.rs index 792d4d8..944988c 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -135,6 +135,12 @@ pub enum ValidationError { span: Option, }, + #[error("Invalid argument type for {predicate}: anchored keys not allowed")] + InvalidArgumentType { + predicate: String, + span: Option, + }, + #[error("Duplicate wildcard in predicate arguments: {name}")] DuplicateWildcard { name: String, span: Option }, @@ -159,9 +165,6 @@ pub enum ValidationError { #[error("Modules must contain at least one predicate definition")] NoPredicatesInModule, - #[error("Self-referential predicate literal not allowed in requests")] - SelfReferentialPredicateLiteralNotAllowedInRequests { span: Option }, - #[error("Requests must contain a REQUEST block")] NoRequestBlock, } diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index dd0052c..4ca7fe4 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -116,8 +116,6 @@ pub enum StatementTmplArg { Literal(LiteralValue), Wildcard(Identifier), AnchoredKey(AnchoredKey), - /// Hash of a same-module predicate, resolved at batch finalization time. - SelfPredicateHash(Identifier), } /// Anchored key: Var["key"] or Var.key @@ -170,13 +168,6 @@ pub enum LiteralValue { Array(LiteralArray), Set(LiteralSet), Dict(LiteralDict), - /// Hash of a native predicate (resolved immediately). - NativePredicateHash(Identifier), - /// Hash of an external module's predicate (resolved immediately). - ExternalPredicateHash { - module: Identifier, - predicate: Identifier, - }, } /// Integer literal @@ -400,9 +391,6 @@ impl fmt::Display for StatementTmplArg { StatementTmplArg::Literal(lit) => write!(f, "{}", lit), StatementTmplArg::Wildcard(id) => write!(f, "{}", id), StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak), - StatementTmplArg::SelfPredicateHash(id) => { - write!(f, "@self_predicate({})", id) - } } } } @@ -434,12 +422,6 @@ impl fmt::Display for LiteralValue { LiteralValue::Array(a) => write!(f, "{}", a), LiteralValue::Set(s) => write!(f, "{}", s), LiteralValue::Dict(d) => write!(f, "{}", d), - LiteralValue::NativePredicateHash(id) => { - write!(f, "@native_predicate({})", id) - } - LiteralValue::ExternalPredicateHash { - module, predicate, .. - } => write!(f, "@external_predicate({}, {})", module, predicate), } } } @@ -787,10 +769,6 @@ pub mod parse { let inner = pair.into_inner().next().unwrap(); match inner.as_rule() { - Rule::predicate_hash_self => { - let id = parse_identifier(inner.into_inner().next().unwrap()); - Ok(StatementTmplArg::SelfPredicateHash(id)) - } Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)), Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))), Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)), @@ -845,16 +823,6 @@ pub mod parse { Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), - Rule::predicate_hash_native => { - let id = parse_identifier(inner.into_inner().next().unwrap()); - Ok(LiteralValue::NativePredicateHash(id)) - } - Rule::predicate_hash_external => { - let mut parts = inner.into_inner(); - let module = parse_identifier(parts.next().unwrap()); - let predicate = parse_identifier(parts.next().unwrap()); - Ok(LiteralValue::ExternalPredicateHash { module, predicate }) - } _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), } } @@ -1136,7 +1104,6 @@ mod tests { AnchoredKeyPath::Dot(id) => id.span = None, } } - StatementTmplArg::SelfPredicateHash(id) => id.span = None, } } } @@ -1172,13 +1139,6 @@ mod tests { clear_literal_spans(&mut pair.value); } } - LiteralValue::NativePredicateHash(id) => id.span = None, - LiteralValue::ExternalPredicateHash { - module, predicate, .. - } => { - module.span = None; - predicate.span = None; - } } } diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index fb00def..b429f4a 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -157,10 +157,8 @@ fn resolve_local_predicate( /// Lower a literal value from AST to middleware Value. /// -/// This is a pure conversion that cannot fail for context-free literals. -/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when -/// external predicate references may appear (e.g. inside containers). -pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { +/// This is a pure conversion that cannot fail. +pub fn lower_literal(lit: &LiteralValue) -> Value { match lit { LiteralValue::Int(i) => Value::from(i.value), LiteralValue::Bool(b) => Value::from(b.value), @@ -192,83 +190,13 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { let dict = containers::Dictionary::new(pairs); Value::from(dict) } - LiteralValue::NativePredicateHash(id) => { - let np = NativePredicate::from_str(&id.name).expect("validated native predicate"); - Value::from(Predicate::Native(np).hash()) - } - LiteralValue::ExternalPredicateHash { .. } => { - unreachable!( - "ExternalPredicateHash must be lowered with context via lower_literal_with_context" - ) - } - } -} - -/// Lower a literal value, resolving external predicate references using the symbol table. -pub fn lower_literal_with_context( - lit: &LiteralValue, - symbols: &SymbolTable, - context: &ResolutionContext, -) -> Result { - match lit { - LiteralValue::ExternalPredicateHash { module, predicate } => { - let pred_or_wc = resolve_predicate_ref( - &PredicateRef::Qualified { - module: module.clone(), - predicate: predicate.clone(), - }, - symbols, - context, - ) - .ok_or_else(|| LoweringError::PredicateNotFound { - name: format!("{}::{}", module.name, predicate.name), - })?; - let pred = match pred_or_wc { - crate::frontend::PredicateOrWildcard::Predicate(p) => p, - _ => unreachable!( - "`resolve_predicate_ref` always returns `PredicateOrWildcard::Predicate` on `PredicateRef::Qualified`" - ) - }; - Ok(Value::from(pred.hash())) - } - LiteralValue::Array(a) => { - let elements: Vec<_> = a - .elements - .iter() - .map(|e| lower_literal_with_context(e, symbols, context)) - .collect::>()?; - Ok(Value::from(containers::Array::new(elements))) - } - LiteralValue::Set(s) => { - let elements: std::collections::HashSet<_> = s - .elements - .iter() - .map(|e| lower_literal_with_context(e, symbols, context)) - .collect::>()?; - Ok(Value::from(containers::Set::new(elements))) - } - LiteralValue::Dict(d) => { - let pairs: HashMap<_, _> = d - .pairs - .iter() - .map(|pair| { - let key = Key::from(pair.key.value.as_str()); - let value = lower_literal_with_context(&pair.value, symbols, context)?; - Ok((key, value)) - }) - .collect::>()?; - Ok(Value::from(containers::Dictionary::new(pairs))) - } - // All other variants are context-free - other => Ok(lower_literal(other)), } } /// Lower a statement argument from AST to BuilderArg. /// -/// Context-free for most arg types. Panics on ExternalPredicateHash inside literals — -/// use `lower_statement_arg_with_context` when external predicate references may appear. -pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { +/// This is a pure conversion that cannot fail. +pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { match arg { StatementTmplArg::Literal(lit) => { let value = lower_literal(lit); @@ -282,25 +210,6 @@ pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { }; BuilderArg::Key(ak.root.name.clone(), key_str) } - StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()), - } -} - -/// Lower a statement argument, resolving external predicate references using the symbol table. -pub fn lower_statement_arg_with_context( - arg: &StatementTmplArg, - symbols: &SymbolTable, - context: &ResolutionContext, -) -> Result { - match arg { - StatementTmplArg::Literal(lit) => { - let value = lower_literal_with_context(lit, symbols, context)?; - Ok(BuilderArg::Literal(value)) - } - StatementTmplArg::SelfPredicateHash(id) => { - Ok(BuilderArg::SelfPredicateHash(id.name.clone())) - } - other => Ok(lower_statement_arg(other)), } } @@ -415,7 +324,7 @@ impl<'a> Lowerer<'a> { // Create a builder with the resolved predicate and desugar let mut builder = StatementTmplBuilder::new(predicate.clone()); for arg in &stmt.args { - let builder_arg = lower_statement_arg_with_context(arg, symbols, &context)?; + let builder_arg = lower_statement_arg(arg); builder = builder.arg(builder_arg); } let desugared = builder.desugar(); @@ -437,9 +346,6 @@ impl<'a> Lowerer<'a> { let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } - BuilderArg::SelfPredicateHash(_) => { - unreachable!("SelfPredicateHash should not appear in request lowering") - } }; mw_args.push(mw_arg); } @@ -493,7 +399,7 @@ impl<'a> Lowerer<'a> { names.push(ak.root.name.clone()); } } - StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} + StatementTmplArg::Literal(_) => {} } } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 482db7a..0d17217 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -123,7 +123,7 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { StatementTmplArg::AnchoredKey(ak) => { wildcards.insert(ak.root.name.clone()); } - StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} + StatementTmplArg::Literal(_) => {} } } diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index ef3d395..49575b5 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -522,7 +522,7 @@ impl Validator { } // Validate arguments - self.validate_statement_args(stmt, wildcard_context)?; + self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; Ok(()) } @@ -530,117 +530,71 @@ impl Validator { fn validate_statement_args( &self, stmt: &StatementTmpl, + pred_info: Option<&PredicateInfo>, wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { - for arg in &stmt.args { - match arg { - StatementTmplArg::Wildcard(id) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&id.name) { - return Err(ValidationError::UndefinedWildcard { - name: id.name.clone(), - pred_name: pred_name.to_string(), - span: id.span, - }); + // For custom predicates, only wildcards and literals are allowed + if matches!( + pred_info.map(|i| &i.kind), + Some(PredicateKind::Custom { .. }) + | Some(PredicateKind::BatchImported { .. }) + | Some(PredicateKind::ModuleImported { .. }) + ) { + for arg in &stmt.args { + match arg { + StatementTmplArg::AnchoredKey(_) => { + return Err(ValidationError::InvalidArgumentType { + predicate: stmt.predicate.predicate_name().to_string(), + span: stmt.span, + }); + } + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); + } } } + StatementTmplArg::Literal(_) => {} } - StatementTmplArg::AnchoredKey(ak) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&ak.root.name) { - return Err(ValidationError::UndefinedWildcard { - name: ak.root.name.clone(), - pred_name: pred_name.to_string(), - span: ak.root.span, - }); + } + } else { + // Native predicates can have anchored keys + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); + } } } - } - StatementTmplArg::Literal(lit) => { - self.validate_literal_value(lit)?; - } - StatementTmplArg::SelfPredicateHash(id) => { - self.validate_self_predicate_hash(id, wildcard_context)?; + StatementTmplArg::AnchoredKey(ak) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&ak.root.name) { + return Err(ValidationError::UndefinedWildcard { + name: ak.root.name.clone(), + pred_name: pred_name.to_string(), + span: ak.root.span, + }); + } + } + } + StatementTmplArg::Literal(_) => {} } } } Ok(()) } - - /// Validate a @self_predicate reference: the name must be a custom predicate in this module. - fn validate_self_predicate_hash( - &self, - id: &Identifier, - wildcard_context: Option<(&str, &WildcardScope)>, - ) -> Result<(), ValidationError> { - // @self_predicate only makes sense inside module predicate definitions - if wildcard_context.is_none() { - return Err( - ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { - span: id.span, - }, - ); - } - // Must refer to a custom predicate defined in this module (not intro/imported) - match self.symbols.predicates.get(&id.name) { - Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()), - _ => Err(ValidationError::UndefinedPredicate { - name: id.name.clone(), - span: id.span, - }), - } - } - - /// Recursively validate a literal value, checking predicate hash references. - fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> { - match lit { - LiteralValue::NativePredicateHash(id) => { - if NativePredicate::from_str(&id.name).is_err() { - return Err(ValidationError::UndefinedPredicate { - name: id.name.clone(), - span: id.span, - }); - } - Ok(()) - } - LiteralValue::ExternalPredicateHash { module, predicate } => { - if let Some(imported) = self.symbols.imported_modules.get(&module.name) { - if !imported.predicate_index.contains_key(&predicate.name) { - return Err(ValidationError::UndefinedPredicate { - name: format!("{}::{}", module.name, predicate.name), - span: predicate.span, - }); - } - } else { - return Err(ValidationError::ModuleNotFound { - name: module.name.clone(), - span: module.span, - }); - } - Ok(()) - } - LiteralValue::Array(a) => { - for elem in &a.elements { - self.validate_literal_value(elem)?; - } - Ok(()) - } - LiteralValue::Set(s) => { - for elem in &s.elements { - self.validate_literal_value(elem)?; - } - Ok(()) - } - LiteralValue::Dict(d) => { - for pair in &d.pairs { - self.validate_literal_value(&pair.value)?; - } - Ok(()) - } - _ => Ok(()), - } - } } #[cfg(test)] @@ -801,7 +755,10 @@ mod tests { module_hash ); let result = parse_and_validate_request(&input, &available_modules); - assert!(result.is_ok()); + assert!(matches!( + result, + Err(ValidationError::InvalidArgumentType { .. }) + )); } #[test] diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index 1c11baa..3002d15 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -49,14 +49,7 @@ custom_predicate_def = { statement_list = { statement+ } -// Predicate hash literals: resolve to the predicate's identity hash as a value. -// @native_predicate and @external_predicate are in literal_value (usable in containers). -// @self_predicate is only in statement_arg (not in containers — deferred resolution). -predicate_hash_native = { "@native_predicate" ~ "(" ~ identifier ~ ")" } -predicate_hash_external = { "@external_predicate" ~ "(" ~ identifier ~ "," ~ identifier ~ ")" } -predicate_hash_self = { "@self_predicate" ~ "(" ~ identifier ~ ")" } - -statement_arg = { predicate_hash_self | literal_value | anchored_key | identifier } +statement_arg = { literal_value | anchored_key | identifier } statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* } // Predicate reference: either qualified (module::predicate) or local (predicate) @@ -81,8 +74,6 @@ literal_value = { literal_bool | literal_raw | literal_string | - predicate_hash_native | - predicate_hash_external | literal_int } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 291f7a6..5674f53 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -578,6 +578,7 @@ mod tests { max_input_pods: 3, max_statements: 31, max_public_statements: 10, + max_operation_args: 5, max_custom_predicate_wildcards: 12, ..Default::default() }; diff --git a/src/lang/module.rs b/src/lang/module.rs index b926871..3ff3d6b 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -11,9 +11,7 @@ use crate::{ lang::{ error::BatchingError, frontend_ast::{ConjunctionType, CustomPredicateDef}, - frontend_ast_lower::{ - lower_statement_arg_with_context, resolve_predicate_ref, ResolutionContext, - }, + frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext}, frontend_ast_split::{SplitChainInfo, SplitResult}, frontend_ast_validate::SymbolTable, }, @@ -347,9 +345,7 @@ fn build_single_batch( })?; } - builder.finish().map_err(|e| BatchingError::Internal { - message: format!("Failed to finalize batch '{}': {}", batch_name, e), - }) + Ok(builder.finish()) } /// Build a statement template with properly resolved predicate references @@ -376,13 +372,7 @@ fn build_statement_with_resolved_refs( let mut builder = StatementTmplBuilder::new(pred_or_wc); for arg in &stmt.args { - let builder_arg = - lower_statement_arg_with_context(arg, symbols, &context).map_err(|e| { - BatchingError::Internal { - message: format!("Failed to lower argument: {}", e), - } - })?; - builder = builder.arg(builder_arg); + builder = builder.arg(lower_statement_arg(arg)); } Ok(builder) @@ -678,110 +668,4 @@ mod tests { PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref)) ); } - - #[test] - fn test_self_predicate_hash_podlang() { - let params = Params::default(); - let module = load_module( - r#" - pred_A(x, y) = AND( - Equal(x, y) - ) - - pred_B(x) = AND( - Equal(x, @self_predicate(pred_A)) - ) - "#, - "test", - ¶ms, - &[], - ) - .unwrap(); - - let batch = &module.batch; - - // pred_B is at index 1, its template should have SelfPredicateHash(0) resolved - // to a Literal containing pred_A's hash after normalization - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_a_hash = crate::middleware::Value::from(Predicate::Custom(pred_a_ref).hash()); - - // Use normalized_predicate to resolve - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let normalized = pred_b_ref.normalized_predicate(); - assert_eq!( - normalized.statements[0].args[1], - crate::middleware::StatementTmplArg::Literal(pred_a_hash) - ); - } - - #[test] - fn test_self_predicate_hash_podlang_cyclic() { - let params = Params::default(); - let module = load_module( - r#" - pred_A(x) = AND( - Equal(x, @self_predicate(pred_B)) - ) - - pred_B(x) = AND( - Equal(x, @self_predicate(pred_A)) - ) - "#, - "test", - ¶ms, - &[], - ) - .unwrap(); - - let batch = &module.batch; - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_a_hash = - crate::middleware::Value::from(Predicate::Custom(pred_a_ref.clone()).hash()); - let pred_b_hash = - crate::middleware::Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); - - // pred_A's normalized form should contain pred_B's hash - let norm_a = pred_a_ref.normalized_predicate(); - assert_eq!( - norm_a.statements[0].args[1], - crate::middleware::StatementTmplArg::Literal(pred_b_hash) - ); - - // pred_B's normalized form should contain pred_A's hash - let norm_b = pred_b_ref.normalized_predicate(); - assert_eq!( - norm_b.statements[0].args[1], - crate::middleware::StatementTmplArg::Literal(pred_a_hash) - ); - } - - #[test] - fn test_native_predicate_hash_podlang() { - let params = Params::default(); - let module = load_module( - r#" - pred_C(x) = AND( - Equal(x, @native_predicate(Equal)) - ) - "#, - "test", - ¶ms, - &[], - ) - .unwrap(); - - let batch = &module.batch; - let pred_c_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_c = pred_c_ref.predicate(); - - // The second arg should be a Literal containing Equal's predicate hash - let equal_hash = crate::middleware::Value::from( - Predicate::Native(crate::middleware::NativePredicate::Equal).hash(), - ); - assert_eq!( - pred_c.statements[0].args[1], - crate::middleware::StatementTmplArg::Literal(equal_hash) - ); - } } diff --git a/src/lang/parser.rs b/src/lang/parser.rs index 1a29113..000e683 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -137,9 +137,6 @@ mod tests { assert_inner(&Rule::anchored_key, "someVar[\"key\"]"); assert_inner(&Rule::literal_value, "true"); assert_inner(&Rule::literal_value, "PublicKey(abc)"); - assert_inner(&Rule::predicate_hash_self, "@self_predicate(foo)"); - assert_inner(&Rule::literal_value, "@native_predicate(Equal)"); - assert_inner(&Rule::literal_value, "@external_predicate(mod_a, pred_b)"); } #[test] @@ -210,33 +207,6 @@ mod tests { "{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ", ); assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes - - // Predicate hash literals - assert_parses(Rule::predicate_hash_native, "@native_predicate(Equal)"); - assert_parses(Rule::predicate_hash_native, "@native_predicate(Lt)"); - assert_parses( - Rule::predicate_hash_external, - "@external_predicate(my_module, my_pred)", - ); - assert_parses(Rule::predicate_hash_self, "@self_predicate(local_pred)"); - - // Predicate hashes inside containers (native and external only) - assert_parses( - Rule::literal_array, - "[1, @native_predicate(Equal), @external_predicate(m, p)]", - ); - assert_parses( - Rule::literal_set, - "#[@native_predicate(Equal), @native_predicate(Lt)]", - ); - assert_parses( - Rule::literal_dict, - "{ \"pred\": @external_predicate(m, p) }", - ); - - // @self_predicate is NOT a literal_value, so it cannot appear inside containers - assert_fails(Rule::test_literal_value, "@self_predicate(local_pred)"); - assert_fails(Rule::literal_array, "[@self_predicate(foo)]"); } #[test] diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index 8e4819d..efca5c9 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -92,7 +92,7 @@ impl StatementTmpl { if i > 0 { write!(w, ", ")?; } - arg.fmt_podlang_with_batch_context(w, batch_context)?; + arg.fmt_podlang(w)?; } write!(w, ")")?; @@ -102,30 +102,7 @@ impl StatementTmpl { impl PrettyPrint for StatementTmplArg { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { - self.fmt_podlang_with_batch_context(w, None) - } -} - -impl StatementTmplArg { - fn fmt_podlang_with_batch_context( - &self, - w: &mut dyn Write, - batch_context: Option<&CustomPredicateBatch>, - ) -> std::fmt::Result { - match self { - StatementTmplArg::SelfPredicateHash(index) => { - if let Some(batch) = batch_context { - if let Some(predicate) = batch.predicates().get(*index) { - write!(w, "@self_predicate({})", predicate.name) - } else { - write!(w, "@self_predicate(self_{})", index) - } - } else { - write!(w, "@self_predicate(self_{})", index) - } - } - other => write!(w, "{}", other), - } + write!(w, "{}", self) } } @@ -154,7 +131,7 @@ impl CustomPredicateBatch { impl PrettyPrint for Value { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { - write!(w, "{}", self.typed) + write!(w, "{}", self.typed()) } } @@ -563,34 +540,6 @@ mod tests { assert_round_trip(&input); } - #[test] - fn test_round_trip_self_predicate_hash() { - let input = r#" - pred_A(x, y) = AND( - Equal(x, y) - ) - - pred_B(x) = AND( - Equal(x, @self_predicate(pred_A)) - ) - "#; - assert_round_trip(input); - } - - #[test] - fn test_round_trip_self_predicate_hash_cyclic() { - let input = r#" - pred_A(x) = AND( - Equal(x, @self_predicate(pred_B)) - ) - - pred_B(x) = AND( - Equal(x, @self_predicate(pred_A)) - ) - "#; - assert_round_trip(input); - } - #[test] fn test_pretty_print_demonstration() { let input = r#" diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 0012251..e6af211 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -169,12 +169,6 @@ pub struct Hash( pub [F; HASH_SIZE], ); -impl Hash { - pub fn raw(self) -> RawValue { - RawValue::from(self) - } -} - impl From for HashOut { fn from(hash: Hash) -> HashOut { HashOut { elements: hash.0 } diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 7c8e744..d01f43f 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -1,260 +1,29 @@ //! This file implements the types defined at //! . -use std::{ - collections::{HashMap, HashSet}, - fmt::{self, Debug}, -}; +use std::collections::{HashMap, HashSet}; use schemars::JsonSchema; -use serde::{ - de::{Error as _, SeqAccess, Visitor}, - ser, Deserialize, Deserializer, Serialize, -}; +use serde::{Deserialize, Deserializer, Serialize}; +use super::serialization::{ordered_map, ordered_set}; #[cfg(feature = "backend_plonky2")] -use crate::backends::plonky2::primitives::merkletree::{self, MerkleProof, MerkleTree}; +use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; use crate::{ backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof, - middleware::{ - db::{mem::MemDB, DB}, - Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH, - }, + middleware::{Error, Hash, Key, RawValue, Result, Value}, }; -#[derive(Clone, Debug)] -pub struct Container { - root: Hash, - db: Box, -} - -impl JsonSchema for Container { - fn schema_name() -> String { - "Container".to_string() - } - - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - // Just use the schema of Vec> since that's what we're actually serializing - Vec::>::json_schema(gen) - } -} - -impl Serialize for Container { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut pairs = self - .iter() - .collect::>>() - .map_err(ser::Error::custom)?; - pairs.sort_by(|(k1, _), (k2, _)| k1.raw().cmp(&k2.raw())); - // Serialize as an array - use serde::ser::SerializeSeq; - let mut seq = serializer.serialize_seq(Some(pairs.len()))?; - for (k, v) in pairs { - if k == v { - seq.serialize_element(&[&v])?; - } else { - seq.serialize_element(&[&k, &v])?; - } - } - seq.end() - } -} - -struct ContainerVisitor; - -impl<'de> Visitor<'de> for ContainerVisitor { - type Value = HashMap; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a sequence of `[Value]` or `[Value, Value]`") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let mut kvs = HashMap::::new(); - while let Some(mut elem) = seq.next_element::>()? { - match elem.len() { - 1 => { - let v = elem.pop().unwrap(); - kvs.insert(v.clone(), v); - } - 2 => { - let (v, k) = (elem.pop().unwrap(), elem.pop().unwrap()); - kvs.insert(k, v); - } - n => { - return Err(A::Error::custom(format!( - "invalid vec length of {n} in container entry" - ))) - } - } - } - - Ok(kvs) - } -} - -impl<'de> Deserialize<'de> for Container { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let kvs = deserializer.deserialize_seq(ContainerVisitor)?; - Ok(Container::new(kvs)) - } -} - -impl PartialEq for Container { - fn eq(&self, other: &Self) -> bool { - self.root == other.root - } -} -impl Eq for Container {} - -fn store_container_mt(db: &mut dyn DB, container: &Container) -> Result<()> { - match db.load_node(container.root) { - Err(e) => return Err(Error::Database(e)), - // Container already exists in the DB - Ok(Some(_)) => return Ok(()), - // Container not existing, we need to save it - Ok(None) => {} - }; - let mut container_copy = Container::empty_with_db(db.clone_box()); - for kv_result in container.iter() { - let (k, v) = kv_result?; - container_copy.insert(k, v)?; - } - Ok(()) -} - -fn store_value(db: &mut dyn DB, v: Value) -> Result<()> { - match &v.typed { - TypedValue::Set(Set { inner }) - | TypedValue::Dictionary(Dictionary { inner }) - | TypedValue::Array(Array { inner }) => { - if db.is_persistent() { - store_container_mt(db, inner)?; - } - db.store_value(v).map_err(Error::Database)? - } - _ => db.store_value(v).map_err(Error::Database)?, - } - Ok(()) -} - -fn load_value(db: &dyn DB, value_raw: RawValue) -> Result { - match db.load_value(value_raw) { - Err(e) => Err(Error::Database(e)), - Ok(Some(v)) => Ok(v), - Ok(None) => Err(Error::custom(format!( - "Value from {value_raw} not found in DB" - ))), - } -} - -impl Container { - fn mt(&self) -> MerkleTree { - MerkleTree::from_db(self.root, self.db.clone()) - } - pub fn new(kvs: HashMap) -> Self { - let db = Box::new(MemDB::new()); - let mut container = Self::empty_with_db(db); - for (k, v) in kvs { - container.insert(k, v).expect("no duplicates, no db errors"); - } - container - } - pub fn empty_with_db(db: Box) -> Self { - Self::from_db(EMPTY_HASH, db).expect("EMPTY_HASH exists implicitly") - } - pub fn from_db(root: Hash, db: Box) -> Result { - // Make sure the root exists in the db - let _ = merkletree::load_node(db.as_ref(), root)?; - Ok(Self { root, db }) - } - pub fn commitment(&self) -> Hash { - self.root - } - pub fn get(&self, key_raw: RawValue) -> Result> { - Ok(match self.mt().get(&key_raw)? { - Some(value_raw) => Some(load_value(self.db.as_ref(), value_raw)?), - None => None, - }) - } - pub fn prove(&self, key_raw: RawValue) -> Result<(Value, MerkleProof)> { - let (value_raw, mtp) = self.mt().prove(&key_raw)?; - let value = load_value(self.db.as_ref(), value_raw)?; - Ok((value, mtp)) - } - pub fn prove_nonexistence(&self, key_raw: RawValue) -> Result { - Ok(self.mt().prove_nonexistence(&key_raw)?) - } - pub fn insert(&mut self, key: Value, value: Value) -> Result { - let (key_raw, value_raw) = (key.raw(), value.raw()); - store_value(self.db.as_mut(), key)?; - store_value(self.db.as_mut(), value)?; - let mut mt = self.mt(); - let mtp = mt.insert(&key_raw, &value_raw)?; - self.root = mt.root(); - Ok(mtp) - } - pub fn update( - &mut self, - key_raw: RawValue, - value: Value, - ) -> Result { - let value_raw = value.raw(); - store_value(self.db.as_mut(), value)?; - let mut mt = self.mt(); - let mtp = mt.update(&key_raw, &value_raw)?; - self.root = mt.root(); - Ok(mtp) - } - pub fn delete(&mut self, key_raw: RawValue) -> Result { - let mut mt = self.mt(); - let mtp = mt.delete(&key_raw)?; - self.root = mt.root(); - Ok(mtp) - } - pub fn verify( - root: Hash, - proof: &MerkleProof, - key_raw: RawValue, - value_raw: RawValue, - ) -> Result<()> { - Ok(MerkleTree::verify(root, proof, &key_raw, &value_raw)?) - } - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key_raw: RawValue) -> Result<()> { - Ok(MerkleTree::verify_nonexistence(root, proof, &key_raw)?) - } - pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) - } - pub fn iter(&self) -> impl Iterator> { - let db = self.db.clone(); - self.mt().iter().map(move |(key_raw, value_raw)| { - let key = load_value(db.as_ref(), key_raw)?; - let value = load_value(db.as_ref(), value_raw)?; - Ok((key, value)) - }) - } - /// This is an expensive operation - pub fn dump(&self) -> Result> { - self.iter().collect() - } -} - /// Dictionary: the user original keys and values are hashed to be used in the leaf. /// leaf.key=hash(original_key) /// leaf.value=hash(original_value) -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct Dictionary { - pub(crate) inner: Container, + #[serde(skip)] + #[schemars(skip)] + mt: MerkleTree, + #[serde(serialize_with = "ordered_map")] + kvs: HashMap, } #[macro_export] @@ -265,371 +34,255 @@ macro_rules! dict { ({ $($key:expr => $val:expr),* }) => ({ let mut map = ::std::collections::HashMap::new(); $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* - $crate::middleware::containers::Dictionary::new(map) + $crate::middleware::containers::Dictionary::new( map) }); } -// TODO: Replace all methods that receive a `&Key` by either `impl Into` for write -// methods and `impl AsRef` for read methods. -// TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a -// trait? - impl Dictionary { pub fn new(kvs: HashMap) -> Self { + let kvs_raw: HashMap = + kvs.iter().map(|(k, v)| (k.raw(), v.raw())).collect(); Self { - inner: Container::new( - kvs.into_iter() - .map(|(k, v)| (Value::from(k.name), v)) - .collect(), - ), + mt: MerkleTree::new(&kvs_raw), + kvs, } } - pub fn empty_with_db(db: Box) -> Self { - Self { - inner: Container::empty_with_db(db), - } - } - pub fn from_db(root: Hash, db: Box) -> Result { - Ok(Self { - inner: Container::from_db(root, db)?, - }) - } pub fn commitment(&self) -> Hash { - self.inner.commitment() + self.mt.root() } - pub fn get(&self, key: &Key) -> Result> { - self.inner.get(key.raw()) + pub fn get(&self, key: &Key) -> Result<&Value> { + self.kvs + .get(key) + .ok_or_else(|| Error::custom(format!("key \"{}\" not found", key.name()))) } - pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> { - self.inner.prove(key.raw()) + pub fn prove(&self, key: &Key) -> Result<(&Value, MerkleProof)> { + let (_, mtp) = self.mt.prove(&key.raw())?; + let value = self.kvs.get(key).expect("key exists"); + Ok((value, mtp)) } pub fn prove_nonexistence(&self, key: &Key) -> Result { - self.inner.prove_nonexistence(key.raw()) + Ok(self.mt.prove_nonexistence(&key.raw())?) } pub fn insert(&mut self, key: &Key, value: &Value) -> Result { - self.inner - .insert(Value::from(key.name.clone()), value.clone()) + let mtp = self.mt.insert(&key.raw(), &value.raw())?; + self.kvs.insert(key.clone(), value.clone()); + Ok(mtp) } pub fn update(&mut self, key: &Key, value: &Value) -> Result { - self.inner.update(key.raw(), value.clone()) + let mtp = self.mt.update(&key.raw(), &value.raw())?; + self.kvs.insert(key.clone(), value.clone()); + Ok(mtp) } pub fn delete(&mut self, key: &Key) -> Result { - self.inner.delete(key.raw()) + let mtp = self.mt.delete(&key.raw())?; + self.kvs.remove(key); + Ok(mtp) } pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { - Container::verify(root, proof, key.raw(), value.raw()) + let key = key.raw(); + Ok(MerkleTree::verify(root, proof, &key, &value.raw())?) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { - Container::verify_nonexistence(root, proof, key.raw()) + let key = key.raw(); + Ok(MerkleTree::verify_nonexistence(root, proof, &key)?) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - Container::verify_state_transition(proof) + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) } - pub fn iter(&self) -> impl Iterator> + use<'_> { - self.inner.iter().map(|r| match r { - Ok((key, value)) => Ok(( - key.as_string() - .ok_or_else(|| Error::custom("dictionary: key is not string"))?, - value, - )), - Err(e) => Err(e), - }) - } - /// This is an expensive operation - pub fn dump(&self) -> Result> { - self.iter().collect() + // TODO: Rename to dict to be consistent maybe? + pub fn kvs(&self) -> &HashMap { + &self.kvs } } impl PartialEq for Dictionary { fn eq(&self, other: &Self) -> bool { - self.inner.eq(&other.inner) + self.mt.root() == other.mt.root() } } impl Eq for Dictionary {} +impl<'de> Deserialize<'de> for Dictionary { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Aux { + #[serde(serialize_with = "ordered_map")] + kvs: HashMap, + } + let aux = Aux::deserialize(deserializer)?; + Ok(Dictionary::new(aux.kvs)) + } +} + /// Set: the value field of the leaf is unused, and the key contains the hash of the element. /// leaf.key=hash(original_value) /// leaf.value=0 -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct Set { - pub(crate) inner: Container, + #[serde(skip)] + #[schemars(skip)] + mt: MerkleTree, + #[serde(serialize_with = "ordered_set")] + set: HashSet, } impl Set { pub fn new(set: HashSet) -> Self { + let kvs_raw: HashMap = set + .iter() + .map(|e| { + let rv = e.raw(); + (rv, rv) + }) + .collect(); Self { - inner: Container::new(set.into_iter().map(|v| (v.clone(), v)).collect()), + mt: MerkleTree::new(&kvs_raw), + set, } } - pub fn empty_with_db(db: Box) -> Self { - Self { - inner: Container::empty_with_db(db), - } - } - pub fn from_db(root: Hash, db: Box) -> Result { - Ok(Self { - inner: Container::from_db(root, db)?, - }) - } pub fn commitment(&self) -> Hash { - self.inner.commitment() + self.mt.root() } - pub fn contains(&self, value: &Value) -> Result { - Ok(self.inner.get(value.raw())?.is_some()) + pub fn contains(&self, value: &Value) -> bool { + self.set.contains(value) } pub fn prove(&self, value: &Value) -> Result { - let (_, proof) = self.inner.prove(value.raw())?; + let rv = value.raw(); + let (_, proof) = self.mt.prove(&rv)?; Ok(proof) } pub fn prove_nonexistence(&self, value: &Value) -> Result { - self.inner.prove_nonexistence(value.raw()) + let rv = value.raw(); + Ok(self.mt.prove_nonexistence(&rv)?) } pub fn insert(&mut self, value: &Value) -> Result { - self.inner.insert(value.clone(), value.clone()) + let raw_value = value.raw(); + let mtp = self.mt.insert(&raw_value, &raw_value)?; + self.set.insert(value.clone()); + Ok(mtp) } pub fn delete(&mut self, value: &Value) -> Result { - self.inner.delete(value.raw()) + let mtp = self.mt.delete(&value.raw())?; + self.set.remove(value); + Ok(mtp) } pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - Container::verify(root, proof, value.raw(), value.raw()) + let rv = value.raw(); + Ok(MerkleTree::verify(root, proof, &rv, &rv)?) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - Container::verify_nonexistence(root, proof, value.raw()) + let rv = value.raw(); + Ok(MerkleTree::verify_nonexistence(root, proof, &rv)?) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - Container::verify_state_transition(proof) + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) } - pub fn iter(&self) -> impl Iterator> + use<'_> { - self.inner.iter().map(|r| match r { - Ok((key, value)) => { - if key != value { - return Err(Error::custom("set: key != value")); - } - Ok(value) - } - Err(e) => Err(e), - }) - } - /// This is an expensive operation - pub fn dump(&self) -> Result> { - self.iter().collect() + pub fn set(&self) -> &HashSet { + &self.set } } impl PartialEq for Set { fn eq(&self, other: &Self) -> bool { - self.inner.eq(&other.inner) + self.mt.root() == other.mt.root() } } impl Eq for Set {} +impl<'de> Deserialize<'de> for Set { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize, JsonSchema)] + struct Aux { + #[serde(serialize_with = "ordered_set")] + set: HashSet, + } + let aux = Aux::deserialize(deserializer)?; + Ok(Set::new(aux.set)) + } +} + /// Array: the elements are placed at the value field of each leaf, and the key field is just the /// array index (integer). /// leaf.key=i /// leaf.value=original_value -/// Due to its construction this should be seen as a sparse array, where there can be gaps -/// (unused indices). -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, JsonSchema)] pub struct Array { - pub(crate) inner: Container, + #[serde(skip)] + #[schemars(skip)] + mt: MerkleTree, + array: Vec, } impl Array { pub fn new(array: Vec) -> Self { + let kvs_raw: HashMap = array + .iter() + .enumerate() + .map(|(i, e)| (RawValue::from(i as i64), e.raw())) + .collect(); + Self { - inner: Container::new( - array - .into_iter() - .enumerate() - .map(|(i, v)| (Value::from(i as i64), v)) - .collect(), - ), + mt: MerkleTree::new(&kvs_raw), + array, } } - pub fn empty_with_db(db: Box) -> Self { - Self { - inner: Container::empty_with_db(db), - } - } - pub fn from_db(root: Hash, db: Box) -> Result { - Ok(Self { - inner: Container::from_db(root, db)?, - }) - } pub fn commitment(&self) -> Hash { - self.inner.commitment() + self.mt.root() } - pub fn get(&self, i: usize) -> Result> { - self.inner.get(Value::from(i as i64).raw()) - } - pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> { - self.inner.prove(Value::from(i as i64).raw()) - } - pub fn insert(&mut self, i: usize, value: Value) -> Result { - self.inner.insert(Value::from(i as i64), value) - } - pub fn delete(&mut self, i: usize) -> Result { - self.inner.delete(Value::from(i as i64).raw()) - } - pub fn update(&mut self, i: usize, value: &Value) -> Result { - self.inner - .update(Value::from(i as i64).raw(), value.clone()) - } - pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { - Container::verify(root, proof, Value::from(i as i64).raw(), value.raw()) - } - pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - Container::verify_state_transition(proof) - } - pub fn iter(&self) -> impl Iterator> + use<'_> { - self.inner.iter().map(|r| match r { - Ok((key, value)) => { - let index = key - .as_int() - .ok_or_else(|| Error::custom("array: key is not int"))?; - Ok((index as usize, value)) - } - Err(e) => Err(e), + pub fn get(&self, i: usize) -> Result<&Value> { + self.array.get(i).ok_or_else(|| { + Error::custom(format!("index {} out of bounds 0..{}", i, self.array.len())) }) } - /// This is an expensive operation - pub fn dump(&self) -> Result> { - self.iter().collect() + pub fn prove(&self, i: usize) -> Result<(&Value, MerkleProof)> { + let (_, mtp) = self.mt.prove(&RawValue::from(i as i64))?; + let value = self.array.get(i).expect("valid index"); + Ok((value, mtp)) + } + pub fn update(&mut self, i: usize, value: &Value) -> Result { + let mtp = self.mt.update(&(i as i64).into(), &value.raw())?; + self.array[i] = value.clone(); + Ok(mtp) + } + pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { + Ok(MerkleTree::verify( + root, + proof, + &RawValue::from(i as i64), + &value.raw(), + )?) + } + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) + } + pub fn array(&self) -> &[Value] { + &self.array } } impl PartialEq for Array { fn eq(&self, other: &Self) -> bool { - self.inner.eq(&other.inner) + self.mt.root() == other.mt.root() } } impl Eq for Array {} -#[cfg(test)] -mod tests { - use super::*; - use crate::middleware::db::mem::MemDB; - - fn test_databases(test_fn: &dyn Fn(Box)) { - let db = MemDB::new(); - test_fn(Box::new(db)); - #[cfg(feature = "db_rocksdb")] - { - use crate::middleware::db; - let db = db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(); - test_fn(Box::new(db)); +impl<'de> Deserialize<'de> for Array { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize, JsonSchema)] + struct Aux { + array: Vec, } - } - - fn _test_dict(db: Box) { - let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&Key::from("a"), &Value::from(1)).unwrap(); - dict0.insert(&Key::from("b"), &Value::from(2)).unwrap(); - dict0.update(&Key::from("a"), &Value::from(3)).unwrap(); - dict0.insert(&Key::from("c"), &Value::from(4)).unwrap(); - dict0.delete(&Key::from("c")).unwrap(); - let kvs0 = dict0.dump().unwrap(); - assert_eq!( - kvs0, - [ - ("a".to_string(), Value::from(3)), - ("b".to_string(), Value::from(2)) - ] - .into_iter() - .collect() - ); - let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap(); - let kvs1 = dict1.dump().unwrap(); - assert_eq!(kvs0, kvs1); - } - - fn _test_set(db: Box) { - let mut set0 = Set::empty_with_db(db.clone()); - set0.insert(&Value::from(1)).unwrap(); - set0.insert(&Value::from(2)).unwrap(); - set0.insert(&Value::from(3)).unwrap(); - set0.delete(&Value::from(2)).unwrap(); - - let s0 = set0.dump().unwrap(); - assert_eq!(s0, [Value::from(1), Value::from(3)].into_iter().collect()); - let set1 = Set::from_db(set0.commitment(), db).unwrap(); - let s1 = set1.dump().unwrap(); - assert_eq!(s0, s1); - } - - fn _test_array(db: Box) { - let mut arr0 = Array::empty_with_db(db.clone()); - arr0.insert(0, Value::from("a")).unwrap(); - arr0.insert(1, Value::from("b")).unwrap(); - arr0.insert(2, Value::from("c")).unwrap(); - arr0.delete(1).unwrap(); - - let a0 = arr0.dump().unwrap(); - assert_eq!( - a0, - [(0, Value::from("a")), (2, Value::from("c"))] - .into_iter() - .collect() - ); - let arr1 = Array::from_db(arr0.commitment(), db).unwrap(); - let a1 = arr1.dump().unwrap(); - assert_eq!(a0, a1); - } - - fn _test_nested(db: Box) { - let mut nested = Dictionary::empty_with_db(db.clone()); - nested.insert(&Key::from("a"), &Value::from(1)).unwrap(); - nested.insert(&Key::from("b"), &Value::from(2)).unwrap(); - let nested_kvs0 = nested.dump().unwrap(); - - let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&Key::from("x"), &Value::from(1)).unwrap(); - dict0 - .insert(&Key::from("y"), &Value::from(nested.clone())) - .unwrap(); - let kvs0 = dict0.dump().unwrap(); - - assert_eq!( - kvs0, - [ - ("x".to_string(), Value::from(1)), - ("y".to_string(), Value::from(nested)) - ] - .into_iter() - .collect() - ); - - let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap(); - let kvs1 = dict1.dump().unwrap(); - assert_eq!(kvs0, kvs1); - - match &kvs1["y"].typed { - TypedValue::Dictionary(d) => { - let nested_kvs1 = d.dump().unwrap(); - assert_eq!(nested_kvs0, nested_kvs1); - } - _ => unreachable!(), - } - } - - #[test] - fn test_dict() { - test_databases(&_test_dict); - } - - #[test] - fn test_set() { - test_databases(&_test_set); - } - - #[test] - fn test_array() { - test_databases(&_test_array); - } - - #[test] - fn test_nested() { - test_databases(&_test_nested); + let aux = Aux::deserialize(deserializer)?; + Ok(Array::new(aux.array)) } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index e5c7285..13cc387 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -49,9 +49,6 @@ pub enum StatementTmplArg { // AnchoredKey where the origin is a wildcard AnchoredKey(Wildcard, Key), Wildcard(Wildcard), - /// Reference to a same-batch predicate's identity hash, resolved at verification time. - /// The usize is the predicate index within the batch. - SelfPredicateHash(usize), } #[derive(Clone, Copy)] @@ -60,7 +57,6 @@ pub enum StatementTmplArgPrefix { Literal = 1, AnchoredKey = 2, WildcardLiteral = 3, - SelfPredicateHash = 4, } impl From for F { @@ -72,12 +68,11 @@ impl From for F { impl ToFields for StatementTmplArg { fn to_fields(&self) -> Vec { // Encoding: - // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) - // Literal(v) => (1, [v ], 0, 0, 0, 0) - // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) - // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) - // SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0) - // In all cases, we pad to 2 * hash_size + 1 = 9 field elements + // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) + // Literal(v) => (1, [v ], 0, 0, 0, 0) + // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) + // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) + // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements match self { StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) .chain(iter::repeat(F::ZERO)) @@ -102,13 +97,6 @@ impl ToFields for StatementTmplArg { .take(Params::statement_tmpl_arg_size()) .collect_vec() } - StatementTmplArg::SelfPredicateHash(index) => { - iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash)) - .chain(iter::once(F::from_canonical_usize(*index))) - .chain(iter::repeat(F::ZERO)) - .take(Params::statement_tmpl_arg_size()) - .collect_vec() - } } } } @@ -125,7 +113,6 @@ impl fmt::Display for StatementTmplArg { write!(f, "]") } Self::Wildcard(v) => v.fmt(f), - Self::SelfPredicateHash(i) => write!(f, "::self.{}", i), } } } @@ -436,7 +423,7 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] enum CustomPredicateBatchData { Full { #[serde(skip)] @@ -449,20 +436,6 @@ enum CustomPredicateBatchData { }, } -// Explicit implementation of Debug to skip the merkle tree -impl fmt::Debug for CustomPredicateBatchData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self { - Self::Full { mt, predicates } => f - .debug_struct("Full") - .field("id", &mt.root()) - .field("predicates", &predicates) - .finish(), - Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(), - } - } -} - // TODO: Rename Batch for Module everywhere in the code base impl CustomPredicateBatchData { fn new_full(predicates: Vec) -> Self { @@ -596,44 +569,6 @@ impl CustomPredicateRef { pub fn predicate(&self) -> &CustomPredicate { &self.batch.predicates()[self.index] } - - /// Returns a copy of this predicate with all `SelfPredicateHash(i)` args - /// resolved to `Literal(hash(Custom(batch, i)))`. - pub fn normalized_predicate(&self) -> CustomPredicate { - let pred = self.predicate(); - let normalized_statements = pred - .statements - .iter() - .map(|st_tmpl| { - let args = st_tmpl - .args - .iter() - .map(|arg| match arg { - StatementTmplArg::SelfPredicateHash(i) => { - let pred_hash = Predicate::Custom(CustomPredicateRef { - batch: self.batch.clone(), - index: *i, - }) - .hash(); - StatementTmplArg::Literal(Value::from(pred_hash)) - } - other => other.clone(), - }) - .collect(); - StatementTmpl { - pred_or_wc: st_tmpl.pred_or_wc.clone(), - args, - } - }) - .collect(); - CustomPredicate { - name: pred.name.clone(), - conjunction: pred.conjunction, - statements: normalized_statements, - args_len: pred.args_len, - wildcard_names: pred.wildcard_names.clone(), - } - } } #[cfg(test)] @@ -644,7 +579,7 @@ mod tests { middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl, - StatementTmplArg, ValueRef, + StatementTmplArg, }, }; @@ -667,9 +602,6 @@ mod tests { fn names(names: &[&str]) -> Vec { names.iter().map(|s| s.to_string()).collect() } - fn value_ref(v: impl Into) -> ValueRef { - v.into() - } #[allow(clippy::upper_case_acronyms)] type STA = StatementTmplArg; @@ -718,7 +650,7 @@ mod tests { }); let custom_statement = Statement::Custom( CustomPredicateRef::new(cust_pred_batch.clone(), 0), - vec![value_ref(d0.clone())], + vec![Value::from(d0.clone())], ); let custom_deduction = Operation::Custom( @@ -850,7 +782,7 @@ mod tests { // Example statement let ethdos_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], + vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], ); // Copies should work. @@ -859,7 +791,7 @@ mod tests { // This could arise as the inductive step. let ethdos_ind_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), - vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], + vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], ); assert!(Operation::Custom( @@ -874,12 +806,12 @@ mod tests { let ethdos_facts = vec![ Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)], + vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)], ), Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)), Statement::Custom( CustomPredicateRef::new(eth_friend_batch.clone(), 0), - vec![value_ref("Charlie"), value_ref("Bob")], + vec![Value::from("Charlie"), Value::from("Bob")], ), ]; @@ -891,173 +823,4 @@ mod tests { Ok(()) } - - #[test] - fn test_normalized_predicate() -> Result<()> { - let params = Params::default(); - - // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) - let pred_a = CustomPredicate::and( - ¶ms, - "pred_A".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], - )], - 2, - names(&["x", "y"]), - )?; - let pred_b = CustomPredicate::and( - ¶ms, - "pred_B".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], - )], - 1, - names(&["x"]), - )?; - let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); - - // Compute expected pred_A hash - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); - - // Normalize pred_B - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let normalized = pred_b_ref.normalized_predicate(); - - // The second arg should be resolved to Literal(pred_A_hash) - assert_eq!( - normalized.statements[0].args[1], - STA::Literal(expected_hash) - ); - - // First arg should be unchanged (still a wildcard) - assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0))); - - Ok(()) - } - - #[test] - fn test_self_predicate_hash_check() -> Result<()> { - let params = Params::default(); - - // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) - let pred_a = CustomPredicate::and( - ¶ms, - "pred_A".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], - )], - 2, - names(&["x", "y"]), - )?; - let pred_b = CustomPredicate::and( - ¶ms, - "pred_B".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], - )], - 1, - names(&["x"]), - )?; - let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); - - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); - - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - - // Construct a valid operation: Equal(some_value, pred_a_hash) - let some_value = Value::from(42); - let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; - - // The output statement - let output_st = Statement::Custom( - pred_b_ref.clone(), - vec![ValueRef::Literal(some_value.clone())], - ); - - // This should pass - assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?); - - // Now try with wrong hash, should fail - let wrong_hash = Value::from(999); - let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)]; - assert!(Operation::Custom(pred_b_ref, bad_op_args) - .check(¶ms, &output_st) - .is_err()); - - Ok(()) - } - - #[test] - fn test_self_predicate_hash_cyclic() -> Result<()> { - let params = Params::default(); - - // Build a batch where pred_A references pred_B's hash and vice versa - // pred_A = Equal(x, SelfPredicateHash(1)) - // pred_B = Equal(x, SelfPredicateHash(0)) - let pred_a = CustomPredicate::and( - ¶ms, - "pred_A".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)], - )], - 1, - names(&["x"]), - )?; - let pred_b = CustomPredicate::and( - ¶ms, - "pred_B".into(), - vec![st( - P::Native(NP::Equal), - vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], - )], - 1, - names(&["x"]), - )?; - let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); - - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash()); - let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); - - // pred_A's normalized form should reference pred_B's hash - let norm_a = pred_a_ref.normalized_predicate(); - assert_eq!( - norm_a.statements[0].args[1], - STA::Literal(pred_b_hash.clone()) - ); - - // pred_B's normalized form should reference pred_A's hash - let norm_b = pred_b_ref.normalized_predicate(); - assert_eq!( - norm_b.statements[0].args[1], - STA::Literal(pred_a_hash.clone()) - ); - - // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass - let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; - let st_a = Statement::Custom( - pred_a_ref.clone(), - vec![ValueRef::Literal(pred_b_hash.clone())], - ); - assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?); - - // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass - let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; - let st_b = Statement::Custom( - pred_b_ref.clone(), - vec![ValueRef::Literal(pred_a_hash.clone())], - ); - assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?); - - Ok(()) - } } diff --git a/src/middleware/db/mem.rs b/src/middleware/db/mem.rs deleted file mode 100644 index 71211fa..0000000 --- a/src/middleware/db/mem.rs +++ /dev/null @@ -1,62 +0,0 @@ -use super::*; - -/// MemDB implements the DB trait in a in-memory HashMap. -#[derive(Clone, Debug, Default)] -pub struct MemDB { - nodes: Arc>>, - values: Arc>>, -} - -impl MemDB { - pub fn new() -> Self { - Self::default() - } -} - -impl merkletree::db::DB for MemDB { - fn load_node(&self, hash: Hash) -> anyhow::Result> { - let nodes = self.nodes.read().expect("lock not poisoned"); - - if hash == EMPTY_HASH { - return Ok(Some(merkletree::Node::Intermediate( - merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), - ))); - } - - Ok(nodes.get(&hash).cloned()) - } - - fn store_node(&mut self, node: merkletree::Node) -> anyhow::Result<()> { - let mut nodes = self.nodes.write().expect("lock not poisoned"); - nodes.insert(node.hash(), node); - Ok(()) - } -} - -impl DB for MemDB { - fn load_value(&self, raw: RawValue) -> anyhow::Result> { - let values = self.values.read().expect("lock not poisoned"); - - Ok(values.get(&raw).cloned()) - } - fn store_value(&mut self, value: Value) -> anyhow::Result<()> { - let mut values = self.values.write().expect("lock not poisoned"); - let value_raw = value.raw(); - if let Some(old_value) = values.get(&value_raw) { - let old_is_raw = old_value.is_raw(); - // If we had a non-RawValue stored don't overwrite it (specially not with a - // RawValue). Also skip redundant RawValue overwrite. - if !old_is_raw || value.is_raw() { - return Ok(()); - } - } - values.insert(value_raw, value); - Ok(()) - } - fn is_persistent(&self) -> bool { - false - } - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/src/middleware/db/mod.rs b/src/middleware/db/mod.rs deleted file mode 100644 index bb32a67..0000000 --- a/src/middleware/db/mod.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{ - collections::HashMap, - fmt::Debug, - sync::{Arc, RwLock}, -}; - -use dyn_clone::DynClone; - -#[cfg(feature = "backend_plonky2")] -use crate::backends::plonky2::primitives::merkletree::{self}; -use crate::middleware::{Hash, RawValue, Value, EMPTY_HASH}; - -pub mod mem; -#[cfg(feature = "db_rocksdb")] -pub mod rocks; - -// Trait for database that stores values. Must be cheap to clone. -pub trait DB: Debug + DynClone + Sync + Send + merkletree::db::DB { - fn load_value(&self, raw: RawValue) -> anyhow::Result>; - // If the DB is persistent, for containers only the root needs to be stored because the - // Container type makes sure the underlying merkle tree is stored in the DB independently, so - // that it can be recovered back just with the root and the DB. - // If the value is RawValue and a previous non-RawValue exists, no store overwrite it. - // should be done. If the value is non-RawValue and a previous RawValue exists, store - // should overwrite it. - fn store_value(&mut self, value: Value) -> anyhow::Result<()>; - fn is_persistent(&self) -> bool; - fn clone_box(&self) -> Box; -} -dyn_clone::clone_trait_object!(DB); diff --git a/src/middleware/db/rocks.rs b/src/middleware/db/rocks.rs deleted file mode 100644 index be5ca4a..0000000 --- a/src/middleware/db/rocks.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::{fmt, path::Path, sync::Arc}; - -use anyhow::{anyhow, Result}; -use rocksdb::{Options, TransactionDB, TransactionDBOptions}; - -use super::*; - -fn node_key(hash: Hash) -> Vec { - let mut k = Vec::with_capacity(2 + 4); - k.extend_from_slice(b"n/"); - k.extend_from_slice(&RawValue::from(hash).to_bytes()); - k -} - -fn value_key(raw: RawValue) -> Vec { - let mut k = Vec::with_capacity(2 + 4); - k.extend_from_slice(b"v/"); - k.extend_from_slice(&raw.to_bytes()); - k -} - -#[derive(Clone)] -pub struct RocksDB { - db: Arc, -} - -impl fmt::Debug for RocksDB { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "RocksDB(path: {:?})", self.db.path()) - } -} - -impl RocksDB { - pub fn open(path: impl AsRef) -> Result { - let mut options = Options::default(); - options.create_if_missing(true); - let txn_options = TransactionDBOptions::default(); - let inner = - TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?; - Ok(Self { - db: Arc::new(inner), - }) - } -} - -impl merkletree::db::DB for RocksDB { - fn load_node(&self, hash: Hash) -> Result> { - if hash == EMPTY_HASH { - return Ok(Some(merkletree::Node::Intermediate( - merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), - ))); - } - - match self.db.get(node_key(hash))? { - None => Ok(None), - Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)), - } - } - - fn store_node(&mut self, node: merkletree::Node) -> Result<()> { - self.db - .put(node_key(node.hash()), node.encode()?) - .map_err(|e| anyhow!("rocksdb transaction put failed: {e}")) - } -} - -impl DB for RocksDB { - fn load_value(&self, raw: RawValue) -> anyhow::Result> { - match self.db.get(value_key(raw))? { - None => Ok(None), - Some(bytes) => Ok(Some({ - if bytes.is_empty() { - Value::from(raw) - } else { - Value::from_bytes(bytes.as_ref(), self.clone_box())? - } - })), - } - } - fn store_value(&mut self, value: Value) -> anyhow::Result<()> { - let value_key = value_key(value.raw()); - let tx = self.db.transaction(); - if let Some(old_value_bytes) = tx.get_for_update(&value_key, true)? { - let is_raw = old_value_bytes.is_empty(); - // If we had a non-RawValue stored don't overwrite it (specially not with a - // RawValue). Also skip redundant RawValue overwrite. - if !is_raw || (is_raw && value.is_raw()) { - return Ok(()); - } - } - let value_bytes = if value.is_raw() { - // For RawValue we store an empty vector because it's a duplicate of the key. - // This way we can easily check for RawValue without decoding. - vec![] - } else { - Value::to_bytes(&value) - }; - tx.put(value_key, value_bytes)?; - Ok(tx.commit()?) - } - fn is_persistent(&self) -> bool { - true - } - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} diff --git a/src/middleware/error.rs b/src/middleware/error.rs index f7ad765..74605da 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -72,10 +72,6 @@ pub enum Error { }, #[error(transparent)] Tree(#[from] crate::backends::plonky2::primitives::merkletree::error::TreeError), - #[error(transparent)] - Json(#[from] serde_json::Error), - #[error("database error: {0}")] - Database(anyhow::Error), } impl Debug for Error { @@ -168,7 +164,7 @@ impl Error { pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { new!(UnsatisfiedCustomPredicateDisjunction(pred)) } - pub(crate) fn custom(s: impl Into) -> Self { - new!(Custom(s.into())) + pub(crate) fn custom(s: String) -> Self { + new!(Custom(s)) } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index d212ca8..542f5b2 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,13 +1,16 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. +use std::sync::Arc; + use hex::ToHex; +use itertools::Itertools; use strum_macros::FromRepr; mod basetypes; use std::{cmp::PartialEq, hash}; -use containers::{Array, Container, Dictionary, Set}; +use containers::{Array, Dictionary, Set}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub mod containers; @@ -19,7 +22,6 @@ pub mod serialization; mod statement; use std::{any::Any, fmt}; -pub mod db; pub use basetypes::*; pub use custom::*; use dyn_clone::DynClone; @@ -29,10 +31,14 @@ pub use pod_deserialization::*; use serialization::*; pub use statement::*; +use crate::backends::plonky2::primitives::merkletree::{ + MerkleProof, MerkleTreeStateTransitionProof, +}; + // TODO: Move all value-related types to to `value.rs` #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] // TODO #[schemars(transform = serialization::transform_value_schema)] -pub(crate) enum TypedValue { +pub enum TypedValue { // Serde cares about the order of the enum variants, with untagged variants // appearing at the end. // Variants without "untagged" will be serialized as "tagged" values by @@ -67,6 +73,8 @@ pub(crate) enum TypedValue { Array(Array), #[serde(untagged)] String(String), + #[serde(untagged)] + Bool(bool), } impl From<&str> for TypedValue { @@ -89,11 +97,7 @@ impl From for TypedValue { impl From for TypedValue { fn from(b: bool) -> Self { - if b { - TypedValue::Int(1) - } else { - TypedValue::Int(0) - } + TypedValue::Bool(b) } } @@ -145,6 +149,70 @@ impl From for TypedValue { } } +impl TryFrom<&TypedValue> for i64 { + type Error = Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::Int(n) = v { + Ok(*n) + } else { + Err(Error::custom("Value not an int".to_string())) + } + } +} + +impl TryFrom<&TypedValue> for String { + type Error = Error; + fn try_from(tv: &TypedValue) -> Result { + match tv { + TypedValue::String(s) => Ok(s.clone()), + _ => Err(Error::custom(format!( + "Value {} cannot be converted to a string.", + tv + ))), + } + } +} + +impl TryFrom<&TypedValue> for Key { + type Error = Error; + fn try_from(tv: &TypedValue) -> Result { + Ok(Key::new(String::try_from(tv)?)) + } +} + +impl TryFrom<&TypedValue> for PublicKey { + type Error = Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::PublicKey(pk) = v { + Ok(*pk) + } else { + Err(Error::custom("Value not a public key".to_string())) + } + } +} + +impl TryFrom<&TypedValue> for SecretKey { + type Error = Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::SecretKey(sk) = v { + Ok(sk.clone()) + } else { + Err(Error::custom("Value not a secret key".to_string())) + } + } +} + +impl TryFrom<&TypedValue> for Predicate { + type Error = Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::Predicate(p) = v { + Ok(p.clone()) + } else { + Err(Error::custom("Value not a Predicate".to_string())) + } + } +} + impl fmt::Display for TypedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -156,54 +224,36 @@ impl fmt::Display for TypedValue { Err(_) => write!(f, "\"{}\"", s), } } + TypedValue::Bool(b) => write!(f, "{}", b), TypedValue::Array(a) => { write!(f, "[")?; - for (i, r) in a.iter().enumerate() { + for (i, v) in a.array().iter().enumerate() { if i > 0 { write!(f, ", ")?; } - if i == 8 { - write!(f, "…")?; - break; - } - match r { - Ok((index, value)) => write!(f, "{}: {}", index, value)?, - Err(e) => write!(f, "{e}")?, - } + write!(f, "{}", v)?; } write!(f, "]") } TypedValue::Dictionary(d) => { write!(f, "{{ ")?; - for (i, r) in d.iter().enumerate() { + let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect(); + for (i, (k, v)) in kvs.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - if i == 8 { - write!(f, "…")?; - break; - } - match r { - Ok((key, value)) => write!(f, "{}: {}", key, value)?, - Err(e) => write!(f, "{e}")?, - } + write!(f, "{}: {}", k, v)?; } write!(f, " }}") } TypedValue::Set(s) => { write!(f, "#[")?; - for (i, r) in s.iter().enumerate() { + let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect(); + for (i, v) in values.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - if i == 8 { - write!(f, "…")?; - break; - } - match r { - Ok(value) => write!(f, "{}", value)?, - Err(e) => write!(f, "{e}")?, - } + write!(f, "{}", v)?; } write!(f, "]") } @@ -222,6 +272,7 @@ impl From<&TypedValue> for RawValue { match v { TypedValue::String(s) => RawValue::from(hash_str(s)), TypedValue::Int(v) => RawValue::from(*v), + TypedValue::Bool(b) => RawValue::from(*b as i64), TypedValue::Dictionary(d) => RawValue::from(d.commitment()), TypedValue::Set(s) => RawValue::from(s.commitment()), TypedValue::Array(a) => RawValue::from(a.commitment()), @@ -354,8 +405,9 @@ impl JsonSchema for TypedValue { #[derive(Clone, Debug)] pub struct Value { - pub(crate) typed: TypedValue, - pub(crate) raw: RawValue, + // The `TypedValue` is under `Arc` so that cloning a `Value` is cheap. + typed: Arc, + raw: RawValue, } // Values are serialized as their TypedValue. @@ -389,55 +441,6 @@ impl JsonSchema for Value { } } -/// Dual of TypedValue that is not recursive: for container types no entry only the commitment -/// (merkle tree root of underlying data) is available. Used for byte serialization for -/// persistent storage. -#[derive(Serialize, Deserialize)] -enum TypedValueNoRec { - Raw(RawValue), - Int(i64), - PublicKey(PublicKey), - SecretKey(SecretKey), - Predicate(Predicate), - Set(Hash), - Dictionary(Hash), - Array(Hash), - String(String), -} - -// NOTE: byte serialization is using json. Using a byte-native serialization would improve -// performance and storage usage. -impl Value { - pub fn to_bytes(&self) -> Vec { - let v = match &self.typed { - TypedValue::Int(v) => TypedValueNoRec::Int(*v), - TypedValue::Raw(v) => TypedValueNoRec::Raw(*v), - TypedValue::PublicKey(v) => TypedValueNoRec::PublicKey(*v), - TypedValue::SecretKey(v) => TypedValueNoRec::SecretKey(v.clone()), - TypedValue::Predicate(v) => TypedValueNoRec::Predicate(v.clone()), - TypedValue::Set(v) => TypedValueNoRec::Set(v.commitment()), - TypedValue::Dictionary(v) => TypedValueNoRec::Dictionary(v.commitment()), - TypedValue::Array(v) => TypedValueNoRec::Array(v.commitment()), - TypedValue::String(v) => TypedValueNoRec::String(v.clone()), - }; - serde_json::to_vec(&v).expect("json serialization succeeds") - } - pub fn from_bytes(bytes: &[u8], db: Box) -> Result { - let v: TypedValueNoRec = serde_json::from_slice(bytes)?; - Ok(match v { - TypedValueNoRec::Int(v) => Value::from(v), - TypedValueNoRec::Raw(v) => Value::from(v), - TypedValueNoRec::PublicKey(v) => Value::from(v), - TypedValueNoRec::SecretKey(v) => Value::from(v), - TypedValueNoRec::Predicate(v) => Value::from(v), - TypedValueNoRec::Set(v) => Value::from(Set::from_db(v, db)?), - TypedValueNoRec::Dictionary(v) => Value::from(Dictionary::from_db(v, db)?), - TypedValueNoRec::Array(v) => Value::from(Array::from_db(v, db)?), - TypedValueNoRec::String(v) => Value::from(v), - }) - } -} - impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { self.raw == other.raw @@ -459,110 +462,106 @@ impl fmt::Display for Value { } impl Value { - pub(crate) fn new(value: TypedValue) -> Self { + pub fn new(value: TypedValue) -> Self { let raw_value = RawValue::from(&value); Self { - typed: value, + typed: Arc::new(value), raw: raw_value, } } + pub fn typed(&self) -> &TypedValue { + &self.typed + } pub fn raw(&self) -> RawValue { self.raw } - /// Returns true if the typed value is RawValue, which means it's a generic value with no type - /// information and no extra value data. - pub fn is_raw(&self) -> bool { - matches!(self.typed, TypedValue::Raw(_)) - } - pub fn as_raw(&self) -> RawValue { - self.raw - } - pub fn as_int(&self) -> Option { - match self.typed { - TypedValue::Int(i) => Some(i), - _ => None, - } - } - pub fn as_public_key(&self) -> Option { - match &self.typed { - TypedValue::PublicKey(pk) => Some(*pk), - _ => None, - } - } - pub fn as_secret_key(&self) -> Option { - match &self.typed { - TypedValue::SecretKey(sk) => Some(sk.clone()), - _ => None, - } - } - pub fn as_predicate(&self) -> Option { - match &self.typed { - TypedValue::Predicate(p) => Some(p.clone()), - _ => None, - } - } - pub fn as_set(&self) -> Option { - match &self.typed { - TypedValue::Set(s) => Some(s.clone()), - TypedValue::Dictionary(d) => Some(Set { - inner: d.inner.clone(), - }), - TypedValue::Array(a) => Some(Set { - inner: a.inner.clone(), - }), - _ => None, - } - } - pub fn as_container(&self) -> Option { - match &self.typed { - TypedValue::Set(s) => Some(s.inner.clone()), - TypedValue::Dictionary(d) => Some(d.inner.clone()), - TypedValue::Array(a) => Some(a.inner.clone()), - _ => None, - } - } - pub fn as_dictionary(&self) -> Option { - match &self.typed { - TypedValue::Set(s) => Some(Dictionary { - inner: s.inner.clone(), - }), - TypedValue::Dictionary(d) => Some(d.clone()), - TypedValue::Array(a) => Some(Dictionary { - inner: a.inner.clone(), - }), - _ => None, - } - } - pub fn as_array(&self) -> Option { - match &self.typed { - TypedValue::Set(s) => Some(Array { - inner: s.inner.clone(), - }), - TypedValue::Dictionary(d) => Some(Array { - inner: d.inner.clone(), - }), - TypedValue::Array(a) => Some(a.clone()), - _ => None, - } - } - pub fn as_str(&self) -> Option<&str> { - match &self.typed { - TypedValue::String(s) => Some(s.as_str()), - _ => None, - } - } - pub fn as_string(&self) -> Option { - self.as_str().map(|s| s.to_string()) - } - pub fn as_bool(&self) -> Option { - match self.typed { - TypedValue::Int(i) => match i { - 0 => Some(false), - 1 => Some(true), - _ => None, + /// Determines Merkle existence proof for `key` in `self` (if applicable). + pub(crate) fn prove_existence<'a>( + &'a self, + key: &'a Value, + ) -> Result<(&'a Value, MerkleProof)> { + match &self.typed() { + TypedValue::Array(a) => match key.typed() { + TypedValue::Int(i) if i >= &0 => a.prove((*i) as usize), + _ => Err(Error::custom(format!( + "Invalid key {} for container {}.", + key, self + )))?, }, - _ => None, + TypedValue::Dictionary(d) => d.prove(&key.typed().try_into()?), + TypedValue::Set(s) => Ok((key, s.prove(key)?)), + _ => Err(Error::custom(format!( + "Invalid container value {}", + self.typed() + ))), + } + } + /// Determines Merkle non-existence proof for `key` in `self` (if applicable). + pub(crate) fn prove_nonexistence<'a>(&'a self, key: &'a Value) -> Result { + match &self.typed() { + TypedValue::Array(_) => Err(Error::custom( + "Arrays do not support `NotContains` operation.".to_string(), + )), + TypedValue::Dictionary(d) => d.prove_nonexistence(&key.typed().try_into()?), + TypedValue::Set(s) => s.prove_nonexistence(key), + _ => Err(Error::custom(format!( + "Invalid container value {}", + self.typed() + ))), + } + } + /// Returns a Merkle state transition proof for inserting a + /// key-value pair (if applicable). + pub(crate) fn prove_insertion( + &self, + key: &Value, + value: &Value, + ) -> Result { + let container = self.typed().clone(); + match container { + TypedValue::Dictionary(mut d) => d.insert(&key.typed().try_into()?, value), + TypedValue::Set(mut s) => s.insert(value), + _ => Err(Error::custom(format!( + "Invalid container value {}", + self.typed() + ))), + } + } + /// Returns a Merkle state transition proof for updating a + /// key-value pair (if applicable). + pub(crate) fn prove_update( + &self, + key: &Value, + value: &Value, + ) -> Result { + let container = self.typed().clone(); + match container { + TypedValue::Array(mut a) => match key.typed() { + TypedValue::Int(i) if i >= &0 => a.update(*i as usize, value), + _ => Err(Error::custom(format!( + "Invalid key {} for container {}.", + key, self + )))?, + }, + TypedValue::Dictionary(mut d) => d.update(&key.typed().try_into()?, value), + _ => Err(Error::custom(format!( + "Invalid container value {} for update op", + self.typed() + ))), + } + } + /// Returns a Merkle state transition proof for deleting a + /// key (if applicable). + pub(crate) fn prove_deletion(&self, key: &Value) -> Result { + let container = self.typed().clone(); + match container { + TypedValue::Dictionary(mut d) => d.delete(&key.typed().try_into()?), + TypedValue::Set(mut s) => s.delete(key), + _ => Err(Error::custom(format!( + "Invalid container value {}", + self.typed() + ))), } } } @@ -768,8 +767,6 @@ pub struct BaseParams { /// in a custom predicate pub max_custom_predicate_arity: usize, pub max_depth_custom_batch_mt: usize, - // This value depends on `max_custom_predicate_arity` - pub max_operation_args: usize, } pub const BASE_PARAMS: BaseParams = BaseParams { @@ -777,53 +774,8 @@ pub const BASE_PARAMS: BaseParams = BaseParams { max_statement_args: 5, max_custom_predicate_arity: 5, max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch - max_operation_args: 5 + 1, }; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] -#[serde(rename_all = "camelCase")] -pub struct ParamsMerkleProofs { - pub max_small: usize, - pub max_medium: usize, -} - -impl ParamsMerkleProofs { - pub fn max_total(&self) -> usize { - self.max_small + self.max_medium - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] -#[serde(rename_all = "camelCase")] -pub struct ParamsContainers { - // Parameters for exists/nonexists container operations. The small set only supports exists - pub state: ParamsMerkleProofs, - // Parameters for transition container operations (insert, delete, update). The small set only - // supports update. - pub transition: ParamsMerkleProofs, - // Max depth of small proofs - pub max_depth_small: usize, - // Max depth of medium proofs - pub max_depth_medium: usize, -} - -impl Default for ParamsContainers { - fn default() -> Self { - Self { - state: ParamsMerkleProofs { - max_small: 22, - max_medium: 8, - }, - transition: ParamsMerkleProofs { - max_small: 12, - max_medium: 6, - }, - max_depth_small: 8, - max_depth_medium: 32, - } - } -} - /// Params: non dynamic parameters that define the circuit. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] @@ -832,12 +784,18 @@ pub struct Params { pub max_input_pods_public_statements: usize, pub max_statements: usize, pub max_public_statements: usize, + pub max_operation_args: usize, // max number of different custom predicates that can be used in a MainPod pub max_custom_predicates: usize, // max number of operations using custom predicates that can be verified in the MainPod pub max_custom_predicate_verifications: usize, pub max_custom_predicate_wildcards: usize, - pub containers: ParamsContainers, + // maximum number of merkle proofs used for container operations + pub max_merkle_proofs_containers: usize, + // maximum number of merkle tree state transition proofs used for container update operations + pub max_merkle_tree_state_transition_proofs_containers: usize, + // maximum depth for merkle tree gadget used for container operations + pub max_depth_mt_containers: usize, // maximum depth of the merkle tree gadget used for verifier_data membership // check. This allows creating verifying sets of pod circuits of size // 2^max_depth_mt_vds. Limits the number of container operations of the type Contains, @@ -856,10 +814,13 @@ impl Default for Params { max_input_pods_public_statements: 8, max_statements: 48, max_public_statements: 8, + max_operation_args: 5, max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, - containers: ParamsContainers::default(), + max_merkle_proofs_containers: 20, + max_merkle_tree_state_transition_proofs_containers: 6, + max_depth_mt_containers: 32, max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits max_public_key_of: 2, max_signed_by: 4, diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 8d3316c..526ff51 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -7,14 +7,17 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::{ - ec::{curve::GROUP_ORDER, schnorr::Signature}, + ec::{ + curve::{Point as PublicKey, GROUP_ORDER}, + schnorr::{SecretKey, Signature}, + }, merkletree::{MerkleProof, MerkleTree, MerkleTreeOp, MerkleTreeStateTransitionProof}, }, middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, - Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef, - Wildcard, BASE_PARAMS, F, + Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, + ValueRef, Wildcard, F, }, }; @@ -89,7 +92,6 @@ pub enum NativeOperation { ContainerInsertFromEntries = 16, ContainerUpdateFromEntries = 17, ContainerDeleteFromEntries = 18, - ReplaceValueWithEntry = 19, // Syntactic sugar operations. These operations are not supported by the backend. The // frontend compiler is responsible of translating these operations into the operations above. @@ -165,7 +167,6 @@ impl OperationType { NativeOperation::ContainerDeleteFromEntries => { Some(Predicate::Native(NativePredicate::ContainerDelete)) } - NativeOperation::ReplaceValueWithEntry => None, no => unreachable!("Unexpected syntactic sugar op {:?}", no), }, OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), @@ -221,10 +222,6 @@ pub enum Operation { /* key */ Statement, /* proof */ MerkleTreeStateTransitionProof, ), - ReplaceValueWithEntry( - /* Contains/None len=max_statement_args */ Vec, - /* to copy */ Statement, - ), Custom(CustomPredicateRef, Vec), } @@ -244,10 +241,6 @@ pub(crate) fn hash_op(x: Value, y: Value) -> Value { Value::from(hash_values(&[x, y])) } -fn ok_or_type_err(o: Option, v: &Value, typ: &'static str) -> Result { - o.ok_or_else(|| Error::custom(format!("{v} type is not {typ}"))) -} - impl Operation { pub fn op_type(&self) -> OperationType { type OT = OperationType; @@ -276,7 +269,6 @@ impl Operation { OT::Native(ContainerUpdateFromEntries) } Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries), - Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry), Self::Custom(cpr, _) => OT::Custom(cpr.clone()), } } @@ -302,11 +294,6 @@ impl Operation { Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3], - Self::ReplaceValueWithEntry(args, s) => { - let mut sts = args; - sts.push(s); - sts - } Self::Custom(_, args) => args, } } @@ -389,18 +376,6 @@ impl Operation { &[s1, s2, s3], OA::MerkleTreeStateTransitionProof(pf), ) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf), - (NO::ReplaceValueWithEntry, args, OA::None) => { - let mut args = args.to_vec(); - if args.len() != BASE_PARAMS.max_statement_args + 1 { - return Err(Error::custom(format!( - "ReplaceValueWithEntry requires exactly {} args but {} were found", - BASE_PARAMS.max_statement_args + 1, - args.len() - ))); - } - let st = args.pop().expect("valid vec len"); - Self::ReplaceValueWithEntry(args, st) - } _ => Err(Error::custom(format!( "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", op_code, @@ -429,55 +404,23 @@ impl Operation { v3: &Value, f: impl FnOnce(i64, i64) -> i64, ) -> Result { - let i1 = ok_or_type_err(v1.as_int(), v1, "Int")?; - let i2 = ok_or_type_err(v2.as_int(), v2, "Int")?; - let i3 = ok_or_type_err(v3.as_int(), v3, "Int")?; + let i1: i64 = v1.typed().try_into()?; + let i2: i64 = v2.typed().try_into()?; + let i3: i64 = v3.typed().try_into()?; Ok(i1 == f(i2, i3)) } pub(crate) fn check_public_key(v1: &Value, v2: &Value) -> Result { - let pk = ok_or_type_err(v1.as_public_key(), v1, "PublicKey")?; - let sk = ok_or_type_err(v2.as_secret_key(), v2, "SecretKey")?; + let pk: PublicKey = v1.typed().try_into()?; + let sk: SecretKey = v2.typed().try_into()?; Ok(sk.0 < *GROUP_ORDER && pk == sk.public_key()) } pub(crate) fn check_signed_by(msg: &Value, pk: &Value, sig: &Signature) -> Result { - let pk = ok_or_type_err(pk.as_public_key(), pk, "PublicKey")?; + let pk: PublicKey = pk.typed().try_into()?; Ok(sig.verify(pk, msg.raw())) } - fn check_replace_value_with_entry( - entries: &[Statement], - st_in: &Statement, - expected_st_out: &Statement, - ) -> Result { - if entries.len() != BASE_PARAMS.max_statement_args { - return Ok(false); - } - let args = iter::zip(st_in.args(), entries) - .map(|(arg_in, entry)| match (arg_in, entry) { - (arg_in, Statement::None) => Ok(arg_in), - ( - StatementArg::Literal(v_in), - Statement::Contains( - ValueRef::Literal(root), - ValueRef::Literal(key), - ValueRef::Literal(v), - ), - ) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new( - Hash::from(root.raw()), - Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?), - ))), - _ => Err(Error::custom( - "invalid statement argument in ReplaceValueWithEntry", - )), - }) - .collect::>>()?; - - let st_out = Statement::from_args(st_in.predicate(), args)?; - Ok(&st_out == expected_st_out) - } - /// Checks the given operation against a statement. pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; @@ -485,8 +428,8 @@ impl Operation { let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); let int_val = |v, s| { let v_op = value_from_op(s, v).ok_or_else(deduction_err)?; - match v_op.as_int() { - Some(i) => Ok(i), + match v_op.typed() { + &TypedValue::Int(i) => Ok(i), _ => Err(deduction_err()), } }; @@ -551,7 +494,8 @@ impl Operation { && pf.op_value == value.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim.", + "The provided Merkle tree state transition proof does not match the claim." + .into(), ))?; MerkleTree::verify_state_transition(pf)?; true @@ -571,7 +515,8 @@ impl Operation { && pf.op_value == value.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim.", + "The provided Merkle tree state transition proof does not match the claim." + .into(), ))?; MerkleTree::verify_state_transition(pf)?; true @@ -589,7 +534,8 @@ impl Operation { && pf.op_key == key.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim.", + "The provided Merkle tree state transition proof does not match the claim." + .into(), ))?; MerkleTree::verify_state_transition(pf)?; true @@ -597,19 +543,7 @@ impl Operation { (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - // The custom operation outputs statements with literal arguments. They can be - // replaced by references later with ReplaceValueWithEntry. - let s_args = s_args - .iter() - .map(|arg| match arg { - ValueRef::Literal(v) => Ok(v.clone()), - _ => Err(deduction_err()), - }) - .collect::>>()?; - check_custom_pred(params, cpr, args, &s_args).map(|_| true)? - } - (Self::ReplaceValueWithEntry(entries, st_in), st_out) => { - Self::check_replace_value_with_entry(entries, st_in, st_out)? + check_custom_pred(params, cpr, args, s_args).map(|_| true)? } _ => return Err(deduction_err()), }; @@ -663,11 +597,6 @@ pub fn check_st_tmpl( (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { wc_check_or_set(v.clone(), wc, wildcard_map) } - (StatementTmplArg::SelfPredicateHash(_), _) => { - unreachable!( - "SelfPredicateHash should be normalized to Literal before template matching" - ) - } _ => Err(Error::mismatched_statement_tmpl_arg( st_tmpl_arg.clone(), st_arg.clone(), @@ -716,9 +645,9 @@ pub fn wildcard_values_from_op_st( params: &Params, pred: &CustomPredicate, op_args: &[Statement], - resolved_st_args: &[Value], + st_args: &[Value], ) -> Result> { - let mut wildcard_map = resolved_st_args + let mut wildcard_map = st_args .iter() .map(|v| Some(v.clone())) .chain(core::iter::repeat(None)) @@ -785,7 +714,7 @@ pub(crate) fn check_custom_pred( args: &[Statement], s_args: &[Value], ) -> Result<()> { - let pred = custom_pred_ref.normalized_predicate(); + let pred = custom_pred_ref.predicate(); if pred.statements.len() != args.len() { return Err(Error::diff_amount( "custom predicate operation".to_string(), @@ -804,7 +733,7 @@ pub(crate) fn check_custom_pred( } // Check that the resolved wildcards match the statement arguments. - let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) { + let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { Ok(wc_values) => wc_values, Err(Error::Inner { inner, backtrace }) => match *inner { MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) @@ -860,8 +789,9 @@ impl fmt::Display for Operation { pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option { let root_hash = Hash::from(root.raw()); - key.as_str() - .map(|s| AnchoredKey::new(root_hash, Key::from(s))) + Key::try_from(key.typed()) + .map(|key| AnchoredKey::new(root_hash, key)) + .ok() } /// Returns the value associated with `output_ref`. diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index b5c1f60..d3e0534 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -311,7 +311,7 @@ pub enum Statement { /* old_root */ ValueRef, /* key */ ValueRef, ), - Custom(CustomPredicateRef, Vec), + Custom(CustomPredicateRef, Vec), Intro(IntroPredicateRef, Vec), } @@ -407,7 +407,7 @@ impl Statement { vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()] } Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], - Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)), + Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)), Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)), } } @@ -478,11 +478,14 @@ impl Statement { } (BatchSelf(_), _) => unreachable!(), (Custom(cpr), _) => { - let v_args = args + let v_args: Result> = args .iter() - .map(|x| x.try_into()) - .collect::>>()?; - Self::Custom(cpr, v_args) + .map(|x| match x { + StatementArg::Literal(v) => Ok(v.clone()), + _ => Err(Error::incorrect_statements_args()), + }) + .collect(); + Self::Custom(cpr, v_args?) } (Intro(ir), _) => { let v_args: Result> = args