Compare commits

...

10 commits

Author SHA1 Message Date
5e3ac9a101
Support mixed depth container merkle proofs (#508)
Some checks failed
Rust Build with features / Rust tests (push) Has been cancelled
Clippy Check / Rust formatting (push) Has been cancelled
Publish MainPod circuit info / Update Wiki with new MainPod circuit info (push) Has been cancelled
Check mdbook compilation / compile (push) Has been cancelled
Publish mdbook / build (push) Has been cancelled
Rustfmt Check / Rust formatting (push) Has been cancelled
Rust Tests / Rust tests (push) Has been cancelled
typos / Spell Check with Typos (push) Has been cancelled
Publish mdbook / deploy (push) Has been cancelled
* remove enabled flag from merkle tree proofs

* add small existence mpt proofs in MainPod

* refactor params, add small transition proofs

* complete

* fix edge case in vdset

* fix: use existence only proof for vdset

* use consistent order for aux table
2026-05-06 12:39:27 +02:00
Rob Knight
111b132a00
Use projected statement lookup for op arg resolution (#503)
* Use projected statement lookup for op arg resolution

* Add projected op-arg index coverage test

* Tidying and reorganising
2026-04-29 09:56:39 +02:00
Rob Knight
8844fe124c
Diagnostics for MultiPodBuilder (#500)
* Diagnostics for MultiPodBuilder

* Reduce duplication
2026-04-23 01:41:29 -07:00
Rob Knight
3203c883e5
Use windowed ECAddXuGate for PublicKeyOf (#501) 2026-04-20 00:07:20 -07:00
dbd958dcca
Allow entries as args in custom statements (#498)
- Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement.
- Allow entries as args in custom statements
- Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement.
2026-04-01 23:49:29 +02:00
Rob Knight
22d25e5cb2
Podlang syntax for quoted predicates (#495) 2026-03-30 15:16:19 +01:00
a4069bcc55
Fix pod builder (#496)
Several fixes and code simplifications:
- MainPodBuilder
  - Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments.
  - Enhancement: Deduplicate statements
- MultiPodBuilder
    - Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do)
    - Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver.
    - Fix: Count and constrain custom predicates used in a pod instead of batches used
2026-03-25 18:48:28 +01:00
Rob Knight
1e592e11cf
Self-referential predicate hashes as statement template args (#494)
* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
2026-03-24 07:25:11 -07:00
13cabdb511
Support persistent storage in Containers (#493)
Extend the work of https://github.com/0xPARC/pod2/pull/487 to the Containers (Dictionary, Set, Array).

The merkle tree only stores `RawValue` for both the key and the value, so it is the responsibility of the Container to store the rich value.

In order to handle containers with persistent storage efficiently (which means, cloning them or updating them should not cause an O(n) data copy) I figured we need to have a database of `Value`s indexed by their raw value; as this gives us deduplication and free cloning of containers.
The issue with this approach is that in the current design we have collisions between Value's of different types: https://github.com/0xPARC/pod2/issues/426 and the current API relies on the single type of values.

To resolve this issue I decided to change the API, instead of assuming that a Value has a fixed type, let the value be possibly multiple compatible types and let the user of the library try casting the Value to a particular type.
For this I deprecated the public access of everything related to `TypedValue` and I propose for it to be considered an implementation detail and a blackbox from the external developer point of view.  The `Value` type is now used like this:
- To create a new Value use `Value::from(...)` where you can pass any compatible type (the same types as before)
- To access the Value in typed form you cast it like `value.as_foo()` which returns `Option<Foo>`.

Previously we had a collision between `true` and `1` (and `false` and `0`).  Now it doesn't matter whether a value holds a `true` or a `1`, both should be seen as the same and both return `Some` when doing `as_int` and `as_bool`.

Similarly we had collisions with containers.  For example `set(0, 1, 2) == array[0, 1, 2]` and `set("a", "b") = dict("a": "a", "b": "b")`.  Now any container can be casted to any of `set, array, dict`.  There's a caveat here: each of these types expects a particular encoding of keys, so casting to the wrong type will return errors on some operations.

With this design it no longer matters what is being stored and recovered because the API requires the user to express the expected type and any type with collisions for particular values can be casted to the right type.

There's only one case where it's not desirable to swap one `TypedValue` for another: the `TypedValue::Raw`.  If a non-`RawValue` in the DB is replaced by the corresponding `RawValue` we erase the required information to recover the rich value.  For this reason the implementations of the database treat the `RawValue` as a special case: if an value is stored in non-`RawValue`, the corresponding `RawValue` can never overwrite it.  If a value is stored in `RawValue`, a matching non-`RawValue` will overwrite it (promoting it to a rich value).  This way we never lose data.

A consequence of this is that the serialization, `Display` and `Debug` of a container is not stable.  At any point any of the entries can be swapped for a "compatible" one if they share the storage with other containers that introduce collisions.

I rewrote all containers as wrapper to a generic `Container` which holds a `Map` from `Value` to `Value`.  The serialization of each container now uses the single implementation of the generic `Container`.
2026-03-23 12:31:28 +01:00
arnaucube
32f45872d7
Re-implement merkletree with persistent storage (key-value db) (#487)
* refactor merkletree to work with disk keyvalue database (wip)

* various fixes post reimplementation; pending delete leaf

* add delete operation case for the new in db tree approach

* polish tree update & delete; everything works (pending polishing)

* polish panics into errs, prints, etc

* Implement iterator

* Lint

* fix case no-siblings

* case delete with semi-empty branch

* polishing

* starting to add rocksdb & heeddb for the DB & Txn traits

* Satisfy the borrow checker

* abstract merkletree tests to use the various available DBs

* update store_node interface (rm hash input), rm heed.rs

* polishing

* typos

* Ditch transactions

* add feature for rocksdb, return errs at new_with_db, remove empty leaf case in Leaf::new

* intermediate instead of leaf in empty node when deleting leaf

---------

Co-authored-by: Ahmad <root@ahmadafuni.com>
2026-03-11 16:32:42 +01:00
50 changed files with 6517 additions and 3603 deletions

View file

@ -24,6 +24,8 @@ jobs:
run: cargo build --features metrics run: cargo build --features metrics
- name: Build time - name: Build time
run: cargo build --features time run: cargo build --features time
- name: Build db_rocksdb
run: cargo build --features db_rocksdb
- name: Build disk_cache - name: Build disk_cache
run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache

View file

@ -17,4 +17,5 @@ jobs:
- name: Set up Rust - name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1 uses: actions-rust-lang/setup-rust-toolchain@v1
- name: Run tests - name: Run tests
run: cargo test --release # RocksDB is disabled by default but we still want to test it.
run: cargo test --release --features db_rocksdb

View file

@ -48,6 +48,7 @@ good_lp = { version = "1.8", default-features = false, features = [
"scip_bundled", "scip_bundled",
] } ] }
annotate-snippets = "0.11" annotate-snippets = "0.11"
rocksdb = { version = "0.24.0", optional = true } # keyvalue database for merkletree
# Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory.
# [patch."https://github.com/0xPARC/plonky2"] # [patch."https://github.com/0xPARC/plonky2"]
@ -57,6 +58,7 @@ annotate-snippets = "0.11"
pretty_assertions = "1.4.1" pretty_assertions = "1.4.1"
# Used only for testing JSON Schema generation and validation. # Used only for testing JSON Schema generation and validation.
jsonschema = "0.30.0" jsonschema = "0.30.0"
tempfile = "3"
[build-dependencies] [build-dependencies]
vergen-gitcl = { version = "1.0.0", features = ["build"] } vergen-gitcl = { version = "1.0.0", features = ["build"] }
@ -70,6 +72,7 @@ time = []
examples = [] examples = []
disk_cache = ["directories", "minicbor-serde"] disk_cache = ["directories", "minicbor-serde"]
mem_cache = [] mem_cache = []
db_rocksdb = ["rocksdb"]
# Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release. # Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release.
# [profile.release] # [profile.release]

View file

@ -51,7 +51,7 @@ use crate::{
mainpod::cache_get_rec_main_pod_verifier_circuit_data, mainpod::cache_get_rec_main_pod_verifier_circuit_data,
primitives::merkletree::MerkleClaimAndProof, primitives::merkletree::MerkleClaimAndProof,
}, },
middleware::{containers::Array, Hash, Params, RawValue, Result, Value}, middleware::{containers::Array, Hash, Params, RawValue, Result, Value, EMPTY_HASH},
}; };
pub static DEFAULT_VD_LIST: LazyLock<Vec<VerifierOnlyCircuitData>> = LazyLock::new(|| { pub static DEFAULT_VD_LIST: LazyLock<Vec<VerifierOnlyCircuitData>> = LazyLock::new(|| {
@ -95,6 +95,12 @@ impl Eq for VDSet {}
impl VDSet { impl VDSet {
fn new_from_vds_hashes(mut vds_hashes: Vec<Hash>) -> Self { fn new_from_vds_hashes(mut vds_hashes: Vec<Hash>) -> Self {
// If vds_hashes is empty we add an zero entry to be used as padding when verifying merkle
// proofs of inclusion in the vds set. This zero entry can't be abused because no circuit
// exists with a vds_hash = 0.
if vds_hashes.is_empty() {
vds_hashes.push(EMPTY_HASH);
}
// before using the hash values, sort them, so that each set of // before using the hash values, sort them, so that each set of
// verifier_datas gets the same VDSet root // verifier_datas gets the same VDSet root
vds_hashes.sort(); vds_hashes.sort();
@ -150,6 +156,9 @@ impl VDSet {
))? ))?
.clone()) .clone())
} }
pub fn get_vds_proof_0(&self) -> MerkleClaimAndProof {
self.proofs_map[&self.vds_hashes[0]].clone()
}
/// Returns true if the `verifier_data_hash` is in the set /// Returns true if the `verifier_data_hash` is in the set
pub fn contains(&self, verifier_data_hash: HashOut) -> bool { pub fn contains(&self, verifier_data_hash: HashOut) -> bool {
self.proofs_map self.proofs_map

View file

@ -25,20 +25,20 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
basetypes::{CircuitBuilder, CommonCircuitData, D}, basetypes::{CircuitBuilder, CommonCircuitData, D},
circuits::mainpod::CustomPredicateVerification, circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator},
error::Result, error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement}, mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{ primitives::merkletree::{
verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget,
MerkleProof, MerkleTreeStateTransitionProofTarget, MerkleProof, MerkleProofExistenceTarget, MerkleTreeStateTransitionProofTarget,
}, },
}, },
middleware::{ middleware::{
hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate,
OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix,
PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg,
StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE,
VALUE_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
}, },
}; };
@ -103,6 +103,20 @@ pub struct StatementArgTarget {
pub elements: [Target; STATEMENT_ARG_F_LEN], pub elements: [Target; STATEMENT_ARG_F_LEN],
} }
impl Flattenable for StatementArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"),
}
}
fn size(_params: &Params) -> usize {
STATEMENT_ARG_F_LEN
}
}
impl StatementArgTarget { impl StatementArgTarget {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> { pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?)
@ -318,7 +332,7 @@ impl OperationTarget {
.args() .args()
.iter() .iter()
.chain(iter::repeat(&OperationArg::None)) .chain(iter::repeat(&OperationArg::None))
.take(params.max_operation_args) .take(BASE_PARAMS.max_operation_args)
.enumerate() .enumerate()
{ {
self.args[i].set_targets(pw, arg.as_usize())?; self.args[i].set_targets(pw, arg.as_usize())?;
@ -328,7 +342,7 @@ impl OperationTarget {
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
OperationTypeTarget::size(params) OperationTypeTarget::size(params)
+ params.max_operation_args * IndexTarget::size(params) + BASE_PARAMS.max_operation_args * IndexTarget::size(params)
+ IndexTarget::size(params) + IndexTarget::size(params)
} }
} }
@ -711,7 +725,6 @@ impl CustomPredicateInBatchTarget {
let mtp = let mtp =
MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder); MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder);
let _true = builder._true(); let _true = builder._true();
builder.connect(_true.target, mtp.enabled.target);
builder.connect(_true.target, mtp.existence.target); builder.connect(_true.target, mtp.existence.target);
let zero = builder.constant(F(0)); let zero = builder.constant(F(0));
let key = ValueTarget { let key = ValueTarget {
@ -749,7 +762,7 @@ impl CustomPredicateInBatchTarget {
value: RawValue::from(hash_fields(&predicate.to_fields())), value: RawValue::from(hash_fields(&predicate.to_fields())),
proof: mtp.clone(), proof: mtp.clone(),
}; };
self.mtp.set_targets(pw, true, &mtp_claim)?; self.mtp.set_targets(pw, &mtp_claim)?;
Ok(()) Ok(())
} }
} }
@ -771,7 +784,8 @@ impl CustomPredicateEntryTarget {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
// Replace statement templates of batch-self with (id,index) // Replace BatchSelf predicates with Custom(batch, i), and
// SelfPredicateHash args with Literal(hash(Custom(batch, i)))
let batch = &predicate.batch; let batch = &predicate.batch;
let predicate = predicate.predicate(); let predicate = predicate.predicate();
let statements = predicate let statements = predicate
@ -788,10 +802,22 @@ impl CustomPredicateEntryTarget {
} }
x => x.clone(), x => x.clone(),
}; };
StatementTmpl { let args = st_tmpl
pred_or_wc, .args
args: st_tmpl.args, .into_iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
} }
other => other,
})
.collect();
StatementTmpl { pred_or_wc, args }
}) })
.collect_vec(); .collect_vec();
let predicate = CustomPredicate { let predicate = CustomPredicate {
@ -855,7 +881,7 @@ impl CustomPredicateVerifyEntryTarget {
args: (0..params.max_custom_predicate_wildcards) args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value()) .map(|_| builder.add_virtual_value())
.collect(), .collect(),
op_args: (0..params.max_operation_args) op_args: (0..BASE_PARAMS.max_operation_args)
.map(|_| builder.add_virtual_statement(false)) .map(|_| builder.add_virtual_statement(false))
.collect(), .collect(),
} }
@ -885,7 +911,7 @@ impl CustomPredicateVerifyEntryTarget {
cpv.op_args cpv.op_args
.iter() .iter()
.chain(iter::repeat(&pad_op_arg)) .chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args), .take(BASE_PARAMS.max_operation_args),
) { ) {
op_arg_target.set_targets(pw, op_arg)? op_arg_target.set_targets(pw, op_arg)?
} }
@ -928,7 +954,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
.expect("len = operation_type_size"), .expect("len = operation_type_size"),
}; };
let (pos, size) = (pos + size, StatementTarget::size(params)); let (pos, size) = (pos + size, StatementTarget::size(params));
let op_args = (0..params.max_operation_args) let op_args = (0..BASE_PARAMS.max_operation_args)
.map(|i| { .map(|i| {
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
}) })
@ -940,7 +966,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
} }
} }
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
StatementTarget::size(params) * (1 + params.max_operation_args) StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args)
+ OperationTarget::size(params) + OperationTarget::size(params)
} }
} }
@ -960,7 +986,6 @@ pub trait Flattenable {
/// elsewhere. /// elsewhere.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MerkleClaimTarget { pub struct MerkleClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget, pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget, pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget, pub(crate) value: ValueTarget,
@ -970,7 +995,6 @@ pub struct MerkleClaimTarget {
impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget { impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
fn from(pf: MerkleClaimAndProofTarget) -> Self { fn from(pf: MerkleClaimAndProofTarget) -> Self {
Self { Self {
enabled: pf.enabled,
root: pf.root, root: pf.root,
key: pf.key, key: pf.key,
value: pf.value, value: pf.value,
@ -979,12 +1003,25 @@ impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
} }
} }
impl MerkleClaimTarget {
pub fn from_proof_existence(
builder: &mut CircuitBuilder,
pf: MerkleProofExistenceTarget,
) -> Self {
Self {
root: pf.root,
key: pf.key,
value: pf.value,
existence: builder._true(),
}
}
}
/// For the purpose of op verification, we need only look up the /// For the purpose of op verification, we need only look up the
/// Merkle state transition claim rather than the Merkle state /// Merkle state transition claim rather than the Merkle state
/// transition proof since it is verified elsewhere. /// transition proof since it is verified elsewhere.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct MerkleTreeStateTransitionClaimTarget { pub struct MerkleTreeStateTransitionClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) op: Target, pub(crate) op: Target,
pub(crate) old_root: HashOutTarget, pub(crate) old_root: HashOutTarget,
pub(crate) new_root: HashOutTarget, pub(crate) new_root: HashOutTarget,
@ -995,7 +1032,6 @@ pub struct MerkleTreeStateTransitionClaimTarget {
impl From<MerkleTreeStateTransitionProofTarget> for MerkleTreeStateTransitionClaimTarget { impl From<MerkleTreeStateTransitionProofTarget> for MerkleTreeStateTransitionClaimTarget {
fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self { fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self {
Self { Self {
enabled: pf.enabled,
op: pf.op, op: pf.op,
old_root: pf.old_root, old_root: pf.old_root,
new_root: pf.new_root, new_root: pf.new_root,
@ -1036,7 +1072,6 @@ impl Flattenable for ValueTarget {
impl Flattenable for MerkleClaimTarget { impl Flattenable for MerkleClaimTarget {
fn flatten(&self) -> Vec<Target> { fn flatten(&self) -> Vec<Target> {
[ [
vec![self.enabled.target],
self.root.elements.to_vec(), self.root.elements.to_vec(),
self.key.elements.to_vec(), self.key.elements.to_vec(),
self.value.elements.to_vec(), self.value.elements.to_vec(),
@ -1048,31 +1083,28 @@ impl Flattenable for MerkleClaimTarget {
fn from_flattened(params: &Params, vs: &[Target]) -> Self { fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params)); assert_eq!(vs.len(), Self::size(params));
Self { Self {
enabled: BoolTarget::new_unsafe(vs[0]), root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()),
root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), key: ValueTarget::from_slice(&vs[NUM_HASH_OUT_ELTS..NUM_HASH_OUT_ELTS + VALUE_SIZE]),
key: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE],
),
value: ValueTarget::from_slice( value: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], &vs[NUM_HASH_OUT_ELTS + VALUE_SIZE..NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
), ),
existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), existence: BoolTarget::new_unsafe(vs[NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]),
} }
} }
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params) HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1
} }
} }
impl Flattenable for MerkleTreeStateTransitionClaimTarget { impl Flattenable for MerkleTreeStateTransitionClaimTarget {
fn flatten(&self) -> Vec<Target> { fn flatten(&self) -> Vec<Target> {
[ [
vec![self.enabled.target, self.op],
self.old_root.elements.to_vec(), self.old_root.elements.to_vec(),
self.new_root.elements.to_vec(), self.new_root.elements.to_vec(),
self.op_key.elements.to_vec(), self.op_key.elements.to_vec(),
self.op_value.elements.to_vec(), self.op_value.elements.to_vec(),
vec![self.op],
] ]
.concat() .concat()
} }
@ -1080,24 +1112,22 @@ impl Flattenable for MerkleTreeStateTransitionClaimTarget {
fn from_flattened(params: &Params, vs: &[Target]) -> Self { fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params)); assert_eq!(vs.len(), Self::size(params));
Self { Self {
enabled: BoolTarget::new_unsafe(vs[0]), old_root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()),
op: vs[1],
old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()),
new_root: HashOutTarget::from_vec( new_root: HashOutTarget::from_vec(
vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(), vs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS].to_vec(),
), ),
op_key: ValueTarget::from_slice( op_key: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE], &vs[2 * NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS + VALUE_SIZE],
), ),
op_value: ValueTarget::from_slice( op_value: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE &vs[2 * NUM_HASH_OUT_ELTS + VALUE_SIZE..2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE],
), ),
op: vs[2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
} }
} }
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params) 2 * HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1
} }
} }
@ -1335,6 +1365,18 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T; fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
/// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then
/// materializes the selected row via a witness generator and constrains its hash. Cheaper than
/// `vec_ref` when each entry has many fields, since random access runs only over the 4-field
/// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and
/// reusing the same slices across multiple lookups.
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T;
fn select_flattenable<T: Flattenable>( fn select_flattenable<T: Flattenable>(
&mut self, &mut self,
params: &Params, params: &Params,
@ -1412,7 +1454,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget { OperationTarget {
op_type: self.add_virtual_operation_type(), op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args) args: (0..BASE_PARAMS.max_operation_args)
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
.collect(), .collect(),
aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self),
@ -1722,7 +1764,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
let num_chunks = array.len().div_ceil(CHUNK_LEN); let num_chunks = array.len().div_ceil(CHUNK_LEN);
for chunk in array.chunks(CHUNK_LEN) { for chunk in array.chunks(CHUNK_LEN) {
let mut index_chunk = i.low; let mut index_chunk = i.low;
// I we have several chunks and the last one is smaller (it's index needs less than 6 // If we have several chunks and the last one is smaller (it's index needs less than 6
// bits), make it zero except when it's used so that the range check over the index // bits), make it zero except when it's used so that the range check over the index
// passes. // passes.
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {
@ -1737,12 +1779,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
self.random_access(i.high, chunk_res) self.random_access(i.high, chunk_res)
} }
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len(); let num_rows = m.len();
@ -1766,6 +1802,28 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
} }
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T {
assert_eq!(ts_flattened.len(), ts_hashes.len());
let selected_hash = self.vec_ref(params, ts_hashes, i);
let selected_flattened = self.add_virtual_targets(T::size(params));
let selected_flattened_hash =
self.hash_n_to_hash_no_pad::<PoseidonHash>(selected_flattened.clone());
self.connect_hashes(selected_hash, selected_flattened_hash);
let result = T::from_flattened(params, &selected_flattened);
self.add_simple_generator(TableGetGenerator::new(
i.clone(),
ts_flattened.to_vec(),
selected_flattened,
));
result
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T { fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero(); let zero = self.zero();
self.vec_ref( self.vec_ref(
@ -2012,7 +2070,7 @@ pub(crate) mod tests {
// Empty case // Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?; _ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish(); let custom_predicate_batch = cpb_builder.finish()?;
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
// Some cases from the examples // Some cases from the examples

File diff suppressed because it is too large Load diff

View file

@ -107,11 +107,11 @@ impl MuxTableTarget {
rev_resolved_tagged_flattened.reverse(); rev_resolved_tagged_flattened.reverse();
let resolved_tagged_flattened = rev_resolved_tagged_flattened; let resolved_tagged_flattened = rev_resolved_tagged_flattened;
builder.add_simple_generator(TableGetGenerator { builder.add_simple_generator(TableGetGenerator::new(
index: index.clone(), index.clone(),
tagged_entries: self.tagged_entries.clone(), self.tagged_entries.clone(),
get_tagged_entry: resolved_tagged_flattened.clone(), resolved_tagged_flattened.clone(),
}); ));
measure_gates_end!(builder, measure); measure_gates_end!(builder, measure);
TableEntryTarget { TableEntryTarget {
params: self.params.clone(), params: self.params.clone(),
@ -123,8 +123,18 @@ impl MuxTableTarget {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct TableGetGenerator { pub struct TableGetGenerator {
index: IndexTarget, index: IndexTarget,
tagged_entries: Vec<Vec<Target>>, entries: Vec<Vec<Target>>,
get_tagged_entry: Vec<Target>, revealed_entry: Vec<Target>,
}
impl TableGetGenerator {
pub fn new(index: IndexTarget, entries: Vec<Vec<Target>>, revealed_entry: Vec<Target>) -> Self {
Self {
index,
entries,
revealed_entry,
}
}
} }
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator { impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
@ -135,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
[self.index.low, self.index.high] [self.index.low, self.index.high]
.into_iter() .into_iter()
.chain(self.tagged_entries.iter().flatten().copied()) .chain(self.entries.iter().flatten().copied())
.collect() .collect()
} }
@ -148,12 +158,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
let index_high = witness.get_target(self.index.high); let index_high = witness.get_target(self.index.high);
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
let entry = witness.get_targets(&self.tagged_entries[index as usize]); let entry = witness.get_targets(&self.entries[index as usize]);
for (target, value) in self.get_tagged_entry.iter().zip( for (target, value) in self.revealed_entry.iter().zip(
entry entry
.iter() .iter()
.chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), .chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())),
) { ) {
out_buffer.set_target(*target, *value)?; out_buffer.set_target(*target, *value)?;
} }
@ -166,12 +176,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
dst.write_target(self.index.low)?; dst.write_target(self.index.low)?;
dst.write_target(self.index.high)?; dst.write_target(self.index.high)?;
dst.write_usize(self.tagged_entries.len())?; dst.write_usize(self.entries.len())?;
for tagged_entry in &self.tagged_entries { for entry in &self.entries {
dst.write_target_vec(tagged_entry)?; dst.write_target_vec(entry)?;
} }
dst.write_target_vec(&self.get_tagged_entry) dst.write_target_vec(&self.revealed_entry)
} }
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> { fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
@ -181,16 +191,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
high: src.read_target()?, high: src.read_target()?,
}; };
let len = src.read_usize()?; let len = src.read_usize()?;
let mut tagged_entries = Vec::with_capacity(len); let mut entries = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
tagged_entries.push(src.read_target_vec()?); entries.push(src.read_target_vec()?);
} }
let get_tagged_entry = src.read_target_vec()?; let revealed_entry = src.read_target_vec()?;
Ok(Self { Ok(Self {
index, index,
tagged_entries, entries,
get_tagged_entry, revealed_entry,
}) })
} }
} }

View file

@ -61,8 +61,8 @@ macro_rules! new {
} }
use InnerError::*; use InnerError::*;
impl Error { impl Error {
pub fn custom(s: String) -> Self { pub fn custom(s: impl Into<String>) -> Self {
new!(Custom(s)) new!(Custom(s.into()))
} }
pub fn plonky2_proof_fail(context: impl Into<String>, e: anyhow::Error) -> Self { pub fn plonky2_proof_fail(context: impl Into<String>, e: anyhow::Error) -> Self {
Self::Plonky2ProofFail(context.into(), e) Self::Plonky2ProofFail(context.into(), e)

View file

@ -1,5 +1,5 @@
pub mod operation; pub mod operation;
use crate::middleware::{wildcard_values_from_op_st, PodType}; use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS};
pub mod statement; pub mod statement;
use std::iter; use std::iter;
@ -39,7 +39,7 @@ use crate::{
middleware::{ middleware::{
self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs,
MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
ToFields, VDSet, Value, ToFields, VDSet, Value, ValueRef,
}, },
timed, timed,
}; };
@ -104,8 +104,20 @@ pub(crate) fn extract_custom_predicate_verifications(
if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Operation::Custom(cpr, sts) = op {
if let middleware::Statement::Custom(st_cpr, st_args) = st { if let middleware::Statement::Custom(st_cpr, st_args) = st {
assert_eq!(cpr, st_cpr); assert_eq!(cpr, st_cpr);
// The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let st_args = st_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(Error::custom(
"custom operation cannot output entries as arguments",
)),
})
.collect::<Result<Vec<_>>>()?;
let normalized_pred = cpr.normalized_predicate();
let wildcard_values = let wildcard_values =
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args)
.expect("resolved wildcards"); .expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let custom_predicate_table_index = custom_predicates let custom_predicate_table_index = custom_predicates
@ -136,14 +148,20 @@ pub(crate) fn extract_custom_predicate_verifications(
Ok(table) Ok(table)
} }
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MerkleProofs {
pub(crate) medium: Vec<MerkleClaimAndProof>,
pub(crate) small: Vec<MerkleClaimAndProof>,
}
/// Extracts Merkle proofs from Contains/NotContains ops. /// Extracts Merkle proofs from Contains/NotContains ops.
pub(crate) fn extract_merkle_proofs( pub(crate) fn extract_merkle_proofs(
params: &Params, params: &Params,
aux_list: &mut [OperationAux], aux_list: &mut [OperationAux],
operations: &[middleware::Operation], operations: &[middleware::Operation],
statements: &[middleware::Statement], statements: &[middleware::Statement],
) -> Result<Vec<MerkleClaimAndProof>> { ) -> Result<MerkleProofs> {
let mut table = Vec::new(); let mut tables = MerkleProofs::default();
for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() {
let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone());
let (root, key, value, pf) = match (op, st) { let (root, key, value, pf) = match (op, st) {
@ -166,31 +184,42 @@ pub(crate) fn extract_merkle_proofs(
} }
_ => continue, _ => continue,
}; };
aux_list[i] = OperationAux::MerkleProofIndex(table.len()); let claim_proof = MerkleClaimAndProof::new(Hash::from(root), key, value, pf.clone());
table.push(MerkleClaimAndProof::new( if pf.existence
Hash::from(root), // TODO: Make sure there's no off-by-one error here
key, && pf.siblings.len() <= params.containers.max_depth_small
value, && tables.small.len() < params.containers.state.max_small
pf.clone(), {
)); aux_list[i] = OperationAux::MerkleProofIndex(Size::Small, tables.small.len());
tables.small.push(claim_proof);
} else {
aux_list[i] = OperationAux::MerkleProofIndex(Size::Medium, tables.medium.len());
tables.medium.push(claim_proof);
} }
if table.len() > params.max_merkle_proofs_containers { }
if tables.medium.len() > params.containers.state.max_medium {
return Err(Error::custom(format!( return Err(Error::custom(format!(
"The number of required Merkle proofs ({}) exceeds the maximum number ({}).", "The number of required Merkle proofs ({}) exceeds the maximum number ({}).",
table.len(), tables.medium.len(),
params.max_merkle_proofs_containers params.containers.state.max_medium
))); )));
} }
Ok(table) Ok(tables)
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MerkleTransitionProofs {
pub(crate) medium: Vec<MerkleTreeStateTransitionProof>,
pub(crate) small: Vec<MerkleTreeStateTransitionProof>,
} }
/// Extracts Merkle state transition proofs from container update ops. /// Extracts Merkle state transition proofs from container update ops.
pub(crate) fn extract_merkle_tree_state_transition_proofs( pub(crate) fn extract_merkle_transition_proofs(
params: &Params, params: &Params,
aux_list: &mut [OperationAux], aux_list: &mut [OperationAux],
operations: &[middleware::Operation], operations: &[middleware::Operation],
) -> Result<Vec<MerkleTreeStateTransitionProof>> { ) -> Result<MerkleTransitionProofs> {
let mut table = Vec::new(); let mut tables = MerkleTransitionProofs::default();
for (i, op) in operations.iter().enumerate() { for (i, op) in operations.iter().enumerate() {
let pf = match op { let pf = match op {
middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf) middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf)
@ -198,17 +227,27 @@ pub(crate) fn extract_merkle_tree_state_transition_proofs(
| middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(), | middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(),
_ => continue, _ => continue,
}; };
aux_list[i] = OperationAux::MerkleTreeStateTransitionProofIndex(table.len()); if pf.op_proof.existence
table.push(pf); // TODO: Make sure there's no off-by-one error here
&& pf.siblings.len() <= params.containers.max_depth_small
&& tables.small.len() < params.containers.transition.max_small
{
aux_list[i] = OperationAux::MerkleTransitionProofIndex(Size::Small, tables.small.len());
tables.small.push(pf);
} else {
aux_list[i] =
OperationAux::MerkleTransitionProofIndex(Size::Medium, tables.medium.len());
tables.medium.push(pf);
} }
if table.len() > params.max_merkle_tree_state_transition_proofs_containers { }
if tables.medium.len() > params.containers.transition.max_medium {
return Err(Error::custom(format!( return Err(Error::custom(format!(
"The number of required Merkle proofs ({}) exceeds the maximum number ({}).", "The number of required Merkle proofs ({}) exceeds the maximum number ({}).",
table.len(), tables.medium.len(),
params.max_merkle_tree_state_transition_proofs_containers params.containers.transition.max_medium
))); )));
} }
Ok(table) Ok(tables)
} }
pub(crate) fn extract_public_key_of( pub(crate) fn extract_public_key_of(
@ -225,11 +264,10 @@ pub(crate) fn extract_public_key_of(
) = (op, st) ) = (op, st)
{ {
let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone());
let sk = SecretKey::try_from( let value = value_from_op(sk_s, sk_ref).ok_or_else(deduction_err)?;
value_from_op(sk_s, sk_ref) let sk = value
.ok_or_else(deduction_err)? .as_secret_key()
.typed(), .ok_or_else(|| Error::custom("{value} not SecretKey"))?;
)?;
aux_list[i] = OperationAux::PublicKeyOfIndex(table.len()); aux_list[i] = OperationAux::PublicKeyOfIndex(table.len());
table.push(sk); table.push(sk);
} }
@ -283,7 +321,9 @@ pub(crate) fn extract_signatures(
aux_list[i] = OperationAux::SignedByIndex(table.len()); aux_list[i] = OperationAux::SignedByIndex(table.len());
table.push(SignedBy { table.push(SignedBy {
msg: msg.raw(), msg: msg.raw(),
pk: PublicKey::try_from(pk.typed())?, pk: pk
.as_public_key()
.ok_or_else(|| Error::custom(format!("{pk} is not PublicKey")))?,
sig: sig.clone(), sig: sig.clone(),
}); });
} }
@ -327,8 +367,8 @@ pub fn pad_statement(s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args())
} }
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) { fn pad_operation_args(args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, params.max_operation_args) fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args)
} }
/// Returns the statements from the given MainPodInputs, padding to the respective max lengths /// Returns the statements from the given MainPodInputs, padding to the respective max lengths
@ -426,7 +466,7 @@ pub(crate) fn process_private_statements_operations(
.map(|mid_arg| find_op_arg(statements, mid_arg)) .map(|mid_arg| find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
pad_operation_args(params, &mut args); pad_operation_args(&mut args);
operations.push(Operation(op.op_type(), args, *aux)); operations.push(Operation(op.op_type(), args, *aux));
} }
Ok(operations) Ok(operations)
@ -457,7 +497,11 @@ pub(crate) fn process_public_statements_operations(
OperationAux::None, OperationAux::None,
) )
}; };
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); fill_pad(
&mut op.1,
OperationArg::None,
BASE_PARAMS.max_operation_args,
);
operations.push(op); operations.push(op);
} }
Ok(operations) Ok(operations)
@ -467,6 +511,7 @@ pub struct Prover {}
impl MainPodProver for Prover { impl MainPodProver for Prover {
fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result<Box<dyn Pod>> { fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result<Box<dyn Pod>> {
assert_eq!(inputs.statements.len(), inputs.operations.len());
// Pad input recursive pods with empty pods if necessary // Pad input recursive pods with empty pods if necessary
let empty_pod = if inputs.pods.len() == params.max_input_pods { let empty_pod = if inputs.pods.len() == params.max_input_pods {
// We don't need padding so we skip creating an EmptyPod // We don't need padding so we skip creating an EmptyPod
@ -495,6 +540,8 @@ impl MainPodProver for Prover {
let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; let mut aux_list = vec![OperationAux::None; params.max_priv_statements()];
let merkle_proofs = let merkle_proofs =
extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?;
let merkle_transition_proofs =
extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?;
let custom_predicates = extract_custom_predicates(params, inputs.operations)?; let custom_predicates = extract_custom_predicates(params, inputs.operations)?;
let custom_predicate_verifications = extract_custom_predicate_verifications( let custom_predicate_verifications = extract_custom_predicate_verifications(
params, params,
@ -519,9 +566,6 @@ impl MainPodProver for Prover {
let signed_bys = let signed_bys =
extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?;
let merkle_tree_state_transition_proofs =
extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?;
let (statements, public_statements) = layout_statements(params, false, &inputs)?; let (statements, public_statements) = layout_statements(params, false, &inputs)?;
let operations = process_private_statements_operations( let operations = process_private_statements_operations(
params, params,
@ -554,20 +598,15 @@ impl MainPodProver for Prover {
.collect_vec(); .collect_vec();
let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len()); let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len());
let pad_vd_mt_proof = inputs.vd_set.get_vds_proof_0();
for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) { for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) {
vd_mt_proofs.push(if pod.is_main() { vd_mt_proofs.push(if pod.is_main() {
(true, inputs.vd_set.get_vds_proof(vd)?) inputs.vd_set.get_vds_proof(vd)?
} else { } else {
// For intro pods we don't verify inclusion of their vk into the vd set, so we // For intro pods we don't verify inclusion of their vk into the vd set, so we
// generate a dummy mt proof with expected root and value to pass some constraints // use a valid vds proof that matches the expected root but not the value to pass
( // the constraints
false, pad_vd_mt_proof.clone()
MerkleClaimAndProof {
root: inputs.vd_set.root(),
value: RawValue::from(pod.verifier_data_hash()),
..MerkleClaimAndProof::empty()
},
)
}); });
} }
@ -580,7 +619,7 @@ impl MainPodProver for Prover {
merkle_proofs, merkle_proofs,
public_key_of_sks, public_key_of_sks,
signed_bys, signed_bys,
merkle_tree_state_transition_proofs, merkle_transition_proofs,
custom_predicates_with_mpt_proofs, custom_predicates_with_mpt_proofs,
custom_predicate_verifications, custom_predicate_verifications,
}; };
@ -967,7 +1006,18 @@ pub mod tests {
max_statements: 2, max_statements: 2,
max_public_statements: 1, max_public_statements: 1,
max_input_pods_public_statements: 0, max_input_pods_public_statements: 0,
max_merkle_proofs_containers: 0, containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
max_public_key_of: 0, max_public_key_of: 0,
max_custom_predicate_verifications: 0, max_custom_predicate_verifications: 0,
max_custom_predicates: 0, max_custom_predicates: 0,
@ -1003,15 +1053,23 @@ pub mod tests {
max_input_pods_public_statements: 2, max_input_pods_public_statements: 2,
max_statements: 5, max_statements: 5,
max_public_statements: 2, max_public_statements: 2,
max_operation_args: 5,
max_custom_predicates: 2, max_custom_predicates: 2,
max_custom_predicate_verifications: 2, max_custom_predicate_verifications: 2,
max_custom_predicate_wildcards: 3, max_custom_predicate_wildcards: 3,
max_merkle_proofs_containers: 2,
max_merkle_tree_state_transition_proofs_containers: 2,
max_public_key_of: 2, max_public_key_of: 2,
max_depth_mt_containers: 4,
max_depth_mt_vds: 6, max_depth_mt_vds: 6,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 2,
max_medium: 2,
},
transition: middleware::ParamsMerkleProofs {
max_small: 2,
max_medium: 2,
},
max_depth_small: 2,
max_depth_medium: 4,
},
}; };
let mut vds = DEFAULT_VD_LIST.clone(); let mut vds = DEFAULT_VD_LIST.clone();
vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone()); vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone());
@ -1068,11 +1126,20 @@ pub mod tests {
max_input_pods: 0, max_input_pods: 0,
max_statements: 9, max_statements: 9,
max_public_statements: 4, max_public_statements: 4,
max_operation_args: 5,
max_custom_predicate_wildcards: 4, max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2, max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 3, containers: middleware::ParamsContainers {
max_merkle_tree_state_transition_proofs_containers: 0, state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 3,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
..Default::default() ..Default::default()
}; };
println!("{:#?}", params); println!("{:#?}", params);
@ -1095,7 +1162,7 @@ pub mod tests {
&[stb0.clone(), stb1.clone()], &[stb0.clone(), stb1.clone()],
)?; )?;
let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?; let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?;
let cpb = cpb_builder.finish(); let cpb = cpb_builder.finish()?;
let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let cpb_and = CustomPredicateRef::new(cpb.clone(), 0);
let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1);
@ -1129,6 +1196,72 @@ pub mod tests {
Ok(pod.verify()?) Ok(pod.verify()?)
} }
#[test]
fn test_main_self_predicate_hash() -> frontend::Result<()> {
use frontend::BuilderArg;
let params = Params {
max_signed_by: 0,
max_input_pods: 0,
max_statements: 6,
max_public_statements: 2,
max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
..Default::default()
};
let mut vds = DEFAULT_VD_LIST.clone();
vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone());
let vd_set = VDSet::new(&vds);
// Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
let stb_b = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = cpb.finish()?;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut pod_builder = MainPodBuilder::new(&params, &vd_set);
let eq_st =
pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?;
pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?;
// Mock
let prover = MockProver {};
let pod = pod_builder.prove(&prover)?;
assert!(pod.pod.verify().is_ok());
// Real
let prover = Prover {};
let pod = pod_builder.prove(&prover)?;
let pod = (pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
Ok(pod.verify()?)
}
#[test] #[test]
fn test_set_contains() -> frontend::Result<()> { fn test_set_contains() -> frontend::Result<()> {
let params = Params::default(); let params = Params::default();
@ -1192,10 +1325,108 @@ pub mod tests {
); );
let st = middleware::Statement::Custom( let st = middleware::Statement::Custom(
cpr, cpr,
[1, 1, 2].into_iter().map(middleware::Value::from).collect(), [1, 1, 2]
.into_iter()
.map(middleware::ValueRef::from)
.collect(),
); );
builder.insert(true, (st, op)).unwrap(); builder.insert((st.clone(), op)).unwrap();
builder.reveal(&st).unwrap();
let prover = Prover {}; let prover = Prover {};
builder.prove(&prover).unwrap(); builder.prove(&prover).unwrap();
} }
#[test]
fn test_replace_value_with_entry() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"a" => 42, "b" => 33});
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42))
.unwrap();
let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap();
// Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)`
builder
.pub_op(frontend::Operation::replace_value_with_entry(
vec![None, Some((&d, "a"))],
st,
))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
assert_eq!(
middleware::Statement::Lt(
middleware::ValueRef::Literal(Value::from(5)),
middleware::ValueRef::Key(middleware::AnchoredKey {
root: d.commitment(),
key: middleware::Key::from("a")
})
),
pod.public_statements[0]
);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
#[test]
fn test_entry_custom_statement_arg() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let input = r#"
PredA(x) = AND(
Lt(x, 100)
)
PredB(d) = AND(
PredA(d.x)
)
"#;
let module = load_module(input, "my_mod", &params, &[]).expect("lang parse");
let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap();
let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"x" => 42, "y" => 33});
let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap();
let st_a = builder
.priv_op(frontend::Operation::custom(pred_a, [st_lt]))
.unwrap();
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42))
.unwrap();
// Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)`
let st_a1 = builder
.priv_op(frontend::Operation::replace_value_with_entry(
vec![Some((&d, "x"))],
st_a,
))
.unwrap();
builder
.pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1]))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
let expected = middleware::Statement::Custom(
pred_b,
vec![middleware::ValueRef::Literal(Value::from(d))],
);
assert_eq!(expected, pod.public_statements[0]);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
} }

View file

@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
error::{Error, Result}, error::{Error, Result},
mainpod::{SignedBy, Statement}, mainpod::{MerkleProofs, MerkleTransitionProofs, SignedBy, Statement},
primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof},
}, },
middleware::{self, OperationType, Params}, middleware::{self, OperationType, Params},
}; };
@ -30,50 +29,89 @@ impl OperationArg {
} }
} }
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum Size {
Small,
Medium,
}
impl fmt::Display for Size {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Small => write!(f, "small"),
Self::Medium => write!(f, "medium"),
}
}
}
impl Size {
pub const fn min() -> Self {
Self::Small
}
pub const fn max() -> Self {
Self::Medium
}
}
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum OperationAux { pub enum OperationAux {
None, None,
MerkleProofIndex(usize), MerkleProofIndex(Size, usize),
MerkleTransitionProofIndex(Size, usize),
PublicKeyOfIndex(usize), PublicKeyOfIndex(usize),
SignedByIndex(usize), SignedByIndex(usize),
MerkleTreeStateTransitionProofIndex(usize),
CustomPredVerifyIndex(usize), CustomPredVerifyIndex(usize),
} }
impl OperationAux { impl OperationAux {
fn table_offset_merkle_proof(_params: &Params) -> usize { fn table_offset_merkle_proof(params: &Params, size: Size) -> usize {
match size {
// At index 0 we store a zero entry // At index 0 we store a zero entry
1 Size::Small => 1,
Size::Medium => {
Self::table_offset_merkle_proof(params, Size::Small)
+ params.containers.state.max_small
}
}
}
fn table_offset_merkle_transition_proof(params: &Params, size: Size) -> usize {
match size {
Size::Small => {
Self::table_offset_merkle_proof(params, Size::min())
+ params.containers.state.max_total()
}
Size::Medium => {
Self::table_offset_merkle_transition_proof(params, Size::Small)
+ params.containers.transition.max_small
}
}
}
fn table_offset_custom_pred_verify(params: &Params) -> usize {
Self::table_offset_merkle_transition_proof(params, Size::min())
+ params.containers.transition.max_total()
} }
fn table_offset_public_key_of(params: &Params) -> usize { fn table_offset_public_key_of(params: &Params) -> usize {
Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers Self::table_offset_custom_pred_verify(params) + params.max_custom_predicate_verifications
} }
fn table_offset_signed_by(params: &Params) -> usize { fn table_offset_signed_by(params: &Params) -> usize {
Self::table_offset_public_key_of(params) + params.max_public_key_of Self::table_offset_public_key_of(params) + params.max_public_key_of
} }
fn table_offset_merkle_tree_state_transition_proof(params: &Params) -> usize {
Self::table_offset_signed_by(params) + params.max_signed_by
}
fn table_offset_custom_pred_verify(params: &Params) -> usize {
Self::table_offset_merkle_tree_state_transition_proof(params)
+ params.max_merkle_tree_state_transition_proofs_containers
}
pub(crate) fn table_size(params: &Params) -> usize { pub(crate) fn table_size(params: &Params) -> usize {
1 + params.max_merkle_proofs_containers 1 + params.containers.state.max_total()
+ params.containers.transition.max_total()
+ params.max_custom_predicate_verifications
+ params.max_public_key_of + params.max_public_key_of
+ params.max_signed_by + params.max_signed_by
+ params.max_merkle_tree_state_transition_proofs_containers
+ params.max_custom_predicate_verifications
} }
pub fn table_index(&self, params: &Params) -> usize { pub fn table_index(&self, params: &Params) -> usize {
match self { match self {
Self::None => 0, Self::None => 0,
Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i, Self::MerkleProofIndex(size, i) => Self::table_offset_merkle_proof(params, *size) + *i,
Self::MerkleTransitionProofIndex(size, i) => {
Self::table_offset_merkle_transition_proof(params, *size) + *i
}
Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i, Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i,
Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i, Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i,
Self::MerkleTreeStateTransitionProofIndex(i) => {
Self::table_offset_merkle_tree_state_transition_proof(params) + *i
}
Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i, Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i,
} }
} }
@ -96,8 +134,8 @@ impl Operation {
&self, &self,
statements: &[Statement], statements: &[Statement],
signatures: &[SignedBy], signatures: &[SignedBy],
merkle_proofs: &[MerkleClaimAndProof], merkle_proofs: &MerkleProofs,
merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProof], merkle_transition_proofs: &MerkleTransitionProofs,
) -> Result<crate::middleware::Operation> { ) -> Result<crate::middleware::Operation> {
let deref_args = self let deref_args = self
.1 .1
@ -113,17 +151,26 @@ impl Operation {
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let deref_aux = match self.2 { let deref_aux = match self.2 {
OperationAux::None => crate::middleware::OperationAux::None, OperationAux::None => crate::middleware::OperationAux::None,
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, OperationAux::MerkleProofIndex(size, i) => {
OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( let table = match size {
merkle_proofs Size::Small => &merkle_proofs.small,
Size::Medium => &merkle_proofs.medium,
};
crate::middleware::OperationAux::MerkleProof(
table
.get(i) .get(i)
.ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))?
.proof .proof
.clone(), .clone(),
), )
OperationAux::MerkleTreeStateTransitionProofIndex(i) => { }
OperationAux::MerkleTransitionProofIndex(size, i) => {
let table = match size {
Size::Small => &merkle_transition_proofs.small,
Size::Medium => &merkle_transition_proofs.medium,
};
crate::middleware::OperationAux::MerkleTreeStateTransitionProof( crate::middleware::OperationAux::MerkleTreeStateTransitionProof(
merkle_tree_state_transition_proofs table
.get(i) .get(i)
.ok_or(Error::custom(format!( .ok_or(Error::custom(format!(
"Missing Merkle state transition proof index {}", "Missing Merkle state transition proof index {}",
@ -132,6 +179,7 @@ impl Operation {
.clone(), .clone(),
) )
} }
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None,
OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature( OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature(
signatures signatures
.get(i) .get(i)
@ -165,12 +213,14 @@ impl fmt::Display for Operation {
} }
match self.2 { match self.2 {
OperationAux::None => (), OperationAux::None => (),
OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, OperationAux::MerkleProofIndex(size, i) => {
write!(f, " {}_merkle_proof_{:02}", size, i)?
}
OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?,
OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?, OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?,
OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?, OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?,
OperationAux::MerkleTreeStateTransitionProofIndex(i) => { OperationAux::MerkleTransitionProofIndex(size, i) => {
write!(f, " merkle_tree_state_transition_proof_{:02}", i)? write!(f, " {}_merkle_transition_proof_{:02}", size, i)?
} }
} }
Ok(()) Ok(())

View file

@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::error::{Error, Result}, backends::plonky2::error::{Error, Result},
middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, middleware::{
self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS,
},
}; };
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -96,15 +98,15 @@ impl TryFrom<Statement> for middleware::Statement {
)))?, )))?,
}, },
Predicate::Custom(cpr) => { Predicate::Custom(cpr) => {
let vs: Vec<Value> = proper_args let args: Vec<ValueRef> = proper_args
.into_iter() .into_iter()
.filter_map(|arg| match arg { .filter_map(|arg| match arg {
SA::None => None, StatementArg::Literal(v) => Some(ValueRef::Literal(v)),
SA::Literal(v) => Some(v), StatementArg::Key(k) => Some(ValueRef::Key(k)),
_ => unreachable!(), StatementArg::None => None,
}) })
.collect(); .collect();
S::Custom(cpr, vs) S::Custom(cpr, args)
} }
Predicate::Intro(ir) => { Predicate::Intro(ir) => {
let vs: Vec<Value> = proper_args let vs: Vec<Value> = proper_args

View file

@ -11,13 +11,12 @@ use crate::{
basetypes::{Proof, VerifierOnlyCircuitData}, basetypes::{Proof, VerifierOnlyCircuitData},
error::{Error, Result}, error::{Error, Result},
mainpod::{ mainpod::{
calculate_statements_hash, extract_merkle_proofs, calculate_statements_hash, extract_merkle_proofs, extract_merkle_transition_proofs,
extract_merkle_tree_state_transition_proofs, extract_signatures, layout_statements, extract_signatures, layout_statements, process_private_statements_operations,
process_private_statements_operations, process_public_statements_operations, Operation, process_public_statements_operations, MerkleProofs, MerkleTransitionProofs, Operation,
OperationAux, SignedBy, Statement, OperationAux, SignedBy, Statement,
}, },
mock::emptypod::MockEmptyPod, mock::emptypod::MockEmptyPod,
primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof},
recursion::hash_verifier_data, recursion::hash_verifier_data,
}, },
middleware::{ middleware::{
@ -45,10 +44,10 @@ pub struct MockMainPod {
operations: Vec<Operation>, operations: Vec<Operation>,
// public subset of the `statements` vector // public subset of the `statements` vector
public_statements: Vec<Statement>, public_statements: Vec<Statement>,
// All Merkle proofs // All Merkle proofs for containers
merkle_proofs_containers: Vec<MerkleClaimAndProof>, merkle_proofs: MerkleProofs,
// All Merkle tree state transition proofs // All Merkle tree state transition proofs for containers
merkle_tree_state_transition_proofs_containers: Vec<MerkleTreeStateTransitionProof>, merkle_transition_proofs: MerkleTransitionProofs,
// All verified signatures // All verified signatures
signatures: Vec<SignedBy>, signatures: Vec<SignedBy>,
} }
@ -124,8 +123,8 @@ struct Data {
public_statements: Vec<Statement>, public_statements: Vec<Statement>,
operations: Vec<Operation>, operations: Vec<Operation>,
statements: Vec<Statement>, statements: Vec<Statement>,
merkle_proofs: Vec<MerkleClaimAndProof>, merkle_proofs: MerkleProofs,
merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProof>, merkle_transition_proofs: MerkleTransitionProofs,
signatures: Vec<SignedBy>, signatures: Vec<SignedBy>,
input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>, input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>,
} }
@ -153,8 +152,8 @@ impl MockMainPod {
let merkle_proofs = let merkle_proofs =
extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?;
// Similarly for Merkle state transition proofs. // Similarly for Merkle state transition proofs.
let merkle_tree_state_transition_proofs = let merkle_transition_proofs =
extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?;
let signatures = let signatures =
extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?;
@ -185,8 +184,8 @@ impl MockMainPod {
public_statements, public_statements,
statements, statements,
operations, operations,
merkle_proofs_containers: merkle_proofs, merkle_proofs,
merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, merkle_transition_proofs,
signatures, signatures,
}) })
} }
@ -260,8 +259,8 @@ impl Pod for MockMainPod {
.deref( .deref(
&self.statements[..input_statement_offset + i], &self.statements[..input_statement_offset + i],
&self.signatures, &self.signatures,
&self.merkle_proofs_containers, &self.merkle_proofs,
&self.merkle_tree_state_transition_proofs_containers, &self.merkle_transition_proofs,
)? )?
.check_and_log(&self.params, &s.clone().try_into()?) .check_and_log(&self.params, &s.clone().try_into()?)
.map_err(|e| e.into()) .map_err(|e| e.into())
@ -321,10 +320,8 @@ impl Pod for MockMainPod {
public_statements: self.public_statements.clone(), public_statements: self.public_statements.clone(),
operations: self.operations.clone(), operations: self.operations.clone(),
statements: self.statements.clone(), statements: self.statements.clone(),
merkle_proofs: self.merkle_proofs_containers.clone(), merkle_proofs: self.merkle_proofs.clone(),
merkle_tree_state_transition_proofs: self merkle_transition_proofs: self.merkle_transition_proofs.clone(),
.merkle_tree_state_transition_proofs_containers
.clone(),
signatures: self.signatures.clone(), signatures: self.signatures.clone(),
input_pods, input_pods,
}) })
@ -344,7 +341,7 @@ impl Pod for MockMainPod {
operations, operations,
statements, statements,
merkle_proofs, merkle_proofs,
merkle_tree_state_transition_proofs, merkle_transition_proofs,
signatures, signatures,
input_pods, input_pods,
} = serde_json::from_value(data)?; } = serde_json::from_value(data)?;
@ -362,8 +359,8 @@ impl Pod for MockMainPod {
public_statements, public_statements,
operations, operations,
statements, statements,
merkle_proofs_containers: merkle_proofs, merkle_proofs,
merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, merkle_transition_proofs,
signatures, signatures,
}) })
} }
@ -380,7 +377,8 @@ pub mod tests {
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
zu_kyc_sign_dict_builders, MOCK_VD_SET, zu_kyc_sign_dict_builders, MOCK_VD_SET,
}, },
frontend, middleware, frontend::{self},
middleware,
middleware::{Signer as _, Value}, middleware::{Signer as _, Value},
}; };

View file

@ -207,7 +207,7 @@ impl Point {
u: *u, u: *u,
}); });
points.find(|p| p.is_in_subgroup()).ok_or(Error::custom( points.find(|p| p.is_in_subgroup()).ok_or(Error::custom(
"One of the points must lie in the EC subgroup.".into(), "One of the points must lie in the EC subgroup.",
)) ))
} }
pub fn as_bytes_from_subgroup(&self) -> Result<Vec<u8>, Error> { pub fn as_bytes_from_subgroup(&self) -> Result<Vec<u8>, Error> {

View file

@ -32,7 +32,7 @@ use crate::{
circuits::common::{CircuitBuilderPod, ValueTarget}, circuits::common::{CircuitBuilderPod, ValueTarget},
error::{Error, Result}, error::{Error, Result},
primitives::merkletree::{ primitives::merkletree::{
MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, MAX_DEPTH,
}, },
}, },
measure_gates_begin, measure_gates_end, measure_gates_begin, measure_gates_end,
@ -42,8 +42,6 @@ use crate::{
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MerkleClaimAndProofTarget { pub struct MerkleClaimAndProofTarget {
pub(crate) max_depth: usize, pub(crate) max_depth: usize,
// `enabled` determines if the merkleproof verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget, pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget, pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget, pub(crate) value: ValueTarget,
@ -121,16 +119,9 @@ pub fn verify_merkle_proof_circuit(
let obtained_root = let obtained_root =
compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings);
// check that obtained_root==root (from inputs), when enabled==true // check that obtained_root==root (from inputs)
let zero = builder.zero();
let expected_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, proof.root.elements[j], zero))
.collect();
let computed_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero))
.collect();
for j in 0..HASH_SIZE { for j in 0..HASH_SIZE {
builder.connect(computed_root[j], expected_root[j]); builder.connect(obtained_root.elements[j], proof.root.elements[j]);
} }
measure_gates_end!(builder, measure); measure_gates_end!(builder, measure);
} }
@ -139,7 +130,6 @@ impl MerkleClaimAndProofTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
MerkleClaimAndProofTarget { MerkleClaimAndProofTarget {
max_depth, max_depth,
enabled: builder.add_virtual_bool_target_safe(),
root: builder.add_virtual_hash(), root: builder.add_virtual_hash(),
key: builder.add_virtual_value(), key: builder.add_virtual_value(),
value: builder.add_virtual_value(), value: builder.add_virtual_value(),
@ -154,12 +144,7 @@ impl MerkleClaimAndProofTarget {
} }
/// assigns the given values to the targets /// assigns the given values to the targets
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn set_targets( pub fn set_targets(&self, pw: &mut PartialWitness<F>, mp: &MerkleClaimAndProof) -> Result<()> {
&self,
pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleClaimAndProof,
) -> Result<()> {
if mp.proof.siblings.len() > self.max_depth { if mp.proof.siblings.len() > self.max_depth {
return Err(Error::Tree(TreeError::circuit_depth_too_small( return Err(Error::Tree(TreeError::circuit_depth_too_small(
self.max_depth, self.max_depth,
@ -167,7 +152,6 @@ impl MerkleClaimAndProofTarget {
))); )));
} }
pw.set_bool_target(self.enabled, enabled)?;
pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?;
pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.key.elements, &mp.key.0)?;
pw.set_target_arr(&self.value.elements, &mp.value.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?;
@ -207,8 +191,6 @@ impl MerkleClaimAndProofTarget {
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct MerkleProofExistenceTarget { pub struct MerkleProofExistenceTarget {
max_depth: usize, max_depth: usize,
// `enabled` determines if the merkleproof verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget, pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget, pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget, pub(crate) value: ValueTarget,
@ -236,16 +218,9 @@ pub fn verify_merkle_proof_existence_circuit(
let obtained_root = let obtained_root =
compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings);
// check that obtained_root==root (from inputs), when enabled==true // check that obtained_root==root (from inputs)
let zero = builder.zero();
let expected_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, proof.root.elements[j], zero))
.collect();
let computed_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero))
.collect();
for j in 0..HASH_SIZE { for j in 0..HASH_SIZE {
builder.connect(computed_root[j], expected_root[j]); builder.connect(obtained_root.elements[j], proof.root.elements[j]);
} }
measure_gates_end!(builder, measure); measure_gates_end!(builder, measure);
@ -256,7 +231,6 @@ impl MerkleProofExistenceTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
MerkleProofExistenceTarget { MerkleProofExistenceTarget {
max_depth, max_depth,
enabled: builder.add_virtual_bool_target_safe(),
root: builder.add_virtual_hash(), root: builder.add_virtual_hash(),
key: builder.add_virtual_value(), key: builder.add_virtual_value(),
value: builder.add_virtual_value(), value: builder.add_virtual_value(),
@ -265,12 +239,7 @@ impl MerkleProofExistenceTarget {
} }
} }
/// assigns the given values to the targets /// assigns the given values to the targets
pub fn set_targets( pub fn set_targets(&self, pw: &mut PartialWitness<F>, mp: &MerkleClaimAndProof) -> Result<()> {
&self,
pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleClaimAndProof,
) -> Result<()> {
assert!(mp.proof.existence); // sanity check assert!(mp.proof.existence); // sanity check
if mp.proof.siblings.len() > self.max_depth { if mp.proof.siblings.len() > self.max_depth {
return Err(Error::Tree(TreeError::circuit_depth_too_small( return Err(Error::Tree(TreeError::circuit_depth_too_small(
@ -279,7 +248,6 @@ impl MerkleProofExistenceTarget {
))); )));
} }
pw.set_bool_target(self.enabled, enabled)?;
pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?;
pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.key.elements, &mp.key.0)?;
pw.set_target_arr(&self.value.elements, &mp.value.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?;
@ -456,8 +424,6 @@ fn hash_with_flag_target<H: AlgebraicHasher<F>>(
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct MerkleTreeStateTransitionProofTarget { pub struct MerkleTreeStateTransitionProofTarget {
pub(crate) max_depth: usize, pub(crate) max_depth: usize,
// `enabled` determines if the merkleproof state transition verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) op: Target, pub(crate) op: Target,
pub(crate) old_root: HashOutTarget, pub(crate) old_root: HashOutTarget,
pub(crate) op_proof: MerkleClaimAndProofTarget, pub(crate) op_proof: MerkleClaimAndProofTarget,
@ -511,7 +477,6 @@ pub fn verify_merkle_state_transition_circuit(
}; };
let new_key_proof = MerkleProofExistenceTarget { let new_key_proof = MerkleProofExistenceTarget {
max_depth: proof.max_depth, max_depth: proof.max_depth,
enabled: proof.enabled,
root, root,
key: proof.op_key, key: proof.op_key,
value: proof.op_value, value: proof.op_value,
@ -523,13 +488,7 @@ pub fn verify_merkle_state_transition_circuit(
// Insert/Delete: Non-existence // Insert/Delete: Non-existence
// Update: Existence // Update: Existence
let proof_type = is_update; let proof_type = is_update;
builder.conditional_assert_eq( builder.connect(proof.op_proof.existence.target, proof_type.target);
proof.enabled.target,
proof.op_proof.existence.target,
proof_type.target,
);
// 3.2) assert that proof.enabled matches with op_proof.enabled
builder.connect(proof.op_proof.enabled.target, proof.enabled.target);
// 4) assert proof_non_existence.root corresponds to the root // 4) assert proof_non_existence.root corresponds to the root
// specified by the op (old_root for Insert/Update and new_root // specified by the op (old_root for Insert/Update and new_root
@ -545,17 +504,9 @@ pub fn verify_merkle_state_transition_circuit(
}; };
for j in 0..HASH_SIZE { for j in 0..HASH_SIZE {
// 4.1) assert that proof.proof_non_existence.root == proof.old_root // 4.1) assert that proof.proof_non_existence.root == proof.old_root
builder.conditional_assert_eq( builder.connect(proof.op_proof.root.elements[j], claim_root.elements[j]);
proof.enabled.target,
proof.op_proof.root.elements[j],
claim_root.elements[j],
);
// 4.2) assert that the non-existence proof uses the op_key (value not needed). // 4.2) assert that the non-existence proof uses the op_key (value not needed).
builder.conditional_assert_eq( builder.connect(proof.op_proof.key.elements[j], proof.op_key.elements[j]);
proof.enabled.target,
proof.op_proof.key.elements[j],
proof.op_key.elements[j],
);
} }
// prepare value for check 5.2) // prepare value for check 5.2)
@ -593,7 +544,7 @@ pub fn verify_merkle_state_transition_circuit(
.map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j])) .map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j]))
.collect(); .collect();
for j in 0..HASH_SIZE { for j in 0..HASH_SIZE {
builder.conditional_assert_eq(proof.enabled.target, old_sibling_i[j], new_sibling_i[j]); builder.connect(old_sibling_i[j], new_sibling_i[j]);
} }
// 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that: // 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that:
@ -611,7 +562,7 @@ pub fn verify_merkle_state_transition_circuit(
let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level); let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level);
// do the case2's checks // do the case2's checks
let sel = builder.and(proof.enabled, in_case_5_2); let sel = in_case_5_2;
for j in 0..HASH_SIZE { for j in 0..HASH_SIZE {
builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero); builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero);
builder.conditional_assert_eq( builder.conditional_assert_eq(
@ -641,7 +592,6 @@ impl MerkleTreeStateTransitionProofTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
Self { Self {
max_depth, max_depth,
enabled: builder.add_virtual_bool_target_safe(),
op: builder.add_virtual_target(), op: builder.add_virtual_target(),
old_root: builder.add_virtual_hash(), old_root: builder.add_virtual_hash(),
@ -661,7 +611,6 @@ impl MerkleTreeStateTransitionProofTarget {
pub fn set_targets( pub fn set_targets(
&self, &self,
pw: &mut PartialWitness<F>, pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleTreeStateTransitionProof, mp: &MerkleTreeStateTransitionProof,
) -> Result<()> { ) -> Result<()> {
let new_siblings = mp.siblings.clone(); let new_siblings = mp.siblings.clone();
@ -672,13 +621,11 @@ impl MerkleTreeStateTransitionProofTarget {
))); )));
} }
pw.set_bool_target(self.enabled, enabled)?;
pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?; pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?;
pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?; pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?;
self.op_proof.set_targets( self.op_proof.set_targets(
pw, pw,
enabled,
&MerkleClaimAndProof { &MerkleClaimAndProof {
root: if mp.op == MerkleTreeOp::Delete { root: if mp.op == MerkleTreeOp::Delete {
mp.new_root mp.new_root
@ -703,10 +650,13 @@ impl MerkleTreeStateTransitionProofTarget {
{ {
pw.set_hash_target(self.siblings[i], HashOut::from_vec(sibling.0.to_vec()))?; pw.set_hash_target(self.siblings[i], HashOut::from_vec(sibling.0.to_vec()))?;
} }
pw.set_target( let div_lvl = if new_siblings.is_empty() {
self.divergence_level, // don't subtract since it would underflow, use MAX_DEPTH
F::from_canonical_u64((new_siblings.len() - 1) as u64), MAX_DEPTH as u64
)?; } else {
(new_siblings.len() - 1) as u64
};
pw.set_target(self.divergence_level, F::from_canonical_u64(div_lvl))?;
Ok(()) Ok(())
} }
@ -856,7 +806,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets); verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets( targets.set_targets(
&mut pw, &mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?; )?;
@ -868,6 +817,42 @@ pub mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_merkleproof_pad_valid() -> Result<()> {
// circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleClaimAndProofTarget::new_virtual(32, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, &MerkleClaimAndProof::pad())?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
#[test]
fn test_merkleproof_transition_pad_valid() -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(32, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, &MerkleTreeStateTransitionProof::pad())?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
#[test] #[test]
fn test_merkleproof_only_existence_verify() -> Result<()> { fn test_merkleproof_only_existence_verify() -> Result<()> {
for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] {
@ -903,7 +888,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets); verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets( targets.set_targets(
&mut pw, &mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?; )?;
@ -979,7 +963,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets); verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets( targets.set_targets(
&mut pw, &mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?; )?;
@ -1025,32 +1008,15 @@ pub mod tests {
let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets); verify_merkle_proof_circuit(&mut builder, &targets);
// verification enabled & proof of existence // proof of existence
let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof); let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof);
targets.set_targets(&mut pw, true, &mp)?; targets.set_targets(&mut pw, &mp)?;
// generate proof, expecting it to fail (since we're using the wrong // generate proof, expecting it to fail (since we're using the wrong
// root) // root)
let data = builder.build::<C>(); let data = builder.build::<C>();
assert!(data.prove(pw).is_err()); assert!(data.prove(pw).is_err());
// Now generate a new proof, using `enabled=false`, which should pass the verification
// despite containing 'wrong' witness.
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets);
// verification disabled & proof of existence
targets.set_targets(&mut pw, false, &mp)?;
// generate proof, should pass despite using wrong witness, since the
// `enabled=false`
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(()) Ok(())
} }
@ -1073,7 +1039,7 @@ pub mod tests {
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets); verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, true, state_transition_proof)?; targets.set_targets(&mut pw, state_transition_proof)?;
// generate & verify proof // generate & verify proof
let data = builder.build::<C>(); let data = builder.build::<C>();
@ -1270,71 +1236,4 @@ pub mod tests {
assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check
Ok(()) Ok(())
} }
#[test]
fn test_state_transition_gadget_disabled() -> Result<()> {
let max_depth: usize = 32;
let mut kvs = HashMap::new();
for i in 0..8 {
kvs.insert(RawValue::from(i), RawValue::from(1000 + i));
}
let mut tree = MerkleTree::new(&kvs);
let key = RawValue::from(37);
let value = RawValue::from(1037);
let _ = tree.insert(&key, &value)?;
let key = RawValue::from(21);
let value = RawValue::from(1021);
let original_state_transition_proof = tree.insert(&key, &value)?;
let mut state_transition_proof = original_state_transition_proof.clone();
// modify the proof, so that it should fail when `enabled=true`, by
// changing the new_root
state_transition_proof.new_root = state_transition_proof.old_root;
run_circuit_disabled(max_depth, &state_transition_proof)?;
// modify the proof, so that it should fail when `enabled=true`, by
// changing the new_sibling at the divergence level, which should not
// pass the verification in the case where we're inserting key=21
let mut state_transition_proof = original_state_transition_proof.clone();
state_transition_proof.siblings[4] = EMPTY_HASH;
run_circuit_disabled(max_depth, &state_transition_proof)?;
Ok(())
}
fn run_circuit_disabled(
max_depth: usize,
state_transition_proof: &MerkleTreeStateTransitionProof,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, true, state_transition_proof)?;
// generate proof, and expect it to fail
let data = builder.build::<C>();
assert!(data.prove(pw).is_err()); // expect prove to fail
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, false, state_transition_proof)?;
// generate and expect it to pass
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
} }

View file

@ -0,0 +1,97 @@
//! Module that implements the key-value DB used at the MerkleTree module.
use std::{
collections::HashMap,
fmt::Debug,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Result};
use dyn_clone::DynClone;
use crate::{
backends::plonky2::primitives::merkletree::{Intermediate, Node},
middleware::{Hash, EMPTY_HASH},
};
#[cfg(feature = "db_rocksdb")]
pub mod rocks;
pub trait DB: Debug + DynClone + Sync + Send {
/// Must always return the empty intermediate node when hash is EMPTY_HASH
fn load_node(&self, hash: Hash) -> Result<Option<Node>>;
fn store_node(&mut self, node: Node) -> Result<()>;
}
dyn_clone::clone_trait_object!(DB);
/// MemDB implements the DB trait in a in-memory HashMap.
#[derive(Clone, Debug, Default)]
pub(crate) struct MemDB {
inner: Arc<Mutex<HashMap<Hash, Node>>>,
}
impl MemDB {
pub fn new() -> Self {
Self::default()
}
}
impl DB for MemDB {
fn load_node(&self, hash: Hash) -> Result<Option<Node>> {
let db = self
.inner
.lock()
.map_err(|e| anyhow!("failed to acquire memdb lock for read: {}", e))?;
if hash == EMPTY_HASH {
return Ok(Some(Node::Intermediate(Intermediate::new(
EMPTY_HASH, EMPTY_HASH,
))));
}
Ok(db.get(&hash).cloned())
}
fn store_node(&mut self, node: Node) -> Result<()> {
let mut db = self
.inner
.lock()
.map_err(|e| anyhow!("failed to acquire memdb lock for write: {}", e))?;
db.insert(node.hash(), node);
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use super::{super::Leaf, *};
#[test]
fn test_db() -> Result<()> {
let mut db = MemDB::new();
test_db_opt(&mut db)?;
#[cfg(feature = "db_rocksdb")]
{
let path = "/tmp/rocksdb";
let mut db = rocks::RocksDB::open(path)?;
test_db_opt(&mut db)?;
}
Ok(())
}
fn test_db_opt(db: &mut dyn DB) -> Result<()> {
let node = Leaf::new(1.into(), 1.into());
db.store_node(Node::Leaf(node.clone()))?;
let obtained_node = db.load_node(node.hash)?.unwrap();
let leaf = match obtained_node {
Node::Leaf(l) => l,
_ => panic!("expected a leaf"),
};
assert_eq!(leaf.hash, node.hash);
Ok(())
}
}

View file

@ -0,0 +1,55 @@
use std::{fmt, path::Path, sync::Arc};
use anyhow::{anyhow, Result};
use rocksdb::{Options, TransactionDB, TransactionDBOptions};
use crate::{
backends::plonky2::primitives::merkletree::{self, db},
middleware::{Hash, RawValue, EMPTY_HASH},
};
#[derive(Clone)]
pub struct RocksDB(Arc<TransactionDB>);
#[allow(dead_code)]
impl RocksDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let mut options = Options::default();
options.create_if_missing(true);
let txn_options = TransactionDBOptions::default();
let inner =
TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?;
Ok(Self(Arc::new(inner)))
}
}
impl fmt::Debug for RocksDB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RocksDB")
}
}
impl db::DB for RocksDB {
fn load_node(&self, hash: Hash) -> Result<Option<merkletree::Node>> {
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
match self
.0
.get(RawValue::from(hash).to_bytes())
.map_err(|e| anyhow!("rocksdb: get failed: {e}"))?
{
None => Ok(None),
Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)),
}
}
fn store_node(&mut self, node: merkletree::Node) -> Result<()> {
self.0
.put(RawValue::from(node.hash()).to_bytes(), node.encode()?)
.map_err(|e| anyhow!("rocksdb transaction put failed: {e}"))
}
}

View file

@ -2,12 +2,16 @@
use std::{backtrace::Backtrace, fmt::Debug}; use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::Hash;
pub type TreeResult<T, E = TreeError> = core::result::Result<T, E>; pub type TreeResult<T, E = TreeError> = core::result::Result<T, E>;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum TreeInnerError { pub enum TreeInnerError {
#[error("key not found")] #[error("key not found")]
KeyNotFound, KeyNotFound,
#[error("node with hash {0} not found")]
NodeNotFound(Hash),
#[error("key already exists")] #[error("key already exists")]
KeyExists, KeyExists,
#[error("max depth reached")] #[error("max depth reached")]
@ -22,6 +26,9 @@ pub enum TreeInnerError {
StateTransitionProofFail(String), StateTransitionProofFail(String),
#[error("circuit max_depth {0} is smaller than proof depth {1}")] #[error("circuit max_depth {0} is smaller than proof depth {1}")]
CircuitDepthTooSmall(usize, usize), CircuitDepthTooSmall(usize, usize),
// Other
#[error("{0}")]
Custom(String),
} }
#[derive(thiserror::Error)] #[derive(thiserror::Error)]
@ -31,8 +38,8 @@ pub enum TreeError {
inner: Box<TreeInnerError>, inner: Box<TreeInnerError>,
backtrace: Box<Backtrace>, backtrace: Box<Backtrace>,
}, },
#[error("anyhow::Error: {0}")] #[error("database error: {0}")]
Anyhow(#[from] anyhow::Error), Database(anyhow::Error),
} }
impl Debug for TreeError { impl Debug for TreeError {
@ -60,6 +67,9 @@ impl TreeError {
pub(crate) fn key_not_found() -> Self { pub(crate) fn key_not_found() -> Self {
new!(KeyNotFound) new!(KeyNotFound)
} }
pub(crate) fn node_not_found(hash: Hash) -> Self {
new!(NodeNotFound(hash))
}
pub(crate) fn key_exists() -> Self { pub(crate) fn key_exists() -> Self {
new!(KeyExists) new!(KeyExists)
} }
@ -81,4 +91,7 @@ impl TreeError {
pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self { pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self {
new!(CircuitDepthTooSmall(circuit_depth, proof_depth)) new!(CircuitDepthTooSmall(circuit_depth, proof_depth))
} }
pub(crate) fn custom(s: impl Into<String>) -> Self {
new!(Custom(s.into()))
}
} }

File diff suppressed because it is too large Load diff

View file

@ -180,11 +180,7 @@ impl EthDosHelper {
}; };
assert_eq!(int, Value::from(int_attestation.public_key)); assert_eq!(int, Value::from(int_attestation.public_key));
let n_i64 = if let TypedValue::Int(x) = n.typed() { let n_i64 = n.as_int().unwrap();
*x
} else {
panic!("distance value is not Int")
};
// eth_dos src->dst dist=n+1 // eth_dos src->dst dist=n+1
self.n_plus_1(&mut pod, eth_dos_int_to_dst, int_attestation, n_i64)?; self.n_plus_1(&mut pod, eth_dos_int_to_dst, int_attestation, n_i64)?;

View file

@ -18,6 +18,8 @@ pub enum BuilderArg {
/// Key: (origin, key), where origin is Wildcard and key is Key /// Key: (origin, key), where origin is Wildcard and key is Key
Key(String, String), Key(String, String),
WildcardLiteral(String), WildcardLiteral(String),
/// Reference to a same-batch predicate's identity hash (resolved by name in finish()).
SelfPredicateHash(String),
} }
/// When defining a `BuilderArg`, it can be done from 3 different inputs: /// When defining a `BuilderArg`, it can be done from 3 different inputs:
@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder {
params: Params, params: Params,
pub name: String, pub name: String,
pub predicates: Vec<CustomPredicate>, pub predicates: Vec<CustomPredicate>,
/// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name)
pending_self_pred_hashes: Vec<(usize, usize, usize, String)>,
} }
impl CustomPredicateBatchBuilder { impl CustomPredicateBatchBuilder {
@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder {
params, params,
name, name,
predicates: Vec::new(), predicates: Vec::new(),
pending_self_pred_hashes: Vec::new(),
} }
} }
@ -171,6 +176,12 @@ impl CustomPredicateBatchBuilder {
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Result<Predicate> { ) -> Result<Predicate> {
if self.predicates.iter().any(|p| p.name == name) {
return Err(Error::custom(format!(
"Duplicate predicate name '{}' in batch",
name
)));
}
if self.predicates.len() >= Params::max_custom_batch_size() { if self.predicates.len() >= Params::max_custom_batch_size() {
return Err(Error::max_length( return Err(Error::max_length(
"self.predicates.len".to_string(), "self.predicates.len".to_string(),
@ -194,14 +205,18 @@ impl CustomPredicateBatchBuilder {
)); ));
} }
let pred_idx = self.predicates.len();
let mut pending = Vec::new();
let statements = sts let statements = sts
.iter() .iter()
.map(|sb| { .enumerate()
.map(|(stmt_idx, sb)| {
let stb = sb.clone().desugar(); let stb = sb.clone().desugar();
let st_tmpl_args = stb let st_tmpl_args = stb
.args .args
.iter() .iter()
.map(|a| { .enumerate()
.map(|(arg_idx, a)| {
Ok::<_, Error>(match a { Ok::<_, Error>(match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey(
@ -211,6 +226,22 @@ impl CustomPredicateBatchBuilder {
BuilderArg::WildcardLiteral(v) => { BuilderArg::WildcardLiteral(v) => {
StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?)
} }
BuilderArg::SelfPredicateHash(pred_name) => {
// Try backward reference first
match self.predicates.iter().position(|p| p.name == *pred_name) {
Some(index) => StatementTmplArg::SelfPredicateHash(index),
None => {
// Forward reference - placeholder, resolved in finish()
pending.push((
pred_idx,
stmt_idx,
arg_idx,
pred_name.clone(),
));
StatementTmplArg::SelfPredicateHash(0)
}
}
}
}) })
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
@ -240,11 +271,27 @@ impl CustomPredicateBatchBuilder {
.collect(), .collect(),
)?; )?;
self.predicates.push(custom_predicate); self.predicates.push(custom_predicate);
self.pending_self_pred_hashes.extend(pending);
Ok(Predicate::BatchSelf(self.predicates.len() - 1)) Ok(Predicate::BatchSelf(self.predicates.len() - 1))
} }
pub fn finish(self) -> Arc<CustomPredicateBatch> { pub fn finish(mut self) -> Result<Arc<CustomPredicateBatch>> {
CustomPredicateBatch::new(self.name, self.predicates) // Resolve forward references for SelfPredicateHash
for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes {
let target_idx = self
.predicates
.iter()
.position(|p| p.name == *name)
.ok_or_else(|| {
Error::custom(format!(
"SelfPredicateHash references unknown predicate '{}'",
name
))
})?;
self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] =
StatementTmplArg::SelfPredicateHash(target_idx);
}
Ok(CustomPredicateBatch::new(self.name, self.predicates))
} }
} }
@ -269,7 +316,9 @@ mod tests {
backends::plonky2::mock::mainpod::MockProver, backends::plonky2::mock::mainpod::MockProver,
examples::{custom::eth_dos_batch, MOCK_VD_SET}, examples::{custom::eth_dos_batch, MOCK_VD_SET},
frontend::{MainPodBuilder, Operation}, frontend::{MainPodBuilder, Operation},
middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET}, middleware::{
self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET,
},
}; };
#[test] #[test]
@ -306,7 +355,7 @@ mod tests {
.arg("s2"); .arg("s2");
builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?; builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?;
let batch = builder.finish(); let batch = builder.finish()?;
let batch_clone = batch.clone(); let batch_clone = batch.clone();
let gt_custom_pred = CustomPredicateRef::new(batch, 0); let gt_custom_pred = CustomPredicateRef::new(batch, 0);
@ -356,7 +405,7 @@ mod tests {
&[], &[],
&[set_contains_stb], &[set_contains_stb],
)?; )?;
let batch = builder.finish(); let batch = builder.finish()?;
let batch_clone = batch.clone(); let batch_clone = batch.clone();
let mut mp_builder = MainPodBuilder::new(&params, vd_set); let mut mp_builder = MainPodBuilder::new(&params, vd_set);
@ -386,4 +435,83 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_builder_self_predicate_hash_unknown_ref() {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("nonexistent".into()));
builder
.predicate_and("pred_A", &["x"], &[], &[stb])
.unwrap();
// finish() should fail because "nonexistent" was never defined
assert!(builder.finish().is_err());
}
/// Tests cyclic SelfPredicateHash references end-to-end:
/// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward
/// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD
/// using pred_A via MockProver.
#[test]
fn test_builder_self_predicate_hash_e2e() -> Result<()> {
let params = Params::default();
let vd_set = &*MOCK_VD_SET;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
// pred_A references pred_B's hash (forward ref, pred_B not yet defined)
let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
// pred_B references pred_A's hash (backward ref, pred_A already defined)
let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = builder.finish()?;
// Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0)
assert_eq!(
batch.predicates()[0].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(1)
);
assert_eq!(
batch.predicates()[1].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(0)
);
// Compute concrete hashes
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut mp_builder = MainPodBuilder::new(&params, vd_set);
let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?;
mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?;
// Prove and verify
let prover = MockProver {};
let proof = mp_builder.prove(&prover)?;
proof.pod.verify()?;
// Verify the public statement contains pred_b_hash as its argument
let pub_sts = proof.pod.pub_self_statements();
let custom_st = pub_sts
.iter()
.find(|s| matches!(s, middleware::Statement::Custom(_, _)))
.expect("should have a custom statement");
if let middleware::Statement::Custom(_, args) = custom_st {
assert_eq!(args[0], ValueRef::Literal(pred_b_hash));
}
Ok(())
}
} }

View file

@ -4,7 +4,7 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::From, convert::From,
fmt, fmt, iter,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -13,10 +13,12 @@ use serde::{Deserialize, Serialize};
pub use serialization::SerializedMainPod; pub use serialization::SerializedMainPod;
use crate::middleware::{ use crate::middleware::{
self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op, self, check_custom_pred,
prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, containers::{Container, Dictionary},
OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key,
StatementArg, VDSet, Value, ValueRef, MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey,
RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS,
EMPTY_VALUE,
}; };
mod custom; mod custom;
@ -92,8 +94,11 @@ impl fmt::Display for SignedDict {
// https://0xparc.github.io/pod2/merkletree.html will not need it since it will be // https://0xparc.github.io/pod2/merkletree.html will not need it since it will be
// deterministic based on the keys values not on the order of the keys when added into the // deterministic based on the keys values not on the order of the keys when added into the
// tree. // tree.
for (k, v) in self.dict.kvs().iter().sorted_by_key(|kv| kv.0.hash()) { for kv in self.dict.iter() {
writeln!(f, " - {} = {}", k, v)?; match kv {
Ok((k, v)) => writeln!(f, " - {} = {}", k, v)?,
Err(e) => writeln!(f, " - ERR: {}", e)?,
}
} }
Ok(()) Ok(())
} }
@ -106,16 +111,13 @@ impl SignedDict {
.then_some(()) .then_some(())
.ok_or(Error::custom("Invalid signature!")) .ok_or(Error::custom("Invalid signature!"))
} }
pub fn kvs(&self) -> &HashMap<Key, Value> { pub fn get(&self, key: impl Into<Key>) -> Option<Value> {
self.dict.kvs() self.dict.get(&key.into()).unwrap()
}
pub fn get(&self, key: impl Into<Key>) -> Option<&Value> {
self.kvs().get(&key.into())
} }
// Returns the Contains statement that defines key if it exists. // Returns the Contains statement that defines key if it exists.
pub fn get_statement(&self, key: impl Into<Key>) -> Option<Statement> { pub fn get_statement(&self, key: impl Into<Key>) -> Option<Statement> {
let key: Key = key.into(); let key: Key = key.into();
self.kvs().get(&key).map(|value| { self.dict.get(&key).unwrap().map(|value| {
Statement::Contains( Statement::Contains(
ValueRef::Literal(Value::from(self.dict.clone())), ValueRef::Literal(Value::from(self.dict.clone())),
ValueRef::Literal(Value::from(key.name())), ValueRef::Literal(Value::from(key.name())),
@ -136,7 +138,7 @@ pub struct MainPodBuilder {
pub operations: Vec<Operation>, pub operations: Vec<Operation>,
pub public_statements: Vec<Statement>, pub public_statements: Vec<Statement>,
// Internal state // Internal state
dict_contains: Vec<(Value, Value)>, // (root, key) contains: Vec<(RawValue, RawValue)>, // (root, key)
} }
impl fmt::Display for MainPodBuilder { impl fmt::Display for MainPodBuilder {
@ -156,6 +158,11 @@ impl fmt::Display for MainPodBuilder {
} }
} }
fn as_container_or_err(v: &Value) -> Result<Container> {
v.as_container()
.ok_or_else(|| Error::custom(format!("{v} not a container")))
}
impl MainPodBuilder { impl MainPodBuilder {
pub fn new(params: &Params, vd_set: &VDSet) -> Self { pub fn new(params: &Params, vd_set: &VDSet) -> Self {
Self { Self {
@ -165,10 +172,16 @@ impl MainPodBuilder {
statements: Vec::new(), statements: Vec::new(),
operations: Vec::new(), operations: Vec::new(),
public_statements: Vec::new(), public_statements: Vec::new(),
dict_contains: Vec::new(), contains: Vec::new(),
} }
} }
pub fn stmt_len(&self) -> usize {
self.statements.len()
}
pub fn add_pod(&mut self, pod: MainPod) -> Result<()> { pub fn add_pod(&mut self, pod: MainPod) -> Result<()> {
for st in &pod.public_statements {
self.track_contains(st);
}
self.input_pods.push(pod); self.input_pods.push(pod);
match self.input_pods.len() > self.params.max_input_pods { match self.input_pods.len() > self.params.max_input_pods {
true => Err(Error::too_many_input_pods( true => Err(Error::too_many_input_pods(
@ -178,31 +191,26 @@ impl MainPodBuilder {
_ => Ok(()), _ => Ok(()),
} }
} }
pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) -> Result<()> {
// TODO: Do error handling instead of panic
let (st, op) = st_op;
// If we're adding a Contains statement with literal arguments (an Entry), track it in // If we're adding a Contains statement with literal arguments (an Entry), track it in
// `dict_contains` to avoid adding it again via `Self::add_entries_contains`. // `dict_contains` to avoid adding it again via `Self::add_entries_contains`.
fn track_contains(&mut self, st: &Statement) {
if let Statement::Contains( if let Statement::Contains(
ValueRef::Literal(dict), ValueRef::Literal(dict),
ValueRef::Literal(key), ValueRef::Literal(key),
ValueRef::Literal(_), ValueRef::Literal(_),
) = &st ) = &st
{ {
let root_key = (dict.clone(), key.clone()); let root_key = (dict.raw(), key.raw());
self.dict_contains.push(root_key); self.contains.push(root_key);
}
} }
if public { pub fn insert(&mut self, st_op: (Statement, Operation)) -> Result<()> {
self.public_statements.push(st.clone()); // TODO: Do error handling instead of panic
} let (st, op) = st_op;
if self.public_statements.len() > self.params.max_public_statements { self.track_contains(&st);
return Err(Error::too_many_public_statements(
self.public_statements.len(),
self.params.max_public_statements,
));
}
self.statements.push(st); self.statements.push(st);
self.operations.push(op); self.operations.push(op);
if self.statements.len() > self.params.max_statements { if self.statements.len() > self.params.max_statements {
@ -347,11 +355,12 @@ impl MainPodBuilder {
.ok_or(Error::custom(format!( .ok_or(Error::custom(format!(
"Invalid key argument for op {}.", "Invalid key argument for op {}.",
op op
)))?; )))?
.raw();
let proof = if op_type == &Native(ContainsFromEntries) { let proof = if op_type == &Native(ContainsFromEntries) {
container.prove_existence(key)?.1 as_container_or_err(container)?.prove(key)?.1
} else { } else {
container.prove_nonexistence(key)? as_container_or_err(container)?.prove_nonexistence(key)?
}; };
Ok(Operation(op_type.clone(), op.1, OpAux::MerkleProof(proof))) Ok(Operation(op_type.clone(), op.1, OpAux::MerkleProof(proof)))
} }
@ -375,18 +384,16 @@ impl MainPodBuilder {
let value = let value =
op.1.get(3) op.1.get(3)
.and_then(|arg| arg.value()) .and_then(|arg| arg.value())
.ok_or(Error::custom(format!( .cloned()
"Invalid key argument for op {}.", .unwrap_or(Value::from(EMPTY_VALUE));
op
)));
let proof = match op_type { let proof = match op_type {
Native(ContainerInsertFromEntries) => { Native(ContainerInsertFromEntries) => {
old_container.prove_insertion(key, value?)? as_container_or_err(old_container)?.insert(key.clone(), value)?
} }
Native(ContainerUpdateFromEntries) => { Native(ContainerUpdateFromEntries) => {
old_container.prove_update(key, value?)? as_container_or_err(old_container)?.update(key.raw(), value)?
} }
_ => old_container.prove_deletion(key)?, _ => as_container_or_err(old_container)?.delete(key.raw())?,
}; };
Ok(Operation( Ok(Operation(
op_type.clone(), op_type.clone(),
@ -399,7 +406,7 @@ impl MainPodBuilder {
} }
fn op_statement( fn op_statement(
&mut self, &self,
wildcard_values: Vec<(usize, Value)>, wildcard_values: Vec<(usize, Value)>,
op: Operation, op: Operation,
) -> Result<Statement> { ) -> Result<Statement> {
@ -560,6 +567,37 @@ impl MainPodBuilder {
// TODO: validate proof // TODO: validate proof
Statement::ContainerDelete(r1, r2, r3) Statement::ContainerDelete(r1, r2, r3)
} }
(ReplaceValueWithEntry, &args, _) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = match args.pop().expect("valid vec len") {
OperationArg::Statement(st) => st,
_ => return Err(Error::custom("expected statement")),
};
let new_st_args = iter::zip(st.args().into_iter(), args)
.map(|(st_arg, arg)| match (st_arg, arg) {
(st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg),
(
StatementArg::Literal(st_arg_v),
OperationArg::Statement(Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
)),
) if st_arg_v == v => root_key_to_ak(&root, &key)
.map(StatementArg::Key)
.ok_or_else(native_arg_error),
_ => Err(Error::custom("unexpected operation argument")),
})
.collect::<Result<Vec<_>>>()?;
Statement::from_args(st.predicate(), new_st_args)?
}
(t, _, _) => { (t, _, _) => {
if t.is_syntactic_sugar() { if t.is_syntactic_sugar() {
return Err(Error::custom(format!( return Err(Error::custom(format!(
@ -573,7 +611,7 @@ impl MainPodBuilder {
} }
} }
OperationType::Custom(cpr) => { OperationType::Custom(cpr) => {
let pred = &cpr.batch.predicates()[cpr.index]; let pred = cpr.normalized_predicate();
if pred.statements.len() != op.1.len() { if pred.statements.len() != op.1.len() {
return Err(Error::custom(format!( return Err(Error::custom(format!(
"Custom predicate operation needs {} statements but has {}.", "Custom predicate operation needs {} statements but has {}.",
@ -601,7 +639,7 @@ impl MainPodBuilder {
} }
wildcard_map[index] = Some(value); wildcard_map[index] = Some(value);
} }
fill_wildcard_values(pred, &args, &mut wildcard_map)?; fill_wildcard_values(&pred, &args, &mut wildcard_map)?;
let v_default = Value::from(0); let v_default = Value::from(0);
let st_args: Vec<_> = wildcard_map let st_args: Vec<_> = wildcard_map
.into_iter() .into_iter()
@ -609,14 +647,14 @@ impl MainPodBuilder {
.map(|v| v.unwrap_or_else(|| v_default.clone())) .map(|v| v.unwrap_or_else(|| v_default.clone()))
.collect(); .collect();
check_custom_pred(&self.params, &cpr, &args, &st_args)?; check_custom_pred(&self.params, &cpr, &args, &st_args)?;
Statement::Custom(cpr, st_args) Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect())
} }
}; };
Ok(st) Ok(st)
} }
/// For every operation that has Entry statements as arguments we add a Contains statement to /// For every operation that has Entry statements as arguments we add a Contains statement to
/// open the dictionary. /// open the dictionary (unless such Contains already exists).
fn add_entries_contains(&mut self, op: &Operation) -> Result<()> { fn add_entries_contains(&mut self, op: &Operation) -> Result<()> {
for arg in &op.1 { for arg in &op.1 {
if let OperationArg::Statement(Statement::Contains( if let OperationArg::Statement(Statement::Contains(
@ -625,9 +663,9 @@ impl MainPodBuilder {
ValueRef::Literal(v), ValueRef::Literal(v),
)) = arg )) = arg
{ {
let root_key = (dict.clone(), key.clone()); let root_key = (dict.raw(), key.raw());
if !self.dict_contains.contains(&root_key) { if !self.contains.contains(&root_key) {
self.dict_contains.push(root_key); self.contains.push(root_key);
self.priv_op(Operation::dict_contains(dict, key, v))?; self.priv_op(Operation::dict_contains(dict, key, v))?;
} }
} }
@ -645,14 +683,29 @@ impl MainPodBuilder {
self.add_entries_contains(&op)?; self.add_entries_contains(&op)?;
let op = Self::fill_in_aux(Self::lower_op(op)?)?; let op = Self::fill_in_aux(Self::lower_op(op)?)?;
let st = self.op_statement(wildcard_values, op.clone())?; let st = self.op_statement(wildcard_values, op.clone())?;
self.insert(public, (st, op))?; // Skip adding the statement and operation if it already exists
if !self.statements.contains(&st) {
Ok(self.statements[self.statements.len() - 1].clone()) self.insert((st.clone(), op))?;
}
if public {
self.reveal(&st)?;
} }
pub fn reveal(&mut self, st: &Statement) { Ok(st)
}
pub fn reveal(&mut self, st: &Statement) -> Result<()> {
if !self.public_statements.contains(st) {
self.public_statements.push(st.clone()); self.public_statements.push(st.clone());
} }
if self.public_statements.len() > self.params.max_public_statements {
return Err(Error::too_many_public_statements(
self.public_statements.len(),
self.params.max_public_statements,
));
}
Ok(())
}
pub fn prove(&self, prover: &dyn MainPodProver) -> Result<MainPod> { pub fn prove(&self, prover: &dyn MainPodProver) -> Result<MainPod> {
let compiler = MainPodCompiler::new(&self.params); let compiler = MainPodCompiler::new(&self.params);
@ -1346,11 +1399,9 @@ pub mod tests {
OperationAux::None, OperationAux::None,
); );
builder builder
.insert(false, (value_of_a.clone(), op_contains.clone())) .insert((value_of_a.clone(), op_contains.clone()))
.unwrap();
builder
.insert(false, (value_of_b.clone(), op_contains))
.unwrap(); .unwrap();
builder.insert((value_of_b.clone(), op_contains)).unwrap();
let st = Statement::equal( let st = Statement::equal(
AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")), AnchoredKey::from((&local, "b")),
@ -1363,7 +1414,7 @@ pub mod tests {
], ],
OperationAux::None, OperationAux::None,
); );
builder.insert(false, (st, op)).unwrap(); builder.insert((st, op)).unwrap();
let prover = MockProver {}; let prover = MockProver {};
let pod = builder.prove(&prover).unwrap(); let pod = builder.prove(&prover).unwrap();

View file

@ -6,60 +6,20 @@
use std::collections::BTreeSet; use std::collections::BTreeSet;
use crate::{ use crate::{
frontend::{Operation, OperationArg}, frontend::Operation,
middleware::{ middleware::{CustomPredicateRef, Hash, NativeOperation, OperationType, Predicate},
CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef,
},
}; };
/// Unique identifier for a custom predicate batch. /// Unique identifier for a custom predicate in a module.
/// ///
/// Uses the batch's cryptographic hash as identifier. Two batches with the same /// Uses the predicate's cryptographic hash as identifier. Two predicates with the same
/// hash are considered identical for resource counting purposes. /// hash are considered identical for resource counting purposes.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CustomBatchId(pub Hash); pub struct CustomPredicateId(pub Hash);
impl From<&CustomPredicateBatch> for CustomBatchId { impl From<&CustomPredicateRef> for CustomPredicateId {
fn from(batch: &CustomPredicateBatch) -> Self { fn from(predicate: &CustomPredicateRef) -> Self {
Self(batch.id()) Self(Predicate::Custom(predicate.clone()).hash())
}
}
/// Unique identifier for an anchored key (dict, key) pair.
///
/// When a Contains statement is used as an argument to operations like gt(), eq(), etc.,
/// the value is accessed via an "anchored key" - a reference to a specific key in a
/// specific dictionary. Each unique anchored key used in a POD requires a Contains
/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed).
///
/// We use the raw values of the dict and key for comparison, as they uniquely identify
/// the anchored key regardless of the specific Value types involved.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AnchoredKeyId {
/// The dictionary root value (raw representation for Ord).
pub dict: RawValue,
/// The key within the dictionary (raw representation for Ord).
pub key: RawValue,
}
impl AnchoredKeyId {
/// Create a new anchored key ID from raw values.
pub fn new(dict: RawValue, key: RawValue) -> Self {
Self { dict, key }
}
/// Try to extract an anchored key ID from a Contains statement with all literal values.
pub fn from_contains_statement(stmt: &Statement) -> Option<Self> {
if let Statement::Contains(
ValueRef::Literal(dict),
ValueRef::Literal(key),
ValueRef::Literal(_value),
) = stmt
{
Some(Self::new(dict.raw(), key.raw()))
} else {
None
}
} }
} }
@ -88,17 +48,9 @@ pub struct StatementCost {
/// Limit: `params.max_public_key_of` /// Limit: `params.max_public_key_of`
pub public_key_of: usize, pub public_key_of: usize,
/// Custom predicate batches used (for batch cardinality constraint). /// Custom predicates used (for custom predicate cardinality constraint).
/// Limit: `params.max_custom_predicate_batches` distinct batches per POD. /// Limit: `params.max_custom_predicates` distinct custom predicates per POD.
pub custom_batch_ids: BTreeSet<CustomBatchId>, pub custom_predicates_ids: BTreeSet<CustomPredicateId>,
/// Anchored keys referenced by this operation.
///
/// When a Contains statement with all literal values is used as an argument,
/// the operation references an "anchored key" (dict, key pair). Each unique
/// anchored key used in a POD incurs an additional Contains statement cost,
/// as MainPodBuilder::add_entries_contains will auto-insert it if not already present.
pub anchored_keys: BTreeSet<AnchoredKeyId>,
} }
impl StatementCost { impl StatementCost {
@ -159,25 +111,14 @@ impl StatementCost {
// Syntactic sugar variants (lowered before proving) // Syntactic sugar variants (lowered before proving)
| NativeOperation::GtEqFromEntries | NativeOperation::GtEqFromEntries
| NativeOperation::GtFromEntries | NativeOperation::GtFromEntries
| NativeOperation::GtToNotEqual => {} | NativeOperation::GtToNotEqual
| NativeOperation::ReplaceValueWithEntry => {}
} }
} }
OperationType::Custom(cpr) => { OperationType::Custom(cpr) => {
cost.custom_pred_verifications = 1; cost.custom_pred_verifications = 1;
cost.custom_batch_ids cost.custom_predicates_ids
.insert(CustomBatchId::from(&*cpr.batch)); .insert(CustomPredicateId::from(cpr));
}
}
// Extract anchored keys from operation arguments.
// Any argument that is a Contains statement with all literal values
// represents an anchored key reference that will require a Contains
// statement in the POD (auto-inserted by MainPodBuilder if needed).
for arg in &op.1 {
if let OperationArg::Statement(stmt) = arg {
if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) {
cost.anchored_keys.insert(anchored_key);
}
} }
} }

View file

@ -5,7 +5,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use super::cost::AnchoredKeyId;
use crate::{ use crate::{
frontend::{Operation, OperationArg}, frontend::{Operation, OperationArg},
middleware::{Hash, Statement}, middleware::{Hash, Statement},
@ -100,11 +99,6 @@ impl DependencyGraph {
pod_hash, pod_hash,
statement: dep_stmt.clone(), statement: dep_stmt.clone(),
})); }));
} else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() {
// Anchored-key Contains args may be implicit requirements that are
// auto-materialized by MainPodBuilder. They are handled by anchored-key
// resource accounting, not by statement dependency edges.
continue;
} else { } else {
// Statement arguments should either be internal (created earlier) // Statement arguments should either be internal (created earlier)
// or from external PODs (except anchored-key implicit Contains). // or from external PODs (except anchored-key implicit Contains).
@ -128,9 +122,8 @@ impl DependencyGraph {
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::{
dict,
frontend::Operation as FrontendOp, frontend::Operation as FrontendOp,
middleware::{AnchoredKey, NativeOperation, OperationAux, OperationType, Value, ValueRef}, middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef},
}; };
fn equal_stmt(n: i64) -> Statement { fn equal_stmt(n: i64) -> Statement {
@ -195,32 +188,4 @@ mod tests {
assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]);
assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]); assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]);
} }
#[test]
fn test_anchored_key_contains_arg_is_treated_as_implicit_requirement() {
// A literal Contains statement can be used as an anchored-key argument even when
// no explicit producer statement exists in internal/external statements, because
// MainPodBuilder auto-inserts Contains statements for anchored keys.
let dict = dict!({
"k" => 7_i64
});
let anchored_contains = Statement::Contains(
ValueRef::Literal(Value::from(dict.clone())),
ValueRef::Literal(Value::from("k")),
ValueRef::Literal(Value::from(7_i64)),
);
let ak = AnchoredKey::from((&dict, "k"));
let produced_statement = Statement::Equal(ValueRef::Key(ak.clone()), ValueRef::Key(ak));
// Use a typical frontend operation that consumes entry-like args.
// We're only testing the dependency graph, not the actual proof, so the operation
// just needs to have the right arguments to test what we're looking for.
let statements = vec![produced_statement];
let operations = vec![FrontendOp::eq(anchored_contains.clone(), anchored_contains)];
let graph = DependencyGraph::build(&statements, &operations, &HashMap::new());
assert!(graph.statement_deps[0].is_empty());
}
} }

View file

@ -0,0 +1,466 @@
//! Diagnostic utilities for multi-POD resource analysis.
//!
//! Provides two views:
//! - [`ResourceSummary`]: Pre-solve aggregate resource demand vs. per-POD limits.
//! Shows which resource category is the bottleneck (requires the most PODs).
//! - [`SolutionBreakdown`]: Post-solve per-POD utilization showing how full each POD is.
use std::{collections::BTreeSet, fmt};
use super::cost::StatementCost;
use crate::middleware::Params;
/// A single resource category's usage vs. per-POD limit.
///
/// Used both for pre-solve aggregate demand (in [`ResourceSummary`]) where
/// `used` is the total across all statements, and for post-solve per-POD
/// breakdown (in [`PodUtilization`]) where `used` is the POD's consumption.
#[derive(Clone, Debug)]
pub struct UtilizationRow {
pub name: &'static str,
pub used: usize,
pub limit: usize,
}
impl UtilizationRow {
/// Utilization as a fraction (0.0 to 1.0).
pub fn utilization(&self) -> f64 {
if self.limit == 0 {
if self.used == 0 {
0.0
} else {
f64::INFINITY
}
} else {
self.used as f64 / self.limit as f64
}
}
/// Minimum PODs needed for this resource alone: `ceil(used / limit)`.
/// `None` if `limit` is 0 and `used > 0` (infeasible).
pub fn min_pods(&self) -> Option<usize> {
lower_bound(self.used, self.limit)
}
}
/// Aggregate resource usage over a set of statement costs into per-category rows.
///
/// Single source of truth for the resource categories and their corresponding
/// `Params` limits. Used both for pre-solve totals and per-POD breakdowns.
fn aggregate_rows<'a>(
costs: impl IntoIterator<Item = &'a StatementCost>,
params: &Params,
) -> (Vec<UtilizationRow>, usize) {
let mut num_stmts = 0usize;
let mut merkle_proofs = 0usize;
let mut merkle_state_transitions = 0usize;
let mut custom_pred_verifications = 0usize;
let mut signed_by = 0usize;
let mut public_key_of = 0usize;
let mut custom_pred_ids = BTreeSet::new();
for c in costs {
num_stmts += 1;
merkle_proofs += c.merkle_proofs;
merkle_state_transitions += c.merkle_state_transitions;
custom_pred_verifications += c.custom_pred_verifications;
signed_by += c.signed_by;
public_key_of += c.public_key_of;
custom_pred_ids.extend(c.custom_predicates_ids.iter().cloned());
}
let rows = vec![
UtilizationRow {
name: "private statements",
used: num_stmts,
limit: params.max_priv_statements(),
},
UtilizationRow {
name: "merkle proofs",
used: merkle_proofs,
limit: params.containers.state.max_medium,
},
UtilizationRow {
name: "merkle state transitions",
used: merkle_state_transitions,
limit: params.containers.transition.max_medium,
},
UtilizationRow {
name: "custom pred verifications",
used: custom_pred_verifications,
limit: params.max_custom_predicate_verifications,
},
UtilizationRow {
name: "signed_by",
used: signed_by,
limit: params.max_signed_by,
},
UtilizationRow {
name: "public_key_of",
used: public_key_of,
limit: params.max_public_key_of,
},
UtilizationRow {
name: "distinct custom predicates",
used: custom_pred_ids.len(),
limit: params.max_custom_predicates,
},
];
(rows, num_stmts)
}
/// Pre-solve aggregate resource summary.
///
/// Shows total resource demand across all operations and the minimum PODs
/// each resource category would require independently.
#[derive(Clone, Debug)]
pub struct ResourceSummary {
pub rows: Vec<UtilizationRow>,
pub num_statements: usize,
}
impl ResourceSummary {
/// Compute a resource summary from per-statement costs and params.
pub fn from_costs(costs: &[StatementCost], params: &Params) -> Self {
let (rows, num_statements) = aggregate_rows(costs.iter(), params);
Self {
rows,
num_statements,
}
}
/// The resource category requiring the most PODs (the bottleneck).
/// Returns `None` only if there are no statements.
pub fn bottleneck(&self) -> Option<&UtilizationRow> {
self.rows
.iter()
.filter(|r| r.used > 0)
.max_by_key(|r| r.min_pods().unwrap_or(usize::MAX))
}
}
impl fmt::Display for ResourceSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Resource Summary ({} statements)", self.num_statements)?;
writeln!(
f,
" {:<30} {:>5} {:>9} {:>8}",
"Category", "Total", "Limit/POD", "Min PODs"
)?;
let bottleneck_name = self.bottleneck().map(|r| r.name);
for row in &self.rows {
let min_pods_str = match row.min_pods() {
Some(n) => format!("{}", n),
None => "inf".to_string(),
};
let marker = if Some(row.name) == bottleneck_name && row.used > 0 {
" <<<"
} else {
""
};
writeln!(
f,
" {:<30} {:>5} {:>9} {:>8}{}",
row.name, row.used, row.limit, min_pods_str, marker
)?;
}
Ok(())
}
}
/// Per-POD resource utilization in a solved solution.
#[derive(Clone, Debug)]
pub struct PodUtilization {
/// POD index.
pub pod_idx: usize,
/// Whether this is the output POD (last).
pub is_output: bool,
/// Number of statements in this POD.
pub num_statements: usize,
/// Resource usage vs. limits for each category.
pub resources: Vec<UtilizationRow>,
}
/// Post-solve per-POD resource breakdown.
#[derive(Clone, Debug)]
pub struct SolutionBreakdown {
pub pods: Vec<PodUtilization>,
pub num_statements: usize,
pub pod_count: usize,
}
impl SolutionBreakdown {
/// Compute a solution breakdown from per-statement costs, the solution's
/// pod_statements assignment, and params.
pub fn from_solution(
costs: &[StatementCost],
pod_statements: &[Vec<usize>],
pod_count: usize,
num_statements: usize,
params: &Params,
) -> Self {
let pods = (0..pod_count)
.map(|pod_idx| {
let stmts = &pod_statements[pod_idx];
let (resources, num_stmts) =
aggregate_rows(stmts.iter().map(|&s| &costs[s]), params);
PodUtilization {
pod_idx,
is_output: pod_idx == pod_count - 1,
num_statements: num_stmts,
resources,
}
})
.collect();
Self {
pods,
num_statements,
pod_count,
}
}
}
impl fmt::Display for SolutionBreakdown {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Solution Breakdown ({} statements -> {} PODs)",
self.num_statements, self.pod_count
)?;
for pod in &self.pods {
let role = if pod.is_output {
"output"
} else {
"intermediate"
};
writeln!(f, " POD {} ({}):", pod.pod_idx, role)?;
for row in &pod.resources {
// Only show rows with nonzero usage to reduce noise
if row.used > 0 {
let pct = if row.limit > 0 {
format!("({:>3}%)", (row.used * 100) / row.limit)
} else {
"".to_string()
};
writeln!(
f,
" {:<30} {:>3}/{:<3} {}",
row.name, row.used, row.limit, pct
)?;
}
}
writeln!(f)?;
}
Ok(())
}
}
fn lower_bound(used: usize, limit: usize) -> Option<usize> {
if used == 0 {
Some(0)
} else if limit == 0 {
None
} else {
Some(used.div_ceil(limit))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
frontend::multi_pod::cost::CustomPredicateId,
middleware::{Hash, ParamsContainers, ParamsMerkleProofs, RawValue},
};
fn default_params() -> Params {
Params {
max_statements: 48,
max_public_statements: 8,
containers: ParamsContainers {
state: ParamsMerkleProofs {
max_small: 0,
max_medium: 8,
},
transition: ParamsMerkleProofs {
max_small: 0,
max_medium: 4,
},
..Default::default()
},
max_custom_predicate_verifications: 10,
max_custom_predicates: 2,
max_signed_by: 4,
max_public_key_of: 4,
..Params::default()
}
}
#[test]
fn test_resource_summary_bottleneck() {
let params = default_params();
// max_priv = 48 - 8 = 40
// 6 merkle proofs, 3 state transitions, rest zero-cost
let costs: Vec<StatementCost> = (0..14)
.map(|i| {
let mut c = StatementCost::default();
if i < 6 {
c.merkle_proofs = 1;
} else if i < 9 {
c.merkle_state_transitions = 1;
}
c
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
// 14 statements / 40 per pod = 1 pod for statements
// 6 merkle proofs / 8 per pod = 1 pod
// 3 state transitions / 4 per pod = 1 pod
// All categories need 1 pod, so bottleneck is whichever has the highest min_pods.
// They're all 1, so the first with total > 0 wins in max_by_key (stable).
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.min_pods(), Some(1));
// Verify display doesn't panic
let display = format!("{}", summary);
assert!(display.contains("Resource Summary (14 statements)"));
assert!(display.contains("merkle proofs"));
}
#[test]
fn test_resource_summary_signed_by_bottleneck() {
let params = Params {
max_statements: 48,
max_public_statements: 8,
max_signed_by: 2,
..Params::default()
};
// max_priv = 40
// 6 signed_by operations
let costs: Vec<StatementCost> = (0..6)
.map(|_| StatementCost {
signed_by: 1,
..Default::default()
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.name, "signed_by");
// 6 / 2 = 3 pods
assert_eq!(bottleneck.min_pods(), Some(3));
}
#[test]
fn test_resource_summary_custom_predicates_bottleneck() {
let params = Params {
max_statements: 48,
max_public_statements: 8,
max_custom_predicates: 1, // Only 1 distinct predicate per POD
max_custom_predicate_verifications: 10,
..Params::default()
};
// 3 statements using 3 different custom predicates
let costs: Vec<StatementCost> = (0..3)
.map(|i| {
let mut ids = std::collections::BTreeSet::new();
ids.insert(CustomPredicateId(Hash::from(RawValue::from(i as i64))));
StatementCost {
custom_pred_verifications: 1,
custom_predicates_ids: ids,
..Default::default()
}
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.name, "distinct custom predicates");
// 3 distinct predicates / 1 per pod = 3 pods
assert_eq!(bottleneck.min_pods(), Some(3));
}
#[test]
fn test_solution_breakdown_display() {
let params = default_params();
let costs: Vec<StatementCost> = (0..8)
.map(|i| {
let mut c = StatementCost::default();
if i < 4 {
c.merkle_proofs = 1;
} else {
c.merkle_state_transitions = 1;
}
c
})
.collect();
let pod_statements = vec![
vec![0, 1, 2, 3], // POD 0: 4 merkle proofs
vec![4, 5, 6, 7], // POD 1: 4 state transitions
];
let breakdown = SolutionBreakdown::from_solution(&costs, &pod_statements, 2, 8, &params);
assert_eq!(breakdown.pods.len(), 2);
assert!(!breakdown.pods[0].is_output);
assert!(breakdown.pods[1].is_output);
// POD 0 should have 4 merkle proofs
let mp = breakdown.pods[0]
.resources
.iter()
.find(|r| r.name == "merkle proofs")
.unwrap();
assert_eq!(mp.used, 4);
assert_eq!(mp.limit, 8);
// POD 1 should have 4 state transitions
let mst = breakdown.pods[1]
.resources
.iter()
.find(|r| r.name == "merkle state transitions")
.unwrap();
assert_eq!(mst.used, 4);
assert_eq!(mst.limit, 4);
// Verify display doesn't panic and contains expected content
let display = format!("{}", breakdown);
assert!(display.contains("Solution Breakdown (8 statements -> 2 PODs)"));
assert!(display.contains("POD 0 (intermediate)"));
assert!(display.contains("POD 1 (output)"));
}
#[test]
fn test_utilization_row_fraction() {
let row = UtilizationRow {
name: "test",
used: 3,
limit: 4,
};
assert!((row.utilization() - 0.75).abs() < f64::EPSILON);
let zero_row = UtilizationRow {
name: "test",
used: 0,
limit: 4,
};
assert!((zero_row.utilization()).abs() < f64::EPSILON);
}
}

View file

@ -48,21 +48,23 @@
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
use std::{ use std::{
collections::{BTreeMap, BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
fmt, fmt,
}; };
use crate::{ use crate::{
frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, frontend::{MainPod, MainPodBuilder, Operation},
middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value}, middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value},
}; };
mod cost; mod cost;
mod deps; mod deps;
pub mod diagnostics;
mod solver; mod solver;
use cost::{AnchoredKeyId, StatementCost}; use cost::StatementCost;
use deps::{DependencyGraph, StatementSource}; use deps::{DependencyGraph, StatementSource};
pub use diagnostics::{ResourceSummary, SolutionBreakdown};
pub use solver::MultiPodSolution; pub use solver::MultiPodSolution;
/// Error type for multi-POD operations. /// Error type for multi-POD operations.
@ -168,12 +170,8 @@ pub struct MultiPodBuilder {
options: Options, options: Options,
/// External input PODs (already proved). /// External input PODs (already proved).
input_pods: Vec<MainPod>, input_pods: Vec<MainPod>,
/// Statements created by this builder.
statements: Vec<Statement>,
/// Operations that produce each statement.
operations: Vec<Operation>,
/// Optional initial wildcard values for custom operations /// Optional initial wildcard values for custom operations
operations_wildcard_values: Vec<Vec<(usize, Value)>>, operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
/// Indices of statements that should be public in output PODs. /// Indices of statements that should be public in output PODs.
/// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted.
output_public_indices: Vec<usize>, output_public_indices: Vec<usize>,
@ -193,7 +191,7 @@ pub struct SolvedMultiPod {
statements: Vec<Statement>, statements: Vec<Statement>,
operations: Vec<Operation>, operations: Vec<Operation>,
output_public_indices: Vec<usize>, output_public_indices: Vec<usize>,
operations_wildcard_values: Vec<Vec<(usize, Value)>>, operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
solution: MultiPodSolution, solution: MultiPodSolution,
deps: DependencyGraph, deps: DependencyGraph,
} }
@ -204,6 +202,22 @@ impl SolvedMultiPod {
&self.solution &self.solution
} }
/// Compute a post-solve per-POD resource utilization breakdown.
pub fn solution_breakdown(&self) -> SolutionBreakdown {
let costs: Vec<StatementCost> = self
.operations
.iter()
.map(StatementCost::from_operation)
.collect();
SolutionBreakdown::from_solution(
&costs,
&self.solution.pod_statements,
self.solution.pod_count,
self.statements.len(),
&self.params,
)
}
/// Build and prove all PODs. /// Build and prove all PODs.
/// ///
/// Builds PODs in dependency order (0, 1, ..., k) and proves each one. /// Builds PODs in dependency order (0, 1, ..., k) and proves each one.
@ -260,56 +274,27 @@ impl SolvedMultiPod {
let statements_sorted: BTreeSet<usize> = statements_in_this_pod.iter().copied().collect(); let statements_sorted: BTreeSet<usize> = statements_in_this_pod.iter().copied().collect();
let public_set = &solution.pod_public_statements[pod_idx]; let public_set = &solution.pod_public_statements[pod_idx];
// Track statements proved locally in this POD for argument remapping.
// We index by statement content so duplicate statements can reuse a single
// built statement slot in MainPodBuilder.
let mut added_statements_by_content: HashMap<Statement, Statement> = HashMap::new();
for &stmt_idx in &statements_sorted { for &stmt_idx in &statements_sorted {
let original_stmt = self.statements[stmt_idx].clone(); let op = self.operations[stmt_idx].clone();
let wildcard_values = self
// If this statement content was already built in this POD, reuse it instead .operations_wildcard_values
// of replaying the operation. If any duplicate is public, reveal the .get(&stmt_idx)
// already-built statement. .cloned()
if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) { .unwrap_or_default();
continue;
}
let mut op = self.operations[stmt_idx].clone();
let wildcard_values = self.operations_wildcard_values[stmt_idx].clone();
// Remap Statement arguments that reference locally-proved statements.
// For external dependencies (from input PODs including earlier generated PODs),
// the original Statement is used directly - MainPodBuilder will find it in
// the input POD's public statements via find_op_arg.
for arg in &mut op.1 {
if let OperationArg::Statement(ref orig_stmt) = arg {
if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) {
*arg = OperationArg::Statement(remapped_stmt.clone());
}
}
}
let stmt = builder.op(false, wildcard_values, op)?; let stmt = builder.op(false, wildcard_values, op)?;
assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check
added_statements_by_content.insert(original_stmt, stmt);
} }
// For the output pod, make statements public in the original order. // For the output pod, make statements public in the original order.
// Intermediate pods use the solver-selected public set. // Intermediate pods use the solver-selected public set.
if pod_idx == solution.pod_count - 1 { if pod_idx == solution.pod_count - 1 {
for idx in &self.output_public_indices { for idx in &self.output_public_indices {
let stmt = added_statements_by_content builder.reveal(&self.statements[*idx])?;
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
} }
} else { } else {
for idx in public_set { for idx in public_set {
let stmt = added_statements_by_content builder.reveal(&self.statements[*idx])?;
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
} }
} }
@ -317,7 +302,7 @@ impl SolvedMultiPod {
// for this POD. These do not require local proving in this POD. // for this POD. These do not require local proving in this POD.
for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] { for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] {
let ext_premise = &solution.external_premises[*ext_premise_idx]; let ext_premise = &solution.external_premises[*ext_premise_idx];
builder.reveal(&ext_premise.statement); builder.reveal(&ext_premise.statement)?;
} }
// Step 4: Prove the POD // Step 4: Prove the POD
@ -456,9 +441,7 @@ impl MultiPodBuilder {
options, options,
builder, builder,
input_pods: Vec::new(), input_pods: Vec::new(),
statements: Vec::new(), operations_wildcard_values: HashMap::new(),
operations: Vec::new(),
operations_wildcard_values: Vec::new(),
output_public_indices: Vec::new(), output_public_indices: Vec::new(),
} }
} }
@ -480,6 +463,16 @@ impl MultiPodBuilder {
self.op(false, vec![], op) self.op(false, vec![], op)
} }
// Find the index of a statement that has been added. Panics if the statement doesn't
// exist.
fn stmt_index(&self, stmt: &Statement) -> usize {
self.builder
.statements
.iter()
.position(|s| s == stmt)
.expect("exists")
}
pub fn op( pub fn op(
&mut self, &mut self,
public: bool, public: bool,
@ -488,8 +481,10 @@ impl MultiPodBuilder {
) -> Result<Statement> { ) -> Result<Statement> {
let stmt = self.add_operation(wildcard_values, op)?; let stmt = self.add_operation(wildcard_values, op)?;
if public { if public {
// Index is always new (just added), so push without duplicate check let index = self.stmt_index(&stmt);
self.output_public_indices.push(self.statements.len() - 1); if !self.output_public_indices.contains(&index) {
self.output_public_indices.push(index);
}
} }
Ok(stmt) Ok(stmt)
} }
@ -510,10 +505,8 @@ impl MultiPodBuilder {
let stmt = self let stmt = self
.builder .builder
.op(false, wildcard_values.clone(), op.clone())?; .op(false, wildcard_values.clone(), op.clone())?;
self.operations_wildcard_values
self.statements.push(stmt.clone()); .insert(self.stmt_index(&stmt), wildcard_values.clone());
self.operations.push(op);
self.operations_wildcard_values.push(wildcard_values);
Ok(stmt) Ok(stmt)
} }
@ -523,7 +516,7 @@ impl MultiPodBuilder {
/// Returns an error if the statement was not found in the builder. /// Returns an error if the statement was not found in the builder.
/// Calling this multiple times on the same statement is idempotent. /// Calling this multiple times on the same statement is idempotent.
pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { pub fn reveal(&mut self, stmt: &Statement) -> Result<()> {
if let Some(idx) = self.statements.iter().position(|s| s == stmt) { if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) {
if !self.output_public_indices.contains(&idx) { if !self.output_public_indices.contains(&idx) {
self.output_public_indices.push(idx); self.output_public_indices.push(idx);
} }
@ -536,8 +529,22 @@ impl MultiPodBuilder {
} }
/// Get the number of statements. /// Get the number of statements.
pub fn num_statements(&self) -> usize { pub fn stmt_len(&self) -> usize {
self.statements.len() self.builder.stmt_len()
}
/// Compute a pre-solve resource summary showing aggregate demand vs. per-POD limits.
///
/// This is useful for understanding which resource category is the bottleneck
/// before running the solver, especially when debugging solver performance issues.
pub fn resource_summary(&self) -> ResourceSummary {
let costs: Vec<StatementCost> = self
.builder
.operations
.iter()
.map(StatementCost::from_operation)
.collect();
ResourceSummary::from_costs(&costs, &self.params)
} }
/// Solve the packing problem and return a solved builder ready for proving. /// Solve the packing problem and return a solved builder ready for proving.
@ -545,66 +552,31 @@ impl MultiPodBuilder {
/// This runs the MILP solver to find the optimal POD assignment. /// This runs the MILP solver to find the optimal POD assignment.
/// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved. /// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved.
pub fn solve(self) -> Result<SolvedMultiPod> { pub fn solve(self) -> Result<SolvedMultiPod> {
let MainPodBuilder {
statements,
operations,
..
} = self.builder;
// Compute costs for each statement // Compute costs for each statement
let costs: Vec<StatementCost> = self let costs: Vec<StatementCost> = operations
.operations
.iter() .iter()
.map(StatementCost::from_operation) .map(StatementCost::from_operation)
.collect(); .collect();
// Collect all unique anchored keys from the costs
let all_anchored_keys: Vec<AnchoredKeyId> = costs
.iter()
.flat_map(|c| c.anchored_keys.iter().cloned())
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
// Build map from anchored key to its producing statement index (if any).
// A Contains statement with literal (dict, key, value) "produces" that anchored key.
let mut ak_to_producer: HashMap<AnchoredKeyId, usize> = HashMap::new();
for (stmt_idx, stmt) in self.statements.iter().enumerate() {
if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) {
// First producer wins (shouldn't have duplicates in practice)
ak_to_producer.entry(ak).or_insert(stmt_idx);
}
}
// Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i]
let anchored_key_producers: Vec<Option<usize>> = all_anchored_keys
.iter()
.map(|ak| ak_to_producer.get(ak).copied())
.collect();
// Build external POD statement mapping // Build external POD statement mapping
let external_pod_statements = build_external_statement_map(&self.input_pods); let external_pod_statements = build_external_statement_map(&self.input_pods);
// Build dependency graph // Build dependency graph
let deps = let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements);
DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements);
// Build statement content groups for deduplication.
// Statements with identical content share a single slot in the POD.
// Keep groups ordered by first occurrence index for deterministic solver input.
let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new();
let mut groups_by_first_idx: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
for (idx, stmt) in self.statements.iter().enumerate() {
let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx);
groups_by_first_idx.entry(first_idx).or_default().push(idx);
}
let statement_content_groups: Vec<Vec<usize>> = groups_by_first_idx.into_values().collect();
// Run solver // Run solver
let input = solver::SolverInput { let input = solver::SolverInput {
num_statements: self.statements.len(), num_statements: statements.len(),
costs: &costs, costs: &costs,
deps: &deps, deps: &deps,
output_public_indices: &self.output_public_indices, output_public_indices: &self.output_public_indices,
params: &self.params, params: &self.params,
max_pods: self.options.max_pods, max_pods: self.options.max_pods,
all_anchored_keys: &all_anchored_keys,
anchored_key_producers: &anchored_key_producers,
statement_content_groups: &statement_content_groups,
}; };
let solution = solver::solve(&input)?; let solution = solver::solve(&input)?;
@ -613,8 +585,8 @@ impl MultiPodBuilder {
params: self.params, params: self.params,
vd_set: self.vd_set, vd_set: self.vd_set,
input_pods: self.input_pods, input_pods: self.input_pods,
statements: self.statements, statements,
operations: self.operations, operations,
output_public_indices: self.output_public_indices, output_public_indices: self.output_public_indices,
operations_wildcard_values: self.operations_wildcard_values, operations_wildcard_values: self.operations_wildcard_values,
solution, solution,
@ -845,33 +817,13 @@ mod tests {
let solution = solved.solution(); let solution = solved.solution();
// Expected: exactly 2 PODs // Expected: exactly 2 PODs
// - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public // Solution A:
// - POD 1 (output): statement 2 (b_out); b_out is public // - POD 0 (intermediate): public statements 0 (contains)
// The output POD accesses a_out from POD 0 to satisfy b_out's dependency. // - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out),
assert_eq!( // public statement 2 (b_out)
solution.pod_count, 2, // Solution B:
"Expected exactly 2 PODs for 3-statement chain with max_priv=2" // - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out)
); // - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out)
// POD 0 should contain statements 0 and 1 (contains and a_out)
assert!(
solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1),
"POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}",
solution.pod_statements[0]
);
// Statement 1 (a_out) should be public in POD 0 so POD 1 can access it
assert!(
solution.pod_public_statements[0].contains(&1),
"Statement 1 (a_out) should be public in POD 0"
);
// POD 1 (output) should contain statement 2 (b_out)
assert!(
solution.pod_statements[1].contains(&2),
"POD 1 should contain statement 2 (b_out), got {:?}",
solution.pod_statements[1]
);
// Statement 2 (b_out) should be public in POD 1 (it's output-public) // Statement 2 (b_out) should be public in POD 1 (it's output-public)
assert!( assert!(

View file

@ -52,7 +52,7 @@ use itertools::Itertools;
use super::Result; use super::Result;
use crate::{ use crate::{
frontend::multi_pod::{ frontend::multi_pod::{
cost::{AnchoredKeyId, CustomBatchId, StatementCost}, cost::{CustomPredicateId, StatementCost},
deps::{DependencyGraph, ExternalDependency, StatementSource}, deps::{DependencyGraph, ExternalDependency, StatementSource},
}, },
middleware::{Hash, Params}, middleware::{Hash, Params},
@ -95,7 +95,6 @@ struct DependencyStats {
struct SolveDebugContext { struct SolveDebugContext {
dep_stats: DependencyStats, dep_stats: DependencyStats,
batch_memberships: usize, batch_memberships: usize,
anchored_key_memberships: usize,
} }
#[derive(Clone, Copy, Debug, Default)] #[derive(Clone, Copy, Debug, Default)]
@ -105,10 +104,8 @@ struct ModelSizeEstimate {
vars_public_external: usize, vars_public_external: usize,
vars_pod_used: usize, vars_pod_used: usize,
vars_batch_used: usize, vars_batch_used: usize,
vars_anchored_key_used: usize,
vars_uses_input: usize, vars_uses_input: usize,
vars_uses_external: usize, vars_uses_external: usize,
vars_content_group_used: usize,
vars_total: usize, vars_total: usize,
c1_coverage: usize, c1_coverage: usize,
c2_output_public: usize, c2_output_public: usize,
@ -120,7 +117,6 @@ struct ModelSizeEstimate {
c6_pre_content_group: usize, c6_pre_content_group: usize,
c6_resource_limits: usize, c6_resource_limits: usize,
c7_batch_cardinality: usize, c7_batch_cardinality: usize,
c7b_anchored_key_tracking: usize,
c8a_internal_inputs: usize, c8a_internal_inputs: usize,
c8b_external_dep_inputs: usize, c8b_external_dep_inputs: usize,
c8c_external_forward_inputs: usize, c8c_external_forward_inputs: usize,
@ -141,8 +137,6 @@ impl ModelSizeEstimate {
debug_ctx: &SolveDebugContext, debug_ctx: &SolveDebugContext,
) -> Self { ) -> Self {
let n = input.num_statements; let n = input.num_statements;
let num_groups = input.statement_content_groups.len();
let num_anchored_keys = input.all_anchored_keys.len();
let triangular_k = target_pods * target_pods.saturating_sub(1) / 2; let triangular_k = target_pods * target_pods.saturating_sub(1) / 2;
let vars_prove = n * target_pods; let vars_prove = n * target_pods;
@ -150,19 +144,15 @@ impl ModelSizeEstimate {
let vars_public_external = external_premises_len * target_pods; let vars_public_external = external_premises_len * target_pods;
let vars_pod_used = target_pods; let vars_pod_used = target_pods;
let vars_batch_used = all_batches_len * target_pods; let vars_batch_used = all_batches_len * target_pods;
let vars_anchored_key_used = num_anchored_keys * target_pods;
let vars_uses_input = triangular_k; let vars_uses_input = triangular_k;
let vars_uses_external = external_pods_len * target_pods; let vars_uses_external = external_pods_len * target_pods;
let vars_content_group_used = num_groups * target_pods;
let vars_total = vars_prove let vars_total = vars_prove
+ vars_public + vars_public
+ vars_public_external + vars_public_external
+ vars_pod_used + vars_pod_used
+ vars_batch_used + vars_batch_used
+ vars_anchored_key_used
+ vars_uses_input + vars_uses_input
+ vars_uses_external + vars_uses_external;
+ vars_content_group_used;
let c1_coverage = n; let c1_coverage = n;
let c2_output_public = input.output_public_indices.len(); let c2_output_public = input.output_public_indices.len();
@ -171,12 +161,10 @@ impl ModelSizeEstimate {
let c4_pod_existence = n * target_pods; let c4_pod_existence = n * target_pods;
let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods; let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods;
let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods; let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods;
let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods); let c6_pre_content_group = n * target_pods;
let c6_resource_limits = 7 * target_pods; let c6_resource_limits = 7 * target_pods;
let c7_batch_cardinality = let c7_batch_cardinality =
(debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods); (debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods);
let c7b_anchored_key_tracking =
(debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods);
let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k; let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k;
let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k; let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k;
let c8c_external_forward_inputs = external_premises_len * triangular_k; let c8c_external_forward_inputs = external_premises_len * triangular_k;
@ -194,7 +182,6 @@ impl ModelSizeEstimate {
+ c6_pre_content_group + c6_pre_content_group
+ c6_resource_limits + c6_resource_limits
+ c7_batch_cardinality + c7_batch_cardinality
+ c7b_anchored_key_tracking
+ c8a_internal_inputs + c8a_internal_inputs
+ c8b_external_dep_inputs + c8b_external_dep_inputs
+ c8c_external_forward_inputs + c8c_external_forward_inputs
@ -209,10 +196,8 @@ impl ModelSizeEstimate {
vars_public_external, vars_public_external,
vars_pod_used, vars_pod_used,
vars_batch_used, vars_batch_used,
vars_anchored_key_used,
vars_uses_input, vars_uses_input,
vars_uses_external, vars_uses_external,
vars_content_group_used,
vars_total, vars_total,
c1_coverage, c1_coverage,
c2_output_public, c2_output_public,
@ -224,7 +209,6 @@ impl ModelSizeEstimate {
c6_pre_content_group, c6_pre_content_group,
c6_resource_limits, c6_resource_limits,
c7_batch_cardinality, c7_batch_cardinality,
c7b_anchored_key_tracking,
c8a_internal_inputs, c8a_internal_inputs,
c8b_external_dep_inputs, c8b_external_dep_inputs,
c8c_external_forward_inputs, c8c_external_forward_inputs,
@ -300,6 +284,7 @@ pub struct MultiPodSolution {
} }
/// Input to the MILP solver. /// Input to the MILP solver.
#[derive(Debug)]
pub struct SolverInput<'a> { pub struct SolverInput<'a> {
/// Number of statements. /// Number of statements.
pub num_statements: usize, pub num_statements: usize,
@ -318,28 +303,6 @@ pub struct SolverInput<'a> {
/// Maximum number of PODs the solver will consider. /// Maximum number of PODs the solver will consider.
pub max_pods: usize, pub max_pods: usize,
/// All unique anchored keys referenced by any statement.
///
/// Each unique (dict, key) pair that is used as an anchored key reference
/// in any operation. When a Contains statement with literal values is used
/// as an argument, it creates an anchored key reference.
pub all_anchored_keys: &'a [AnchoredKeyId],
/// For each anchored key, the statement index that produces it (if any).
///
/// When a Contains statement with literal (dict, key, value) args is explicitly
/// added, it "produces" that anchored key. If the producer is in the same POD
/// as statements using the anchored key, no auto-insertion is needed.
/// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`.
pub anchored_key_producers: &'a [Option<usize>],
/// Statement content groups for deduplication.
///
/// Each inner Vec contains statement indices that have identical content.
/// When multiple statements with the same content are proved in the same POD,
/// they only use one statement slot (the POD deduplicates identical statements).
pub statement_content_groups: &'a [Vec<usize>],
} }
/// Solve the MILP problem to find optimal POD packing. /// Solve the MILP problem to find optimal POD packing.
@ -386,11 +349,11 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
))); )));
} }
// Collect all unique custom batch IDs used // Collect all unique custom predicate IDs used
let all_batches: Vec<CustomBatchId> = input let all_custom_predicates: Vec<CustomPredicateId> = input
.costs .costs
.iter() .iter()
.flat_map(|c| c.custom_batch_ids.iter().cloned()) .flat_map(|c| c.custom_predicates_ids.iter().cloned())
.unique() .unique()
.collect(); .collect();
@ -417,27 +380,26 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
} }
let dep_stats = dependency_stats(input.deps); let dep_stats = dependency_stats(input.deps);
let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum(); let batch_memberships: usize = input
let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum(); .costs
.iter()
.map(|c| c.custom_predicates_ids.len())
.sum();
let debug_ctx = SolveDebugContext { let debug_ctx = SolveDebugContext {
dep_stats, dep_stats,
batch_memberships, batch_memberships,
anchored_key_memberships,
}; };
if log::log_enabled!(log::Level::Debug) { if log::log_enabled!(log::Level::Debug) {
let resource_totals = ResourceTotals::from_costs(input.costs); let resource_totals = ResourceTotals::from_costs(input.costs);
let lb_statement_groups = let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod);
lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod);
let lb_merkle = lower_bound_from_total( let lb_merkle = lower_bound_from_total(
resource_totals.merkle_proofs, resource_totals.merkle_proofs,
input.params.max_merkle_proofs_containers, input.params.containers.state.max_medium,
); );
let lb_merkle_transitions = lower_bound_from_total( let lb_merkle_transitions = lower_bound_from_total(
resource_totals.merkle_state_transitions, resource_totals.merkle_state_transitions,
input input.params.containers.transition.max_medium,
.params
.max_merkle_tree_state_transition_proofs_containers,
); );
let lb_custom_pred_verifications = lower_bound_from_total( let lb_custom_pred_verifications = lower_bound_from_total(
resource_totals.custom_pred_verifications, resource_totals.custom_pred_verifications,
@ -463,14 +425,12 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
.expect("non-empty lower-bound candidate list"); .expect("non-empty lower-bound candidate list");
log::debug!( log::debug!(
"MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \ "MILP summary: statements={} output_public={} \
batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ custom_predicates={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \
external_premises={} search_min_pods={} max_pods={}", external_premises={} search_min_pods={} max_pods={}",
n, n,
num_output_public, num_output_public,
input.statement_content_groups.len(), all_custom_predicates.len(),
input.all_anchored_keys.len(),
all_batches.len(),
dep_stats.internal_edges, dep_stats.internal_edges,
dep_stats.external_edges, dep_stats.external_edges,
external_pods.len(), external_pods.len(),
@ -481,14 +441,13 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
log::debug!( log::debug!(
"MILP resource totals: merkle_proofs={} merkle_state_transitions={} \ "MILP resource totals: merkle_proofs={} merkle_state_transitions={} \
custom_pred_verifications={} signed_by={} public_key_of={} \ custom_pred_verifications={} signed_by={} public_key_of={} \
batch_memberships={} anchored_key_memberships={}", batch_memberships={}",
resource_totals.merkle_proofs, resource_totals.merkle_proofs,
resource_totals.merkle_state_transitions, resource_totals.merkle_state_transitions,
resource_totals.custom_pred_verifications, resource_totals.custom_pred_verifications,
resource_totals.signed_by, resource_totals.signed_by,
resource_totals.public_key_of, resource_totals.public_key_of,
batch_memberships, batch_memberships,
anchored_key_memberships
); );
log::debug!( log::debug!(
"MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \ "MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \
@ -513,7 +472,7 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
if let Some(solution) = try_solve_with_pods( if let Some(solution) = try_solve_with_pods(
input, input,
target_pods, target_pods,
&all_batches, &all_custom_predicates,
&external_pods, &external_pods,
&external_premises, &external_premises,
&debug_ctx, &debug_ctx,
@ -540,7 +499,7 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
fn try_solve_with_pods( fn try_solve_with_pods(
input: &SolverInput, input: &SolverInput,
target_pods: usize, target_pods: usize,
all_batches: &[CustomBatchId], all_custom_predicates: &[CustomPredicateId],
external_pods: &[Hash], external_pods: &[Hash],
external_premises: &[ExternalDependency], external_premises: &[ExternalDependency],
debug_ctx: &SolveDebugContext, debug_ctx: &SolveDebugContext,
@ -574,21 +533,8 @@ fn try_solve_with_pods(
.map(|_| vars.add(variable().binary())) .map(|_| vars.add(variable().binary()))
.collect(); .collect();
// batch_used[b][p] - custom batch b is used in POD p // custom_predicates[b][p] - custom predicate b is used in POD p
let batch_used: Vec<Vec<Variable>> = (0..all_batches.len()) let custom_predicate_used: Vec<Vec<Variable>> = (0..all_custom_predicates.len())
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
.collect()
})
.collect();
// anchored_key_used[ak][p] - anchored key ak is used in POD p
// When a statement references an anchored key (via a Contains statement argument),
// that POD must have a Contains statement for that (dict, key) pair.
// MainPodBuilder::add_entries_contains auto-inserts these, and we must account
// for them in the statement count.
let anchored_key_used: Vec<Vec<Variable>> = (0..input.all_anchored_keys.len())
.map(|_| { .map(|_| {
(0..target_pods) (0..target_pods)
.map(|_| vars.add(variable().binary())) .map(|_| vars.add(variable().binary()))
@ -633,31 +579,19 @@ fn try_solve_with_pods(
.map(|(i, ext)| (ext.clone(), i)) .map(|(i, ext)| (ext.clone(), i))
.collect(); .collect();
// content_group_used[g][p] - content group g has at least one statement proved in POD p
// When multiple statements have identical content, they share a slot in the POD.
// This variable tracks whether at least one statement from each content group is proved.
let num_groups = input.statement_content_groups.len();
let content_group_used: Vec<Vec<Variable>> = (0..num_groups)
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
.collect()
})
.collect();
if log::log_enabled!(log::Level::Debug) { if log::log_enabled!(log::Level::Debug) {
let estimate = ModelSizeEstimate::for_target_pods( let estimate = ModelSizeEstimate::for_target_pods(
input, input,
target_pods, target_pods,
all_batches.len(), all_custom_predicates.len(),
external_pods.len(), external_pods.len(),
external_premises.len(), external_premises.len(),
debug_ctx, debug_ctx,
); );
log::debug!( log::debug!(
"MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \ "MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \
public_external={} batch_used={} anchored_key_used={} uses_input={} \ public_external={} batch_used={} uses_input={} \
uses_external={} content_group_used={}]", uses_external={}]",
target_pods, target_pods,
estimate.vars_total, estimate.vars_total,
estimate.vars_prove, estimate.vars_prove,
@ -665,14 +599,12 @@ fn try_solve_with_pods(
estimate.vars_pod_used, estimate.vars_pod_used,
estimate.vars_public_external, estimate.vars_public_external,
estimate.vars_batch_used, estimate.vars_batch_used,
estimate.vars_anchored_key_used,
estimate.vars_uses_input, estimate.vars_uses_input,
estimate.vars_uses_external, estimate.vars_uses_external,
estimate.vars_content_group_used
); );
log::debug!( log::debug!(
"MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \ "MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \
c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \ c5i={} c5e={} c6_pre={} c6_limits={} c7={} c8a={} c8b={} c8c={} \
c8d={} c9={} c10={} c10b={}]", c8d={} c9={} c10={} c10b={}]",
target_pods, target_pods,
estimate.constraints_total, estimate.constraints_total,
@ -686,7 +618,6 @@ fn try_solve_with_pods(
estimate.c6_pre_content_group, estimate.c6_pre_content_group,
estimate.c6_resource_limits, estimate.c6_resource_limits,
estimate.c7_batch_cardinality, estimate.c7_batch_cardinality,
estimate.c7b_anchored_key_tracking,
estimate.c8a_internal_inputs, estimate.c8a_internal_inputs,
estimate.c8b_external_dep_inputs, estimate.c8b_external_dep_inputs,
estimate.c8c_external_forward_inputs, estimate.c8c_external_forward_inputs,
@ -798,35 +729,11 @@ fn try_solve_with_pods(
} }
} }
// Constraint 6: Resource limits per POD
//
// 6a-pre: Content group tracking for statement deduplication
// When multiple statement indices have identical content, they share a single slot in the POD.
// content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p.
for (g, group) in input.statement_content_groups.iter().enumerate() {
for p in 0..target_pods { for p in 0..target_pods {
// Lower bound: if any statement in the group is proved, the group is used // 6a: Statement count
for &s in group { let stmt_sum: Expression = (0..n).map(|g| prove[g][p]).sum();
model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p]));
}
// Upper bound: if no statements in the group are proved, the group is not used
let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum();
model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum));
}
}
for p in 0..target_pods {
// 6a: Unique statement count (unique content groups + anchored key Contains)
// Statements with identical content share a slot, so we count content groups, not indices.
// Anchored key Contains statements are auto-inserted by MainPodBuilder when needed.
// The total must not exceed max_priv_statements (= max_statements - max_public_statements).
let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum();
let anchored_key_sum: Expression = (0..input.all_anchored_keys.len())
.map(|ak| anchored_key_used[ak][p])
.sum();
model.add_constraint(constraint!( model.add_constraint(constraint!(
unique_stmt_sum + anchored_key_sum stmt_sum <= (input.params.max_priv_statements() as f64) * pod_used[p]
<= (input.params.max_priv_statements() as f64) * pod_used[p]
)); ));
// 6b: Public statement count (internal public statements + forwarded external premises) // 6b: Public statement count (internal public statements + forwarded external premises)
@ -844,7 +751,7 @@ fn try_solve_with_pods(
.map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p]) .map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p])
.sum(); .sum();
model.add_constraint(constraint!( model.add_constraint(constraint!(
merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p] merkle_sum <= (input.params.containers.state.max_medium as f64) * pod_used[p]
)); ));
// 6d: Merkle state transitions // 6d: Merkle state transitions
@ -852,11 +759,7 @@ fn try_solve_with_pods(
.map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p]) .map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p])
.sum(); .sum();
model.add_constraint(constraint!( model.add_constraint(constraint!(
mst_sum mst_sum <= (input.params.containers.transition.max_medium as f64) * pod_used[p]
<= (input
.params
.max_merkle_tree_state_transition_proofs_containers as f64)
* pod_used[p]
)); ));
// 6e: Custom predicate verifications // 6e: Custom predicate verifications
@ -885,67 +788,31 @@ fn try_solve_with_pods(
} }
// Constraint 7: Batch cardinality // Constraint 7: Batch cardinality
// batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it) // custom_predicate_used[b][p] >= prove[s][p] for all s that use custom predicate b (custom
// batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it) // predicate is used if any statement uses it)
for (b, batch_id) in all_batches.iter().enumerate() { // custom_predicate_used[b][p] <= sum of prove[s][p] for all s using custom predicate b (custom
// predicate is 0 if no statements use it)
for (b, predicate_id) in all_custom_predicates.iter().enumerate() {
for p in 0..target_pods { for p in 0..target_pods {
let mut sum: Expression = 0.into(); let mut sum: Expression = 0.into();
for s in 0..n { for s in 0..n {
if input.costs[s].custom_batch_ids.contains(batch_id) { if input.costs[s].custom_predicates_ids.contains(predicate_id) {
model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p])); model.add_constraint(constraint!(custom_predicate_used[b][p] >= prove[s][p]));
sum += prove[s][p]; sum += prove[s][p];
} }
} }
model.add_constraint(constraint!(batch_used[b][p] <= sum)); model.add_constraint(constraint!(custom_predicate_used[b][p] <= sum));
} }
} }
// Constraint 7b: Anchored key tracking // Custom predicate count per POD
//
// anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p.
// This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p.
//
// If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)):
// - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak
// - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p]
// This ensures overhead is 0 when the producer is in the same POD.
//
// If no Contains produces ak (anchored_key_producers[ak] = None):
// - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak
// - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak
// Auto-insertion is always needed when any user is present.
for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() {
let producer = input.anchored_key_producers[ak_idx];
for p in 0..target_pods { for p in 0..target_pods {
let mut user_sum: Expression = 0.into(); let custom_predicate_sum: Expression = (0..all_custom_predicates.len())
for s in 0..n { .map(|b| custom_predicate_used[b][p])
if input.costs[s].anchored_keys.contains(ak) { .sum();
if let Some(prod_idx) = producer {
// Producer exists: only count overhead if producer not in this POD
model.add_constraint(constraint!( model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p] custom_predicate_sum <= (input.params.max_custom_predicates as f64) * pod_used[p]
)); ));
} else {
// No producer: always need auto-insertion if user is present
model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] >= prove[s][p]
));
}
user_sum += prove[s][p];
}
}
if let Some(prod_idx) = producer {
// If producer is in POD, no auto-insertion needed (overhead = 0)
model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p]
));
} else {
// No producer: overhead is bounded by whether any user is present
model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum));
}
}
} }
// Constraint 8a: Internal input POD tracking using uses_input. // Constraint 8a: Internal input POD tracking using uses_input.
@ -1147,9 +1014,6 @@ mod tests {
output_public_indices: &[], output_public_indices: &[],
params: &params, params: &params,
max_pods: 20, max_pods: 20,
all_anchored_keys: &[],
anchored_key_producers: &[],
statement_content_groups: &[],
}; };
let result = solve(&input); let result = solve(&input);
@ -1195,7 +1059,6 @@ mod tests {
}; };
let costs = vec![StatementCost::default(), StatementCost::default()]; let costs = vec![StatementCost::default(), StatementCost::default()];
let statement_content_groups = vec![vec![0], vec![1]];
let output_public = vec![1]; let output_public = vec![1];
let input = SolverInput { let input = SolverInput {
@ -1205,9 +1068,6 @@ mod tests {
output_public_indices: &output_public, output_public_indices: &output_public,
params: &params, params: &params,
max_pods: 4, max_pods: 4,
all_anchored_keys: &[],
anchored_key_producers: &[],
statement_content_groups: &statement_content_groups,
}; };
let solution = solve(&input).expect("solver should find a feasible forwarding layout"); let solution = solve(&input).expect("solver should find a feasible forwarding layout");

View file

@ -1,10 +1,10 @@
use std::fmt; use std::{fmt, iter};
use crate::{ use crate::{
frontend::SignedDict, frontend::SignedDict,
middleware::{ middleware::{
containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux,
OperationType, Signature, Statement, TypedValue, Value, ValueRef, OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS,
}, },
}; };
@ -39,10 +39,9 @@ impl OperationArg {
} }
pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> { pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> {
self.value_and_ref().and_then(|(r, v)| match v.typed() { self.value_and_ref()
&TypedValue::Int(i) => Some((r, i)), .and_then(|(r, v)| v.as_int().map(|i| Some((r, i))))
_ => None, .flatten()
})
} }
} }
@ -71,7 +70,7 @@ impl From<&Value> for OperationArg {
impl From<(&Dictionary, &str)> for OperationArg { impl From<(&Dictionary, &str)> for OperationArg {
fn from((dict, key): (&Dictionary, &str)) -> Self { fn from((dict, key): (&Dictionary, &str)) -> Self {
// TODO: Use TryFrom // TODO: Use TryFrom
let value = dict.get(&key.into()).cloned().unwrap(); let value = dict.get(&key.into()).unwrap().unwrap();
Self::Statement(Statement::Contains( Self::Statement(Statement::Contains(
dict.clone().into(), dict.clone().into(),
key.into(), key.into(),
@ -220,6 +219,24 @@ impl Operation {
op_impl_oa!(set_insert, SetInsertFromEntries, 3); op_impl_oa!(set_insert, SetInsertFromEntries, 3);
op_impl_oa!(set_delete, SetDeleteFromEntries, 3); op_impl_oa!(set_delete, SetDeleteFromEntries, 3);
op_impl_oa!(array_update, ArrayUpdateFromEntries, 4); op_impl_oa!(array_update, ArrayUpdateFromEntries, 4);
pub fn replace_value_with_entry(args: Vec<Option<(&Dictionary, &str)>>, st: Statement) -> Self {
assert!(args.len() <= BASE_PARAMS.max_statement_args);
let args = args
.into_iter()
.chain(iter::repeat(None))
.take(BASE_PARAMS.max_statement_args)
.map(|a| match a {
None => OperationArg::Statement(Statement::None),
Some((dict, key)) => OperationArg::from((dict, key)),
})
.chain(iter::once(OperationArg::Statement(st)))
.collect();
Self(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
args,
OperationAux::None,
)
}
pub fn signed_by( pub fn signed_by(
msg: impl Into<OperationArg>, msg: impl Into<OperationArg>,
pk: impl Into<OperationArg>, pk: impl Into<OperationArg>,

View file

@ -83,7 +83,7 @@ mod tests {
middleware::{ middleware::{
self, self,
containers::{Array, Dictionary, Set}, containers::{Array, Dictionary, Set},
Params, Signer as _, TypedValue, DEFAULT_VD_LIST, Params, Signer as _, Value, DEFAULT_VD_LIST,
}, },
}; };
@ -91,16 +91,15 @@ mod tests {
fn test_value_serialization() { fn test_value_serialization() {
// Pairs of values and their expected serialized representations // Pairs of values and their expected serialized representations
let values = vec![ let values = vec![
(TypedValue::String("hello".to_string()), "\"hello\""), (Value::from("hello"), "\"hello\""),
(TypedValue::Int(42), "{\"Int\":\"42\"}"), (Value::from(42), "{\"Int\":\"42\"}"),
(TypedValue::Bool(true), "true"), (Value::from(true), r#"{"Int":"1"}"#),
( (
TypedValue::Array(Array::new(vec!["foo".into(), false.into()])), Value::from(Array::new(vec![Value::from("foo"), Value::from(false)])),
"{\"array\":[\"foo\",false]}", r#"{"inner":[[{"Int":"0"},"foo"],[{"Int":"1"},{"Int":"0"}]]}"#,
), ),
( (
TypedValue::Dictionary( Value::from(Dictionary::new(HashMap::from([
Dictionary::new(HashMap::from([
// The set of valid keys is equal to the set of valid JSON keys // The set of valid keys is equal to the set of valid JSON keys
("foo".into(), 123.into()), ("foo".into(), 123.into()),
// Empty strings are valid JSON keys // Empty strings are valid JSON keys
@ -113,26 +112,25 @@ mod tests {
(("\0".into()), "".into()), (("\0".into()), "".into()),
// Keys can contain emojis // Keys can contain emojis
(("🥳".into()), "party time!".into()), (("🥳".into()), "party time!".into()),
])) ]))),
), r#"{"inner":[["!@£$%^&&*()",""],["🥳","party time!"],[" hi",{"Int":"0"}],["foo",{"Int":"123"}],["\u0000",""],["","baz"]]}"#,
"{\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}",
), ),
( (
TypedValue::Set(Set::new(HashSet::from(["foo".into(), "bar".into()]))), Value::from(Set::new(HashSet::from(["foo".into(), "bar".into()]))),
"{\"set\":[\"bar\",\"foo\"]}", r#"{"inner":[["bar"],["foo"]]}"#,
), ),
]; ];
for (value, expected) in values { for (value, expected) in values {
let serialized = serde_json::to_string(&value).unwrap(); let serialized = serde_json::to_string(&value).unwrap();
assert_eq!(serialized, expected); assert_eq!(serialized, expected);
let deserialized: TypedValue = serde_json::from_str(&serialized).unwrap(); let deserialized: Value = serde_json::from_str(&serialized).unwrap();
assert_eq!( assert_eq!(
value, deserialized, value, deserialized,
"value {:#?} should equal deserialized {:#?}", "value {:#?} should equal deserialized {:#?}",
value, deserialized value, deserialized
); );
let expected_deserialized: TypedValue = serde_json::from_str(expected).unwrap(); let expected_deserialized: Value = serde_json::from_str(expected).unwrap();
assert_eq!(value, expected_deserialized); assert_eq!(value, expected_deserialized);
} }
} }
@ -177,7 +175,10 @@ mod tests {
"deserialized: {}", "deserialized: {}",
serde_json::to_string_pretty(&deserialized).unwrap() serde_json::to_string_pretty(&deserialized).unwrap()
); );
assert_eq!(signed_dict.dict.kvs(), deserialized.dict.kvs()); assert_eq!(
signed_dict.dict.dump().unwrap(),
deserialized.dict.dump().unwrap()
);
assert_eq!(signed_dict.public_key, deserialized.public_key); assert_eq!(signed_dict.public_key, deserialized.public_key);
assert_eq!(signed_dict.signature, deserialized.signature); assert_eq!(signed_dict.signature, deserialized.signature);
assert_eq!(signed_dict.verify().is_ok(), deserialized.verify().is_ok()); assert_eq!(signed_dict.verify().is_ok(), deserialized.verify().is_ok());

View file

@ -174,18 +174,6 @@ fn render_validation_error(
"second REQUEST here", "second REQUEST here",
), ),
ValidationError::InvalidArgumentType { predicate, span } => {
let title = format!("invalid argument type for `{}`", predicate);
render_with_optional_span(
renderer,
source,
path,
&title,
span.as_ref(),
"anchored keys not allowed here",
)
}
ValidationError::DuplicateWildcard { name, span } => { ValidationError::DuplicateWildcard { name, span } => {
let title = format!("duplicate wildcard: {}", name); let title = format!("duplicate wildcard: {}", name);
render_with_optional_span( render_with_optional_span(
@ -287,6 +275,17 @@ fn render_validation_error(
ValidationError::NoRequestBlock => { ValidationError::NoRequestBlock => {
render_title_only(renderer, "requests must contain a REQUEST block") render_title_only(renderer, "requests must contain a REQUEST block")
} }
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { span } => {
render_with_optional_span(
renderer,
source,
path,
"self-referential predicate literal not allowed in requests",
span.as_ref(),
"not allowed here",
)
}
} }
} }

View file

@ -135,12 +135,6 @@ pub enum ValidationError {
span: Option<Span>, span: Option<Span>,
}, },
#[error("Invalid argument type for {predicate}: anchored keys not allowed")]
InvalidArgumentType {
predicate: String,
span: Option<Span>,
},
#[error("Duplicate wildcard in predicate arguments: {name}")] #[error("Duplicate wildcard in predicate arguments: {name}")]
DuplicateWildcard { name: String, span: Option<Span> }, DuplicateWildcard { name: String, span: Option<Span> },
@ -165,6 +159,9 @@ pub enum ValidationError {
#[error("Modules must contain at least one predicate definition")] #[error("Modules must contain at least one predicate definition")]
NoPredicatesInModule, NoPredicatesInModule,
#[error("Self-referential predicate literal not allowed in requests")]
SelfReferentialPredicateLiteralNotAllowedInRequests { span: Option<Span> },
#[error("Requests must contain a REQUEST block")] #[error("Requests must contain a REQUEST block")]
NoRequestBlock, NoRequestBlock,
} }

View file

@ -116,6 +116,8 @@ pub enum StatementTmplArg {
Literal(LiteralValue), Literal(LiteralValue),
Wildcard(Identifier), Wildcard(Identifier),
AnchoredKey(AnchoredKey), AnchoredKey(AnchoredKey),
/// Hash of a same-module predicate, resolved at batch finalization time.
SelfPredicateHash(Identifier),
} }
/// Anchored key: Var["key"] or Var.key /// Anchored key: Var["key"] or Var.key
@ -168,6 +170,13 @@ pub enum LiteralValue {
Array(LiteralArray), Array(LiteralArray),
Set(LiteralSet), Set(LiteralSet),
Dict(LiteralDict), Dict(LiteralDict),
/// Hash of a native predicate (resolved immediately).
NativePredicateHash(Identifier),
/// Hash of an external module's predicate (resolved immediately).
ExternalPredicateHash {
module: Identifier,
predicate: Identifier,
},
} }
/// Integer literal /// Integer literal
@ -391,6 +400,9 @@ impl fmt::Display for StatementTmplArg {
StatementTmplArg::Literal(lit) => write!(f, "{}", lit), StatementTmplArg::Literal(lit) => write!(f, "{}", lit),
StatementTmplArg::Wildcard(id) => write!(f, "{}", id), StatementTmplArg::Wildcard(id) => write!(f, "{}", id),
StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak), StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak),
StatementTmplArg::SelfPredicateHash(id) => {
write!(f, "@self_predicate({})", id)
}
} }
} }
} }
@ -422,6 +434,12 @@ impl fmt::Display for LiteralValue {
LiteralValue::Array(a) => write!(f, "{}", a), LiteralValue::Array(a) => write!(f, "{}", a),
LiteralValue::Set(s) => write!(f, "{}", s), LiteralValue::Set(s) => write!(f, "{}", s),
LiteralValue::Dict(d) => write!(f, "{}", d), LiteralValue::Dict(d) => write!(f, "{}", d),
LiteralValue::NativePredicateHash(id) => {
write!(f, "@native_predicate({})", id)
}
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => write!(f, "@external_predicate({}, {})", module, predicate),
} }
} }
} }
@ -769,6 +787,10 @@ pub mod parse {
let inner = pair.into_inner().next().unwrap(); let inner = pair.into_inner().next().unwrap();
match inner.as_rule() { match inner.as_rule() {
Rule::predicate_hash_self => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(StatementTmplArg::SelfPredicateHash(id))
}
Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)), Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)),
Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))), Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))),
Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)), Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)),
@ -823,6 +845,16 @@ pub mod parse {
Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)),
Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)),
Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)),
Rule::predicate_hash_native => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(LiteralValue::NativePredicateHash(id))
}
Rule::predicate_hash_external => {
let mut parts = inner.into_inner();
let module = parse_identifier(parts.next().unwrap());
let predicate = parse_identifier(parts.next().unwrap());
Ok(LiteralValue::ExternalPredicateHash { module, predicate })
}
_ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()),
} }
} }
@ -1104,6 +1136,7 @@ mod tests {
AnchoredKeyPath::Dot(id) => id.span = None, AnchoredKeyPath::Dot(id) => id.span = None,
} }
} }
StatementTmplArg::SelfPredicateHash(id) => id.span = None,
} }
} }
} }
@ -1139,6 +1172,13 @@ mod tests {
clear_literal_spans(&mut pair.value); clear_literal_spans(&mut pair.value);
} }
} }
LiteralValue::NativePredicateHash(id) => id.span = None,
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => {
module.span = None;
predicate.span = None;
}
} }
} }

View file

@ -157,8 +157,10 @@ fn resolve_local_predicate(
/// Lower a literal value from AST to middleware Value. /// Lower a literal value from AST to middleware Value.
/// ///
/// This is a pure conversion that cannot fail. /// This is a pure conversion that cannot fail for context-free literals.
pub fn lower_literal(lit: &LiteralValue) -> Value { /// Panics on ExternalPredicateHash — use `lower_literal_with_context` when
/// external predicate references may appear (e.g. inside containers).
pub(crate) fn lower_literal(lit: &LiteralValue) -> Value {
match lit { match lit {
LiteralValue::Int(i) => Value::from(i.value), LiteralValue::Int(i) => Value::from(i.value),
LiteralValue::Bool(b) => Value::from(b.value), LiteralValue::Bool(b) => Value::from(b.value),
@ -190,13 +192,83 @@ pub fn lower_literal(lit: &LiteralValue) -> Value {
let dict = containers::Dictionary::new(pairs); let dict = containers::Dictionary::new(pairs);
Value::from(dict) Value::from(dict)
} }
LiteralValue::NativePredicateHash(id) => {
let np = NativePredicate::from_str(&id.name).expect("validated native predicate");
Value::from(Predicate::Native(np).hash())
}
LiteralValue::ExternalPredicateHash { .. } => {
unreachable!(
"ExternalPredicateHash must be lowered with context via lower_literal_with_context"
)
}
}
}
/// Lower a literal value, resolving external predicate references using the symbol table.
pub fn lower_literal_with_context(
lit: &LiteralValue,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<Value, LoweringError> {
match lit {
LiteralValue::ExternalPredicateHash { module, predicate } => {
let pred_or_wc = resolve_predicate_ref(
&PredicateRef::Qualified {
module: module.clone(),
predicate: predicate.clone(),
},
symbols,
context,
)
.ok_or_else(|| LoweringError::PredicateNotFound {
name: format!("{}::{}", module.name, predicate.name),
})?;
let pred = match pred_or_wc {
crate::frontend::PredicateOrWildcard::Predicate(p) => p,
_ => unreachable!(
"`resolve_predicate_ref` always returns `PredicateOrWildcard::Predicate` on `PredicateRef::Qualified`"
)
};
Ok(Value::from(pred.hash()))
}
LiteralValue::Array(a) => {
let elements: Vec<_> = a
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Array::new(elements)))
}
LiteralValue::Set(s) => {
let elements: std::collections::HashSet<_> = s
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Set::new(elements)))
}
LiteralValue::Dict(d) => {
let pairs: HashMap<_, _> = d
.pairs
.iter()
.map(|pair| {
let key = Key::from(pair.key.value.as_str());
let value = lower_literal_with_context(&pair.value, symbols, context)?;
Ok((key, value))
})
.collect::<Result<_, LoweringError>>()?;
Ok(Value::from(containers::Dictionary::new(pairs)))
}
// All other variants are context-free
other => Ok(lower_literal(other)),
} }
} }
/// Lower a statement argument from AST to BuilderArg. /// Lower a statement argument from AST to BuilderArg.
/// ///
/// This is a pure conversion that cannot fail. /// Context-free for most arg types. Panics on ExternalPredicateHash inside literals —
pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { /// use `lower_statement_arg_with_context` when external predicate references may appear.
pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
match arg { match arg {
StatementTmplArg::Literal(lit) => { StatementTmplArg::Literal(lit) => {
let value = lower_literal(lit); let value = lower_literal(lit);
@ -210,6 +282,25 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
}; };
BuilderArg::Key(ak.root.name.clone(), key_str) BuilderArg::Key(ak.root.name.clone(), key_str)
} }
StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()),
}
}
/// Lower a statement argument, resolving external predicate references using the symbol table.
pub fn lower_statement_arg_with_context(
arg: &StatementTmplArg,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<BuilderArg, LoweringError> {
match arg {
StatementTmplArg::Literal(lit) => {
let value = lower_literal_with_context(lit, symbols, context)?;
Ok(BuilderArg::Literal(value))
}
StatementTmplArg::SelfPredicateHash(id) => {
Ok(BuilderArg::SelfPredicateHash(id.name.clone()))
}
other => Ok(lower_statement_arg(other)),
} }
} }
@ -324,7 +415,7 @@ impl<'a> Lowerer<'a> {
// Create a builder with the resolved predicate and desugar // Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate.clone()); let mut builder = StatementTmplBuilder::new(predicate.clone());
for arg in &stmt.args { for arg in &stmt.args {
let builder_arg = lower_statement_arg(arg); let builder_arg = lower_statement_arg_with_context(arg, symbols, &context)?;
builder = builder.arg(builder_arg); builder = builder.arg(builder_arg);
} }
let desugared = builder.desugar(); let desugared = builder.desugar();
@ -346,6 +437,9 @@ impl<'a> Lowerer<'a> {
let key = Key::from(key_str.as_str()); let key = Key::from(key_str.as_str());
MWStatementTmplArg::AnchoredKey(wildcard, key) MWStatementTmplArg::AnchoredKey(wildcard, key)
} }
BuilderArg::SelfPredicateHash(_) => {
unreachable!("SelfPredicateHash should not appear in request lowering")
}
}; };
mw_args.push(mw_arg); mw_args.push(mw_arg);
} }
@ -399,7 +493,7 @@ impl<'a> Lowerer<'a> {
names.push(ak.root.name.clone()); names.push(ak.root.name.clone());
} }
} }
StatementTmplArg::Literal(_) => {} StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
} }
} }
} }

View file

@ -123,7 +123,7 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
StatementTmplArg::AnchoredKey(ak) => { StatementTmplArg::AnchoredKey(ak) => {
wildcards.insert(ak.root.name.clone()); wildcards.insert(ak.root.name.clone());
} }
StatementTmplArg::Literal(_) => {} StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
} }
} }

View file

@ -522,7 +522,7 @@ impl Validator {
} }
// Validate arguments // Validate arguments
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; self.validate_statement_args(stmt, wildcard_context)?;
Ok(()) Ok(())
} }
@ -530,40 +530,8 @@ impl Validator {
fn validate_statement_args( fn validate_statement_args(
&self, &self,
stmt: &StatementTmpl, stmt: &StatementTmpl,
pred_info: Option<&PredicateInfo>,
wildcard_context: Option<(&str, &WildcardScope)>, wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> { ) -> Result<(), ValidationError> {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. })
| Some(PredicateKind::BatchImported { .. })
| Some(PredicateKind::ModuleImported { .. })
) {
for arg in &stmt.args {
match arg {
StatementTmplArg::AnchoredKey(_) => {
return Err(ValidationError::InvalidArgumentType {
predicate: stmt.predicate.predicate_name().to_string(),
span: stmt.span,
});
}
StatementTmplArg::Wildcard(id) => {
if let Some((pred_name, scope)) = wildcard_context {
if !scope.wildcards.contains_key(&id.name) {
return Err(ValidationError::UndefinedWildcard {
name: id.name.clone(),
pred_name: pred_name.to_string(),
span: id.span,
});
}
}
}
StatementTmplArg::Literal(_) => {}
}
}
} else {
// Native predicates can have anchored keys
for arg in &stmt.args { for arg in &stmt.args {
match arg { match arg {
StatementTmplArg::Wildcard(id) => { StatementTmplArg::Wildcard(id) => {
@ -588,13 +556,91 @@ impl Validator {
} }
} }
} }
StatementTmplArg::Literal(_) => {} StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
} }
} }
} }
Ok(()) Ok(())
} }
/// Validate a @self_predicate reference: the name must be a custom predicate in this module.
fn validate_self_predicate_hash(
&self,
id: &Identifier,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// @self_predicate only makes sense inside module predicate definitions
if wildcard_context.is_none() {
return Err(
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests {
span: id.span,
},
);
}
// Must refer to a custom predicate defined in this module (not intro/imported)
match self.symbols.predicates.get(&id.name) {
Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()),
_ => Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
}),
}
}
/// Recursively validate a literal value, checking predicate hash references.
fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> {
match lit {
LiteralValue::NativePredicateHash(id) => {
if NativePredicate::from_str(&id.name).is_err() {
return Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
});
}
Ok(())
}
LiteralValue::ExternalPredicateHash { module, predicate } => {
if let Some(imported) = self.symbols.imported_modules.get(&module.name) {
if !imported.predicate_index.contains_key(&predicate.name) {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module.name, predicate.name),
span: predicate.span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module.name.clone(),
span: module.span,
});
}
Ok(())
}
LiteralValue::Array(a) => {
for elem in &a.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Set(s) => {
for elem in &s.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Dict(d) => {
for pair in &d.pairs {
self.validate_literal_value(&pair.value)?;
}
Ok(())
}
_ => Ok(()),
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -755,10 +801,7 @@ mod tests {
module_hash module_hash
); );
let result = parse_and_validate_request(&input, &available_modules); let result = parse_and_validate_request(&input, &available_modules);
assert!(matches!( assert!(result.is_ok());
result,
Err(ValidationError::InvalidArgumentType { .. })
));
} }
#[test] #[test]

View file

@ -49,7 +49,14 @@ custom_predicate_def = {
statement_list = { statement+ } statement_list = { statement+ }
statement_arg = { literal_value | anchored_key | identifier } // Predicate hash literals: resolve to the predicate's identity hash as a value.
// @native_predicate and @external_predicate are in literal_value (usable in containers).
// @self_predicate is only in statement_arg (not in containers — deferred resolution).
predicate_hash_native = { "@native_predicate" ~ "(" ~ identifier ~ ")" }
predicate_hash_external = { "@external_predicate" ~ "(" ~ identifier ~ "," ~ identifier ~ ")" }
predicate_hash_self = { "@self_predicate" ~ "(" ~ identifier ~ ")" }
statement_arg = { predicate_hash_self | literal_value | anchored_key | identifier }
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* } statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
// Predicate reference: either qualified (module::predicate) or local (predicate) // Predicate reference: either qualified (module::predicate) or local (predicate)
@ -74,6 +81,8 @@ literal_value = {
literal_bool | literal_bool |
literal_raw | literal_raw |
literal_string | literal_string |
predicate_hash_native |
predicate_hash_external |
literal_int literal_int
} }

View file

@ -578,7 +578,6 @@ mod tests {
max_input_pods: 3, max_input_pods: 3,
max_statements: 31, max_statements: 31,
max_public_statements: 10, max_public_statements: 10,
max_operation_args: 5,
max_custom_predicate_wildcards: 12, max_custom_predicate_wildcards: 12,
..Default::default() ..Default::default()
}; };

View file

@ -11,7 +11,9 @@ use crate::{
lang::{ lang::{
error::BatchingError, error::BatchingError,
frontend_ast::{ConjunctionType, CustomPredicateDef}, frontend_ast::{ConjunctionType, CustomPredicateDef},
frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext}, frontend_ast_lower::{
lower_statement_arg_with_context, resolve_predicate_ref, ResolutionContext,
},
frontend_ast_split::{SplitChainInfo, SplitResult}, frontend_ast_split::{SplitChainInfo, SplitResult},
frontend_ast_validate::SymbolTable, frontend_ast_validate::SymbolTable,
}, },
@ -345,7 +347,9 @@ fn build_single_batch(
})?; })?;
} }
Ok(builder.finish()) builder.finish().map_err(|e| BatchingError::Internal {
message: format!("Failed to finalize batch '{}': {}", batch_name, e),
})
} }
/// Build a statement template with properly resolved predicate references /// Build a statement template with properly resolved predicate references
@ -372,7 +376,13 @@ fn build_statement_with_resolved_refs(
let mut builder = StatementTmplBuilder::new(pred_or_wc); let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args { for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg)); let builder_arg =
lower_statement_arg_with_context(arg, symbols, &context).map_err(|e| {
BatchingError::Internal {
message: format!("Failed to lower argument: {}", e),
}
})?;
builder = builder.arg(builder_arg);
} }
Ok(builder) Ok(builder)
@ -668,4 +678,110 @@ mod tests {
PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref)) PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref))
); );
} }
#[test]
fn test_self_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
// pred_B is at index 1, its template should have SelfPredicateHash(0) resolved
// to a Literal containing pred_A's hash after normalization
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = crate::middleware::Value::from(Predicate::Custom(pred_a_ref).hash());
// Use normalized_predicate to resolve
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
assert_eq!(
normalized.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_self_predicate_hash_podlang_cyclic() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash =
crate::middleware::Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash =
crate::middleware::Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should contain pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_b_hash)
);
// pred_B's normalized form should contain pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_native_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_C(x) = AND(
Equal(x, @native_predicate(Equal))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_c_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_c = pred_c_ref.predicate();
// The second arg should be a Literal containing Equal's predicate hash
let equal_hash = crate::middleware::Value::from(
Predicate::Native(crate::middleware::NativePredicate::Equal).hash(),
);
assert_eq!(
pred_c.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(equal_hash)
);
}
} }

View file

@ -137,6 +137,9 @@ mod tests {
assert_inner(&Rule::anchored_key, "someVar[\"key\"]"); assert_inner(&Rule::anchored_key, "someVar[\"key\"]");
assert_inner(&Rule::literal_value, "true"); assert_inner(&Rule::literal_value, "true");
assert_inner(&Rule::literal_value, "PublicKey(abc)"); assert_inner(&Rule::literal_value, "PublicKey(abc)");
assert_inner(&Rule::predicate_hash_self, "@self_predicate(foo)");
assert_inner(&Rule::literal_value, "@native_predicate(Equal)");
assert_inner(&Rule::literal_value, "@external_predicate(mod_a, pred_b)");
} }
#[test] #[test]
@ -207,6 +210,33 @@ mod tests {
"{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ", "{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ",
); );
assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes
// Predicate hash literals
assert_parses(Rule::predicate_hash_native, "@native_predicate(Equal)");
assert_parses(Rule::predicate_hash_native, "@native_predicate(Lt)");
assert_parses(
Rule::predicate_hash_external,
"@external_predicate(my_module, my_pred)",
);
assert_parses(Rule::predicate_hash_self, "@self_predicate(local_pred)");
// Predicate hashes inside containers (native and external only)
assert_parses(
Rule::literal_array,
"[1, @native_predicate(Equal), @external_predicate(m, p)]",
);
assert_parses(
Rule::literal_set,
"#[@native_predicate(Equal), @native_predicate(Lt)]",
);
assert_parses(
Rule::literal_dict,
"{ \"pred\": @external_predicate(m, p) }",
);
// @self_predicate is NOT a literal_value, so it cannot appear inside containers
assert_fails(Rule::test_literal_value, "@self_predicate(local_pred)");
assert_fails(Rule::literal_array, "[@self_predicate(foo)]");
} }
#[test] #[test]

View file

@ -92,7 +92,7 @@ impl StatementTmpl {
if i > 0 { if i > 0 {
write!(w, ", ")?; write!(w, ", ")?;
} }
arg.fmt_podlang(w)?; arg.fmt_podlang_with_batch_context(w, batch_context)?;
} }
write!(w, ")")?; write!(w, ")")?;
@ -102,7 +102,30 @@ impl StatementTmpl {
impl PrettyPrint for StatementTmplArg { impl PrettyPrint for StatementTmplArg {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self) self.fmt_podlang_with_batch_context(w, None)
}
}
impl StatementTmplArg {
fn fmt_podlang_with_batch_context(
&self,
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match self {
StatementTmplArg::SelfPredicateHash(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates().get(*index) {
write!(w, "@self_predicate({})", predicate.name)
} else {
write!(w, "@self_predicate(self_{})", index)
}
} else {
write!(w, "@self_predicate(self_{})", index)
}
}
other => write!(w, "{}", other),
}
} }
} }
@ -131,7 +154,7 @@ impl CustomPredicateBatch {
impl PrettyPrint for Value { impl PrettyPrint for Value {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self.typed()) write!(w, "{}", self.typed)
} }
} }
@ -540,6 +563,34 @@ mod tests {
assert_round_trip(&input); assert_round_trip(&input);
} }
#[test]
fn test_round_trip_self_predicate_hash() {
let input = r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_self_predicate_hash_cyclic() {
let input = r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test] #[test]
fn test_pretty_print_demonstration() { fn test_pretty_print_demonstration() {
let input = r#" let input = r#"

View file

@ -169,6 +169,12 @@ pub struct Hash(
pub [F; HASH_SIZE], pub [F; HASH_SIZE],
); );
impl Hash {
pub fn raw(self) -> RawValue {
RawValue::from(self)
}
}
impl From<Hash> for HashOut { impl From<Hash> for HashOut {
fn from(hash: Hash) -> HashOut { fn from(hash: Hash) -> HashOut {
HashOut { elements: hash.0 } HashOut { elements: hash.0 }

View file

@ -1,29 +1,260 @@
//! This file implements the types defined at //! This file implements the types defined at
//! <https://0xparc.github.io/pod2/values.html#dictionary-array-set> . //! <https://0xparc.github.io/pod2/values.html#dictionary-array-set> .
use std::collections::{HashMap, HashSet}; use std::{
collections::{HashMap, HashSet},
fmt::{self, Debug},
};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{
de::{Error as _, SeqAccess, Visitor},
ser, Deserialize, Deserializer, Serialize,
};
use super::serialization::{ordered_map, ordered_set};
#[cfg(feature = "backend_plonky2")] #[cfg(feature = "backend_plonky2")]
use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; use crate::backends::plonky2::primitives::merkletree::{self, MerkleProof, MerkleTree};
use crate::{ use crate::{
backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof, backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof,
middleware::{Error, Hash, Key, RawValue, Result, Value}, middleware::{
db::{mem::MemDB, DB},
Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH,
},
}; };
#[derive(Clone, Debug)]
pub struct Container {
root: Hash,
db: Box<dyn DB>,
}
impl JsonSchema for Container {
fn schema_name() -> String {
"Container".to_string()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
// Just use the schema of Vec<Vec<Value>> since that's what we're actually serializing
Vec::<Vec<Value>>::json_schema(gen)
}
}
impl Serialize for Container {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut pairs = self
.iter()
.collect::<Result<Vec<(Value, Value)>>>()
.map_err(ser::Error::custom)?;
pairs.sort_by(|(k1, _), (k2, _)| k1.raw().cmp(&k2.raw()));
// Serialize as an array
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(pairs.len()))?;
for (k, v) in pairs {
if k == v {
seq.serialize_element(&[&v])?;
} else {
seq.serialize_element(&[&k, &v])?;
}
}
seq.end()
}
}
struct ContainerVisitor;
impl<'de> Visitor<'de> for ContainerVisitor {
type Value = HashMap<Value, Value>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of `[Value]` or `[Value, Value]`")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut kvs = HashMap::<Value, Value>::new();
while let Some(mut elem) = seq.next_element::<Vec<Value>>()? {
match elem.len() {
1 => {
let v = elem.pop().unwrap();
kvs.insert(v.clone(), v);
}
2 => {
let (v, k) = (elem.pop().unwrap(), elem.pop().unwrap());
kvs.insert(k, v);
}
n => {
return Err(A::Error::custom(format!(
"invalid vec length of {n} in container entry"
)))
}
}
}
Ok(kvs)
}
}
impl<'de> Deserialize<'de> for Container {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let kvs = deserializer.deserialize_seq(ContainerVisitor)?;
Ok(Container::new(kvs))
}
}
impl PartialEq for Container {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
}
}
impl Eq for Container {}
fn store_container_mt(db: &mut dyn DB, container: &Container) -> Result<()> {
match db.load_node(container.root) {
Err(e) => return Err(Error::Database(e)),
// Container already exists in the DB
Ok(Some(_)) => return Ok(()),
// Container not existing, we need to save it
Ok(None) => {}
};
let mut container_copy = Container::empty_with_db(db.clone_box());
for kv_result in container.iter() {
let (k, v) = kv_result?;
container_copy.insert(k, v)?;
}
Ok(())
}
fn store_value(db: &mut dyn DB, v: Value) -> Result<()> {
match &v.typed {
TypedValue::Set(Set { inner })
| TypedValue::Dictionary(Dictionary { inner })
| TypedValue::Array(Array { inner }) => {
if db.is_persistent() {
store_container_mt(db, inner)?;
}
db.store_value(v).map_err(Error::Database)?
}
_ => db.store_value(v).map_err(Error::Database)?,
}
Ok(())
}
fn load_value(db: &dyn DB, value_raw: RawValue) -> Result<Value> {
match db.load_value(value_raw) {
Err(e) => Err(Error::Database(e)),
Ok(Some(v)) => Ok(v),
Ok(None) => Err(Error::custom(format!(
"Value from {value_raw} not found in DB"
))),
}
}
impl Container {
fn mt(&self) -> MerkleTree {
MerkleTree::from_db(self.root, self.db.clone())
}
pub fn new(kvs: HashMap<Value, Value>) -> Self {
let db = Box::new(MemDB::new());
let mut container = Self::empty_with_db(db);
for (k, v) in kvs {
container.insert(k, v).expect("no duplicates, no db errors");
}
container
}
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self::from_db(EMPTY_HASH, db).expect("EMPTY_HASH exists implicitly")
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
// Make sure the root exists in the db
let _ = merkletree::load_node(db.as_ref(), root)?;
Ok(Self { root, db })
}
pub fn commitment(&self) -> Hash {
self.root
}
pub fn get(&self, key_raw: RawValue) -> Result<Option<Value>> {
Ok(match self.mt().get(&key_raw)? {
Some(value_raw) => Some(load_value(self.db.as_ref(), value_raw)?),
None => None,
})
}
pub fn prove(&self, key_raw: RawValue) -> Result<(Value, MerkleProof)> {
let (value_raw, mtp) = self.mt().prove(&key_raw)?;
let value = load_value(self.db.as_ref(), value_raw)?;
Ok((value, mtp))
}
pub fn prove_nonexistence(&self, key_raw: RawValue) -> Result<MerkleProof> {
Ok(self.mt().prove_nonexistence(&key_raw)?)
}
pub fn insert(&mut self, key: Value, value: Value) -> Result<MerkleTreeStateTransitionProof> {
let (key_raw, value_raw) = (key.raw(), value.raw());
store_value(self.db.as_mut(), key)?;
store_value(self.db.as_mut(), value)?;
let mut mt = self.mt();
let mtp = mt.insert(&key_raw, &value_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn update(
&mut self,
key_raw: RawValue,
value: Value,
) -> Result<MerkleTreeStateTransitionProof> {
let value_raw = value.raw();
store_value(self.db.as_mut(), value)?;
let mut mt = self.mt();
let mtp = mt.update(&key_raw, &value_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn delete(&mut self, key_raw: RawValue) -> Result<MerkleTreeStateTransitionProof> {
let mut mt = self.mt();
let mtp = mt.delete(&key_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn verify(
root: Hash,
proof: &MerkleProof,
key_raw: RawValue,
value_raw: RawValue,
) -> Result<()> {
Ok(MerkleTree::verify(root, proof, &key_raw, &value_raw)?)
}
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key_raw: RawValue) -> Result<()> {
Ok(MerkleTree::verify_nonexistence(root, proof, &key_raw)?)
}
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into())
}
pub fn iter(&self) -> impl Iterator<Item = Result<(Value, Value)>> {
let db = self.db.clone();
self.mt().iter().map(move |(key_raw, value_raw)| {
let key = load_value(db.as_ref(), key_raw)?;
let value = load_value(db.as_ref(), value_raw)?;
Ok((key, value))
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<Value, Value>> {
self.iter().collect()
}
}
/// Dictionary: the user original keys and values are hashed to be used in the leaf. /// Dictionary: the user original keys and values are hashed to be used in the leaf.
/// leaf.key=hash(original_key) /// leaf.key=hash(original_key)
/// leaf.value=hash(original_value) /// leaf.value=hash(original_value)
#[derive(Clone, Debug, Serialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Dictionary { pub struct Dictionary {
#[serde(skip)] pub(crate) inner: Container,
#[schemars(skip)]
mt: MerkleTree,
#[serde(serialize_with = "ordered_map")]
kvs: HashMap<Key, Value>,
} }
#[macro_export] #[macro_export]
@ -34,255 +265,371 @@ macro_rules! dict {
({ $($key:expr => $val:expr),* }) => ({ ({ $($key:expr => $val:expr),* }) => ({
let mut map = ::std::collections::HashMap::new(); let mut map = ::std::collections::HashMap::new();
$( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )*
$crate::middleware::containers::Dictionary::new( map) $crate::middleware::containers::Dictionary::new(map)
}); });
} }
// TODO: Replace all methods that receive a `&Key` by either `impl Into<String>` for write
// methods and `impl AsRef<str>` for read methods.
// TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a
// trait?
impl Dictionary { impl Dictionary {
pub fn new(kvs: HashMap<Key, Value>) -> Self { pub fn new(kvs: HashMap<Key, Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> =
kvs.iter().map(|(k, v)| (k.raw(), v.raw())).collect();
Self { Self {
mt: MerkleTree::new(&kvs_raw), inner: Container::new(
kvs, kvs.into_iter()
.map(|(k, v)| (Value::from(k.name), v))
.collect(),
),
} }
} }
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self {
inner: Container::empty_with_db(db),
}
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Ok(Self {
inner: Container::from_db(root, db)?,
})
}
pub fn commitment(&self) -> Hash { pub fn commitment(&self) -> Hash {
self.mt.root() self.inner.commitment()
} }
pub fn get(&self, key: &Key) -> Result<&Value> { pub fn get(&self, key: &Key) -> Result<Option<Value>> {
self.kvs self.inner.get(key.raw())
.get(key)
.ok_or_else(|| Error::custom(format!("key \"{}\" not found", key.name())))
} }
pub fn prove(&self, key: &Key) -> Result<(&Value, MerkleProof)> { pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> {
let (_, mtp) = self.mt.prove(&key.raw())?; self.inner.prove(key.raw())
let value = self.kvs.get(key).expect("key exists");
Ok((value, mtp))
} }
pub fn prove_nonexistence(&self, key: &Key) -> Result<MerkleProof> { pub fn prove_nonexistence(&self, key: &Key) -> Result<MerkleProof> {
Ok(self.mt.prove_nonexistence(&key.raw())?) self.inner.prove_nonexistence(key.raw())
} }
pub fn insert(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> { pub fn insert(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.insert(&key.raw(), &value.raw())?; self.inner
self.kvs.insert(key.clone(), value.clone()); .insert(Value::from(key.name.clone()), value.clone())
Ok(mtp)
} }
pub fn update(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> { pub fn update(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.update(&key.raw(), &value.raw())?; self.inner.update(key.raw(), value.clone())
self.kvs.insert(key.clone(), value.clone());
Ok(mtp)
} }
pub fn delete(&mut self, key: &Key) -> Result<MerkleTreeStateTransitionProof> { pub fn delete(&mut self, key: &Key) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.delete(&key.raw())?; self.inner.delete(key.raw())
self.kvs.remove(key);
Ok(mtp)
} }
pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> {
let key = key.raw(); Container::verify(root, proof, key.raw(), value.raw())
Ok(MerkleTree::verify(root, proof, &key, &value.raw())?)
} }
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> {
let key = key.raw(); Container::verify_nonexistence(root, proof, key.raw())
Ok(MerkleTree::verify_nonexistence(root, proof, &key)?)
} }
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) Container::verify_state_transition(proof)
} }
// TODO: Rename to dict to be consistent maybe? pub fn iter(&self) -> impl Iterator<Item = Result<(String, Value)>> + use<'_> {
pub fn kvs(&self) -> &HashMap<Key, Value> { self.inner.iter().map(|r| match r {
&self.kvs Ok((key, value)) => Ok((
key.as_string()
.ok_or_else(|| Error::custom("dictionary: key is not string"))?,
value,
)),
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<String, Value>> {
self.iter().collect()
} }
} }
impl PartialEq for Dictionary { impl PartialEq for Dictionary {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root() self.inner.eq(&other.inner)
} }
} }
impl Eq for Dictionary {} impl Eq for Dictionary {}
impl<'de> Deserialize<'de> for Dictionary {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Aux {
#[serde(serialize_with = "ordered_map")]
kvs: HashMap<Key, Value>,
}
let aux = Aux::deserialize(deserializer)?;
Ok(Dictionary::new(aux.kvs))
}
}
/// Set: the value field of the leaf is unused, and the key contains the hash of the element. /// Set: the value field of the leaf is unused, and the key contains the hash of the element.
/// leaf.key=hash(original_value) /// leaf.key=hash(original_value)
/// leaf.value=0 /// leaf.value=0
#[derive(Clone, Debug, Serialize, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Set { pub struct Set {
#[serde(skip)] pub(crate) inner: Container,
#[schemars(skip)]
mt: MerkleTree,
#[serde(serialize_with = "ordered_set")]
set: HashSet<Value>,
} }
impl Set { impl Set {
pub fn new(set: HashSet<Value>) -> Self { pub fn new(set: HashSet<Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> = set
.iter()
.map(|e| {
let rv = e.raw();
(rv, rv)
})
.collect();
Self { Self {
mt: MerkleTree::new(&kvs_raw), inner: Container::new(set.into_iter().map(|v| (v.clone(), v)).collect()),
set,
} }
} }
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self {
inner: Container::empty_with_db(db),
}
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Ok(Self {
inner: Container::from_db(root, db)?,
})
}
pub fn commitment(&self) -> Hash { pub fn commitment(&self) -> Hash {
self.mt.root() self.inner.commitment()
} }
pub fn contains(&self, value: &Value) -> bool { pub fn contains(&self, value: &Value) -> Result<bool> {
self.set.contains(value) Ok(self.inner.get(value.raw())?.is_some())
} }
pub fn prove(&self, value: &Value) -> Result<MerkleProof> { pub fn prove(&self, value: &Value) -> Result<MerkleProof> {
let rv = value.raw(); let (_, proof) = self.inner.prove(value.raw())?;
let (_, proof) = self.mt.prove(&rv)?;
Ok(proof) Ok(proof)
} }
pub fn prove_nonexistence(&self, value: &Value) -> Result<MerkleProof> { pub fn prove_nonexistence(&self, value: &Value) -> Result<MerkleProof> {
let rv = value.raw(); self.inner.prove_nonexistence(value.raw())
Ok(self.mt.prove_nonexistence(&rv)?)
} }
pub fn insert(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> { pub fn insert(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let raw_value = value.raw(); self.inner.insert(value.clone(), value.clone())
let mtp = self.mt.insert(&raw_value, &raw_value)?;
self.set.insert(value.clone());
Ok(mtp)
} }
pub fn delete(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> { pub fn delete(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.delete(&value.raw())?; self.inner.delete(value.raw())
self.set.remove(value);
Ok(mtp)
} }
pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> {
let rv = value.raw(); Container::verify(root, proof, value.raw(), value.raw())
Ok(MerkleTree::verify(root, proof, &rv, &rv)?)
} }
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> {
let rv = value.raw(); Container::verify_nonexistence(root, proof, value.raw())
Ok(MerkleTree::verify_nonexistence(root, proof, &rv)?)
} }
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) Container::verify_state_transition(proof)
} }
pub fn set(&self) -> &HashSet<Value> { pub fn iter(&self) -> impl Iterator<Item = Result<Value>> + use<'_> {
&self.set self.inner.iter().map(|r| match r {
Ok((key, value)) => {
if key != value {
return Err(Error::custom("set: key != value"));
}
Ok(value)
}
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashSet<Value>> {
self.iter().collect()
} }
} }
impl PartialEq for Set { impl PartialEq for Set {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root() self.inner.eq(&other.inner)
} }
} }
impl Eq for Set {} impl Eq for Set {}
impl<'de> Deserialize<'de> for Set {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize, JsonSchema)]
struct Aux {
#[serde(serialize_with = "ordered_set")]
set: HashSet<Value>,
}
let aux = Aux::deserialize(deserializer)?;
Ok(Set::new(aux.set))
}
}
/// Array: the elements are placed at the value field of each leaf, and the key field is just the /// Array: the elements are placed at the value field of each leaf, and the key field is just the
/// array index (integer). /// array index (integer).
/// leaf.key=i /// leaf.key=i
/// leaf.value=original_value /// leaf.value=original_value
#[derive(Clone, Debug, Serialize, JsonSchema)] /// Due to its construction this should be seen as a sparse array, where there can be gaps
/// (unused indices).
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Array { pub struct Array {
#[serde(skip)] pub(crate) inner: Container,
#[schemars(skip)]
mt: MerkleTree,
array: Vec<Value>,
} }
impl Array { impl Array {
pub fn new(array: Vec<Value>) -> Self { pub fn new(array: Vec<Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> = array
.iter()
.enumerate()
.map(|(i, e)| (RawValue::from(i as i64), e.raw()))
.collect();
Self { Self {
mt: MerkleTree::new(&kvs_raw), inner: Container::new(
array, array
.into_iter()
.enumerate()
.map(|(i, v)| (Value::from(i as i64), v))
.collect(),
),
} }
} }
pub fn commitment(&self) -> Hash { pub fn empty_with_db(db: Box<dyn DB>) -> Self {
self.mt.root() Self {
inner: Container::empty_with_db(db),
} }
pub fn get(&self, i: usize) -> Result<&Value> { }
self.array.get(i).ok_or_else(|| { pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Error::custom(format!("index {} out of bounds 0..{}", i, self.array.len())) Ok(Self {
inner: Container::from_db(root, db)?,
}) })
} }
pub fn prove(&self, i: usize) -> Result<(&Value, MerkleProof)> { pub fn commitment(&self) -> Hash {
let (_, mtp) = self.mt.prove(&RawValue::from(i as i64))?; self.inner.commitment()
let value = self.array.get(i).expect("valid index"); }
Ok((value, mtp)) pub fn get(&self, i: usize) -> Result<Option<Value>> {
self.inner.get(Value::from(i as i64).raw())
}
pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> {
self.inner.prove(Value::from(i as i64).raw())
}
pub fn insert(&mut self, i: usize, value: Value) -> Result<MerkleTreeStateTransitionProof> {
self.inner.insert(Value::from(i as i64), value)
}
pub fn delete(&mut self, i: usize) -> Result<MerkleTreeStateTransitionProof> {
self.inner.delete(Value::from(i as i64).raw())
} }
pub fn update(&mut self, i: usize, value: &Value) -> Result<MerkleTreeStateTransitionProof> { pub fn update(&mut self, i: usize, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.update(&(i as i64).into(), &value.raw())?; self.inner
self.array[i] = value.clone(); .update(Value::from(i as i64).raw(), value.clone())
Ok(mtp)
} }
pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> {
Ok(MerkleTree::verify( Container::verify(root, proof, Value::from(i as i64).raw(), value.raw())
root,
proof,
&RawValue::from(i as i64),
&value.raw(),
)?)
} }
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) Container::verify_state_transition(proof)
} }
pub fn array(&self) -> &[Value] { pub fn iter(&self) -> impl Iterator<Item = Result<(usize, Value)>> + use<'_> {
&self.array self.inner.iter().map(|r| match r {
Ok((key, value)) => {
let index = key
.as_int()
.ok_or_else(|| Error::custom("array: key is not int"))?;
Ok((index as usize, value))
}
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<usize, Value>> {
self.iter().collect()
} }
} }
impl PartialEq for Array { impl PartialEq for Array {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root() self.inner.eq(&other.inner)
} }
} }
impl Eq for Array {} impl Eq for Array {}
impl<'de> Deserialize<'de> for Array { #[cfg(test)]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> mod tests {
where use super::*;
D: Deserializer<'de>, use crate::middleware::db::mem::MemDB;
fn test_databases(test_fn: &dyn Fn(Box<dyn DB>)) {
let db = MemDB::new();
test_fn(Box::new(db));
#[cfg(feature = "db_rocksdb")]
{ {
#[derive(Deserialize, JsonSchema)] use crate::middleware::db;
struct Aux { let db = db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap();
array: Vec<Value>, test_fn(Box::new(db));
} }
let aux = Aux::deserialize(deserializer)?; }
Ok(Array::new(aux.array))
fn _test_dict(db: Box<dyn DB>) {
let mut dict0 = Dictionary::empty_with_db(db.clone());
dict0.insert(&Key::from("a"), &Value::from(1)).unwrap();
dict0.insert(&Key::from("b"), &Value::from(2)).unwrap();
dict0.update(&Key::from("a"), &Value::from(3)).unwrap();
dict0.insert(&Key::from("c"), &Value::from(4)).unwrap();
dict0.delete(&Key::from("c")).unwrap();
let kvs0 = dict0.dump().unwrap();
assert_eq!(
kvs0,
[
("a".to_string(), Value::from(3)),
("b".to_string(), Value::from(2))
]
.into_iter()
.collect()
);
let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap();
let kvs1 = dict1.dump().unwrap();
assert_eq!(kvs0, kvs1);
}
fn _test_set(db: Box<dyn DB>) {
let mut set0 = Set::empty_with_db(db.clone());
set0.insert(&Value::from(1)).unwrap();
set0.insert(&Value::from(2)).unwrap();
set0.insert(&Value::from(3)).unwrap();
set0.delete(&Value::from(2)).unwrap();
let s0 = set0.dump().unwrap();
assert_eq!(s0, [Value::from(1), Value::from(3)].into_iter().collect());
let set1 = Set::from_db(set0.commitment(), db).unwrap();
let s1 = set1.dump().unwrap();
assert_eq!(s0, s1);
}
fn _test_array(db: Box<dyn DB>) {
let mut arr0 = Array::empty_with_db(db.clone());
arr0.insert(0, Value::from("a")).unwrap();
arr0.insert(1, Value::from("b")).unwrap();
arr0.insert(2, Value::from("c")).unwrap();
arr0.delete(1).unwrap();
let a0 = arr0.dump().unwrap();
assert_eq!(
a0,
[(0, Value::from("a")), (2, Value::from("c"))]
.into_iter()
.collect()
);
let arr1 = Array::from_db(arr0.commitment(), db).unwrap();
let a1 = arr1.dump().unwrap();
assert_eq!(a0, a1);
}
fn _test_nested(db: Box<dyn DB>) {
let mut nested = Dictionary::empty_with_db(db.clone());
nested.insert(&Key::from("a"), &Value::from(1)).unwrap();
nested.insert(&Key::from("b"), &Value::from(2)).unwrap();
let nested_kvs0 = nested.dump().unwrap();
let mut dict0 = Dictionary::empty_with_db(db.clone());
dict0.insert(&Key::from("x"), &Value::from(1)).unwrap();
dict0
.insert(&Key::from("y"), &Value::from(nested.clone()))
.unwrap();
let kvs0 = dict0.dump().unwrap();
assert_eq!(
kvs0,
[
("x".to_string(), Value::from(1)),
("y".to_string(), Value::from(nested))
]
.into_iter()
.collect()
);
let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap();
let kvs1 = dict1.dump().unwrap();
assert_eq!(kvs0, kvs1);
match &kvs1["y"].typed {
TypedValue::Dictionary(d) => {
let nested_kvs1 = d.dump().unwrap();
assert_eq!(nested_kvs0, nested_kvs1);
}
_ => unreachable!(),
}
}
#[test]
fn test_dict() {
test_databases(&_test_dict);
}
#[test]
fn test_set() {
test_databases(&_test_set);
}
#[test]
fn test_array() {
test_databases(&_test_array);
}
#[test]
fn test_nested() {
test_databases(&_test_nested);
} }
} }

View file

@ -49,6 +49,9 @@ pub enum StatementTmplArg {
// AnchoredKey where the origin is a wildcard // AnchoredKey where the origin is a wildcard
AnchoredKey(Wildcard, Key), AnchoredKey(Wildcard, Key),
Wildcard(Wildcard), Wildcard(Wildcard),
/// Reference to a same-batch predicate's identity hash, resolved at verification time.
/// The usize is the predicate index within the batch.
SelfPredicateHash(usize),
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix {
Literal = 1, Literal = 1,
AnchoredKey = 2, AnchoredKey = 2,
WildcardLiteral = 3, WildcardLiteral = 3,
SelfPredicateHash = 4,
} }
impl From<StatementTmplArgPrefix> for F { impl From<StatementTmplArgPrefix> for F {
@ -72,7 +76,8 @@ impl ToFields for StatementTmplArg {
// Literal(v) => (1, [v ], 0, 0, 0, 0) // Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements // SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0)
// In all cases, we pad to 2 * hash_size + 1 = 9 field elements
match self { match self {
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO)) .chain(iter::repeat(F::ZERO))
@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg {
.take(Params::statement_tmpl_arg_size()) .take(Params::statement_tmpl_arg_size())
.collect_vec() .collect_vec()
} }
StatementTmplArg::SelfPredicateHash(index) => {
iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash))
.chain(iter::once(F::from_canonical_usize(*index)))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
} }
} }
} }
@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg {
write!(f, "]") write!(f, "]")
} }
Self::Wildcard(v) => v.fmt(f), Self::Wildcard(v) => v.fmt(f),
Self::SelfPredicateHash(i) => write!(f, "::self.{}", i),
} }
} }
} }
@ -423,7 +436,7 @@ impl fmt::Display for CustomPredicate {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] #[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)]
enum CustomPredicateBatchData { enum CustomPredicateBatchData {
Full { Full {
#[serde(skip)] #[serde(skip)]
@ -436,6 +449,20 @@ enum CustomPredicateBatchData {
}, },
} }
// Explicit implementation of Debug to skip the merkle tree
impl fmt::Debug for CustomPredicateBatchData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
Self::Full { mt, predicates } => f
.debug_struct("Full")
.field("id", &mt.root())
.field("predicates", &predicates)
.finish(),
Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(),
}
}
}
// TODO: Rename Batch for Module everywhere in the code base // TODO: Rename Batch for Module everywhere in the code base
impl CustomPredicateBatchData { impl CustomPredicateBatchData {
fn new_full(predicates: Vec<CustomPredicate>) -> Self { fn new_full(predicates: Vec<CustomPredicate>) -> Self {
@ -569,6 +596,44 @@ impl CustomPredicateRef {
pub fn predicate(&self) -> &CustomPredicate { pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates()[self.index] &self.batch.predicates()[self.index]
} }
/// Returns a copy of this predicate with all `SelfPredicateHash(i)` args
/// resolved to `Literal(hash(Custom(batch, i)))`.
pub fn normalized_predicate(&self) -> CustomPredicate {
let pred = self.predicate();
let normalized_statements = pred
.statements
.iter()
.map(|st_tmpl| {
let args = st_tmpl
.args
.iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: self.batch.clone(),
index: *i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other.clone(),
})
.collect();
StatementTmpl {
pred_or_wc: st_tmpl.pred_or_wc.clone(),
args,
}
})
.collect();
CustomPredicate {
name: pred.name.clone(),
conjunction: pred.conjunction,
statements: normalized_statements,
args_len: pred.args_len,
wildcard_names: pred.wildcard_names.clone(),
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -579,7 +644,7 @@ mod tests {
middleware::{ middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl, NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl,
StatementTmplArg, StatementTmplArg, ValueRef,
}, },
}; };
@ -602,6 +667,9 @@ mod tests {
fn names(names: &[&str]) -> Vec<String> { fn names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect() names.iter().map(|s| s.to_string()).collect()
} }
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
type STA = StatementTmplArg; type STA = StatementTmplArg;
@ -650,7 +718,7 @@ mod tests {
}); });
let custom_statement = Statement::Custom( let custom_statement = Statement::Custom(
CustomPredicateRef::new(cust_pred_batch.clone(), 0), CustomPredicateRef::new(cust_pred_batch.clone(), 0),
vec![Value::from(d0.clone())], vec![value_ref(d0.clone())],
); );
let custom_deduction = Operation::Custom( let custom_deduction = Operation::Custom(
@ -782,7 +850,7 @@ mod tests {
// Example statement // Example statement
let ethdos_example = Statement::Custom( let ethdos_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
); );
// Copies should work. // Copies should work.
@ -791,7 +859,7 @@ mod tests {
// This could arise as the inductive step. // This could arise as the inductive step.
let ethdos_ind_example = Statement::Custom( let ethdos_ind_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
); );
assert!(Operation::Custom( assert!(Operation::Custom(
@ -806,12 +874,12 @@ mod tests {
let ethdos_facts = vec![ let ethdos_facts = vec![
Statement::Custom( Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)], vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)],
), ),
Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)), Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)),
Statement::Custom( Statement::Custom(
CustomPredicateRef::new(eth_friend_batch.clone(), 0), CustomPredicateRef::new(eth_friend_batch.clone(), 0),
vec![Value::from("Charlie"), Value::from("Bob")], vec![value_ref("Charlie"), value_ref("Bob")],
), ),
]; ];
@ -823,4 +891,173 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_normalized_predicate() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
// Compute expected pred_A hash
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
// Normalize pred_B
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
// The second arg should be resolved to Literal(pred_A_hash)
assert_eq!(
normalized.statements[0].args[1],
STA::Literal(expected_hash)
);
// First arg should be unchanged (still a wildcard)
assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0)));
Ok(())
}
#[test]
fn test_self_predicate_hash_check() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
// Construct a valid operation: Equal(some_value, pred_a_hash)
let some_value = Value::from(42);
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
// The output statement
let output_st = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(some_value.clone())],
);
// This should pass
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?);
// Now try with wrong hash, should fail
let wrong_hash = Value::from(999);
let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)];
assert!(Operation::Custom(pred_b_ref, bad_op_args)
.check(&params, &output_st)
.is_err());
Ok(())
}
#[test]
fn test_self_predicate_hash_cyclic() -> Result<()> {
let params = Params::default();
// Build a batch where pred_A references pred_B's hash and vice versa
// pred_A = Equal(x, SelfPredicateHash(1))
// pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)],
)],
1,
names(&["x"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should reference pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
STA::Literal(pred_b_hash.clone())
);
// pred_B's normalized form should reference pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
STA::Literal(pred_a_hash.clone())
);
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
let st_a = Statement::Custom(
pred_a_ref.clone(),
vec![ValueRef::Literal(pred_b_hash.clone())],
);
assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?);
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
let st_b = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(pred_a_hash.clone())],
);
assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?);
Ok(())
}
} }

62
src/middleware/db/mem.rs Normal file
View file

@ -0,0 +1,62 @@
use super::*;
/// MemDB implements the DB trait in a in-memory HashMap.
#[derive(Clone, Debug, Default)]
pub struct MemDB {
nodes: Arc<RwLock<HashMap<Hash, merkletree::Node>>>,
values: Arc<RwLock<HashMap<RawValue, Value>>>,
}
impl MemDB {
pub fn new() -> Self {
Self::default()
}
}
impl merkletree::db::DB for MemDB {
fn load_node(&self, hash: Hash) -> anyhow::Result<Option<merkletree::Node>> {
let nodes = self.nodes.read().expect("lock not poisoned");
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
Ok(nodes.get(&hash).cloned())
}
fn store_node(&mut self, node: merkletree::Node) -> anyhow::Result<()> {
let mut nodes = self.nodes.write().expect("lock not poisoned");
nodes.insert(node.hash(), node);
Ok(())
}
}
impl DB for MemDB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>> {
let values = self.values.read().expect("lock not poisoned");
Ok(values.get(&raw).cloned())
}
fn store_value(&mut self, value: Value) -> anyhow::Result<()> {
let mut values = self.values.write().expect("lock not poisoned");
let value_raw = value.raw();
if let Some(old_value) = values.get(&value_raw) {
let old_is_raw = old_value.is_raw();
// If we had a non-RawValue stored don't overwrite it (specially not with a
// RawValue). Also skip redundant RawValue overwrite.
if !old_is_raw || value.is_raw() {
return Ok(());
}
}
values.insert(value_raw, value);
Ok(())
}
fn is_persistent(&self) -> bool {
false
}
fn clone_box(&self) -> Box<dyn DB> {
Box::new(self.clone())
}
}

30
src/middleware/db/mod.rs Normal file
View file

@ -0,0 +1,30 @@
use std::{
collections::HashMap,
fmt::Debug,
sync::{Arc, RwLock},
};
use dyn_clone::DynClone;
#[cfg(feature = "backend_plonky2")]
use crate::backends::plonky2::primitives::merkletree::{self};
use crate::middleware::{Hash, RawValue, Value, EMPTY_HASH};
pub mod mem;
#[cfg(feature = "db_rocksdb")]
pub mod rocks;
// Trait for database that stores values. Must be cheap to clone.
pub trait DB: Debug + DynClone + Sync + Send + merkletree::db::DB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>>;
// If the DB is persistent, for containers only the root needs to be stored because the
// Container type makes sure the underlying merkle tree is stored in the DB independently, so
// that it can be recovered back just with the root and the DB.
// If the value is RawValue and a previous non-RawValue exists, no store overwrite it.
// should be done. If the value is non-RawValue and a previous RawValue exists, store
// should overwrite it.
fn store_value(&mut self, value: Value) -> anyhow::Result<()>;
fn is_persistent(&self) -> bool;
fn clone_box(&self) -> Box<dyn DB>;
}
dyn_clone::clone_trait_object!(DB);

107
src/middleware/db/rocks.rs Normal file
View file

@ -0,0 +1,107 @@
use std::{fmt, path::Path, sync::Arc};
use anyhow::{anyhow, Result};
use rocksdb::{Options, TransactionDB, TransactionDBOptions};
use super::*;
fn node_key(hash: Hash) -> Vec<u8> {
let mut k = Vec::with_capacity(2 + 4);
k.extend_from_slice(b"n/");
k.extend_from_slice(&RawValue::from(hash).to_bytes());
k
}
fn value_key(raw: RawValue) -> Vec<u8> {
let mut k = Vec::with_capacity(2 + 4);
k.extend_from_slice(b"v/");
k.extend_from_slice(&raw.to_bytes());
k
}
#[derive(Clone)]
pub struct RocksDB {
db: Arc<TransactionDB>,
}
impl fmt::Debug for RocksDB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RocksDB(path: {:?})", self.db.path())
}
}
impl RocksDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let mut options = Options::default();
options.create_if_missing(true);
let txn_options = TransactionDBOptions::default();
let inner =
TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?;
Ok(Self {
db: Arc::new(inner),
})
}
}
impl merkletree::db::DB for RocksDB {
fn load_node(&self, hash: Hash) -> Result<Option<merkletree::Node>> {
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
match self.db.get(node_key(hash))? {
None => Ok(None),
Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)),
}
}
fn store_node(&mut self, node: merkletree::Node) -> Result<()> {
self.db
.put(node_key(node.hash()), node.encode()?)
.map_err(|e| anyhow!("rocksdb transaction put failed: {e}"))
}
}
impl DB for RocksDB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>> {
match self.db.get(value_key(raw))? {
None => Ok(None),
Some(bytes) => Ok(Some({
if bytes.is_empty() {
Value::from(raw)
} else {
Value::from_bytes(bytes.as_ref(), self.clone_box())?
}
})),
}
}
fn store_value(&mut self, value: Value) -> anyhow::Result<()> {
let value_key = value_key(value.raw());
let tx = self.db.transaction();
if let Some(old_value_bytes) = tx.get_for_update(&value_key, true)? {
let is_raw = old_value_bytes.is_empty();
// If we had a non-RawValue stored don't overwrite it (specially not with a
// RawValue). Also skip redundant RawValue overwrite.
if !is_raw || (is_raw && value.is_raw()) {
return Ok(());
}
}
let value_bytes = if value.is_raw() {
// For RawValue we store an empty vector because it's a duplicate of the key.
// This way we can easily check for RawValue without decoding.
vec![]
} else {
Value::to_bytes(&value)
};
tx.put(value_key, value_bytes)?;
Ok(tx.commit()?)
}
fn is_persistent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn DB> {
Box::new(self.clone())
}
}

View file

@ -72,6 +72,10 @@ pub enum Error {
}, },
#[error(transparent)] #[error(transparent)]
Tree(#[from] crate::backends::plonky2::primitives::merkletree::error::TreeError), Tree(#[from] crate::backends::plonky2::primitives::merkletree::error::TreeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("database error: {0}")]
Database(anyhow::Error),
} }
impl Debug for Error { impl Debug for Error {
@ -164,7 +168,7 @@ impl Error {
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateDisjunction(pred)) new!(UnsatisfiedCustomPredicateDisjunction(pred))
} }
pub(crate) fn custom(s: String) -> Self { pub(crate) fn custom(s: impl Into<String>) -> Self {
new!(Custom(s)) new!(Custom(s.into()))
} }
} }

View file

@ -1,16 +1,13 @@
//! The middleware includes the type definitions and the traits used to connect the frontend and //! The middleware includes the type definitions and the traits used to connect the frontend and
//! the backend. //! the backend.
use std::sync::Arc;
use hex::ToHex; use hex::ToHex;
use itertools::Itertools;
use strum_macros::FromRepr; use strum_macros::FromRepr;
mod basetypes; mod basetypes;
use std::{cmp::PartialEq, hash}; use std::{cmp::PartialEq, hash};
use containers::{Array, Dictionary, Set}; use containers::{Array, Container, Dictionary, Set};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod containers; pub mod containers;
@ -22,6 +19,7 @@ pub mod serialization;
mod statement; mod statement;
use std::{any::Any, fmt}; use std::{any::Any, fmt};
pub mod db;
pub use basetypes::*; pub use basetypes::*;
pub use custom::*; pub use custom::*;
use dyn_clone::DynClone; use dyn_clone::DynClone;
@ -31,14 +29,10 @@ pub use pod_deserialization::*;
use serialization::*; use serialization::*;
pub use statement::*; pub use statement::*;
use crate::backends::plonky2::primitives::merkletree::{
MerkleProof, MerkleTreeStateTransitionProof,
};
// TODO: Move all value-related types to to `value.rs` // TODO: Move all value-related types to to `value.rs`
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
// TODO #[schemars(transform = serialization::transform_value_schema)] // TODO #[schemars(transform = serialization::transform_value_schema)]
pub enum TypedValue { pub(crate) enum TypedValue {
// Serde cares about the order of the enum variants, with untagged variants // Serde cares about the order of the enum variants, with untagged variants
// appearing at the end. // appearing at the end.
// Variants without "untagged" will be serialized as "tagged" values by // Variants without "untagged" will be serialized as "tagged" values by
@ -73,8 +67,6 @@ pub enum TypedValue {
Array(Array), Array(Array),
#[serde(untagged)] #[serde(untagged)]
String(String), String(String),
#[serde(untagged)]
Bool(bool),
} }
impl From<&str> for TypedValue { impl From<&str> for TypedValue {
@ -97,7 +89,11 @@ impl From<i64> for TypedValue {
impl From<bool> for TypedValue { impl From<bool> for TypedValue {
fn from(b: bool) -> Self { fn from(b: bool) -> Self {
TypedValue::Bool(b) if b {
TypedValue::Int(1)
} else {
TypedValue::Int(0)
}
} }
} }
@ -149,70 +145,6 @@ impl From<RawValue> for TypedValue {
} }
} }
impl TryFrom<&TypedValue> for i64 {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::Int(n) = v {
Ok(*n)
} else {
Err(Error::custom("Value not an int".to_string()))
}
}
}
impl TryFrom<&TypedValue> for String {
type Error = Error;
fn try_from(tv: &TypedValue) -> Result<Self> {
match tv {
TypedValue::String(s) => Ok(s.clone()),
_ => Err(Error::custom(format!(
"Value {} cannot be converted to a string.",
tv
))),
}
}
}
impl TryFrom<&TypedValue> for Key {
type Error = Error;
fn try_from(tv: &TypedValue) -> Result<Self> {
Ok(Key::new(String::try_from(tv)?))
}
}
impl TryFrom<&TypedValue> for PublicKey {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::PublicKey(pk) = v {
Ok(*pk)
} else {
Err(Error::custom("Value not a public key".to_string()))
}
}
}
impl TryFrom<&TypedValue> for SecretKey {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::SecretKey(sk) = v {
Ok(sk.clone())
} else {
Err(Error::custom("Value not a secret key".to_string()))
}
}
}
impl TryFrom<&TypedValue> for Predicate {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::Predicate(p) = v {
Ok(p.clone())
} else {
Err(Error::custom("Value not a Predicate".to_string()))
}
}
}
impl fmt::Display for TypedValue { impl fmt::Display for TypedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
@ -224,36 +156,54 @@ impl fmt::Display for TypedValue {
Err(_) => write!(f, "\"{}\"", s), Err(_) => write!(f, "\"{}\"", s),
} }
} }
TypedValue::Bool(b) => write!(f, "{}", b),
TypedValue::Array(a) => { TypedValue::Array(a) => {
write!(f, "[")?; write!(f, "[")?;
for (i, v) in a.array().iter().enumerate() { for (i, r) in a.iter().enumerate() {
if i > 0 { if i > 0 {
write!(f, ", ")?; write!(f, ", ")?;
} }
write!(f, "{}", v)?; if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok((index, value)) => write!(f, "{}: {}", index, value)?,
Err(e) => write!(f, "{e}")?,
}
} }
write!(f, "]") write!(f, "]")
} }
TypedValue::Dictionary(d) => { TypedValue::Dictionary(d) => {
write!(f, "{{ ")?; write!(f, "{{ ")?;
let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect(); for (i, r) in d.iter().enumerate() {
for (i, (k, v)) in kvs.iter().enumerate() {
if i > 0 { if i > 0 {
write!(f, ", ")?; write!(f, ", ")?;
} }
write!(f, "{}: {}", k, v)?; if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok((key, value)) => write!(f, "{}: {}", key, value)?,
Err(e) => write!(f, "{e}")?,
}
} }
write!(f, " }}") write!(f, " }}")
} }
TypedValue::Set(s) => { TypedValue::Set(s) => {
write!(f, "#[")?; write!(f, "#[")?;
let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect(); for (i, r) in s.iter().enumerate() {
for (i, v) in values.iter().enumerate() {
if i > 0 { if i > 0 {
write!(f, ", ")?; write!(f, ", ")?;
} }
write!(f, "{}", v)?; if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok(value) => write!(f, "{}", value)?,
Err(e) => write!(f, "{e}")?,
}
} }
write!(f, "]") write!(f, "]")
} }
@ -272,7 +222,6 @@ impl From<&TypedValue> for RawValue {
match v { match v {
TypedValue::String(s) => RawValue::from(hash_str(s)), TypedValue::String(s) => RawValue::from(hash_str(s)),
TypedValue::Int(v) => RawValue::from(*v), TypedValue::Int(v) => RawValue::from(*v),
TypedValue::Bool(b) => RawValue::from(*b as i64),
TypedValue::Dictionary(d) => RawValue::from(d.commitment()), TypedValue::Dictionary(d) => RawValue::from(d.commitment()),
TypedValue::Set(s) => RawValue::from(s.commitment()), TypedValue::Set(s) => RawValue::from(s.commitment()),
TypedValue::Array(a) => RawValue::from(a.commitment()), TypedValue::Array(a) => RawValue::from(a.commitment()),
@ -405,9 +354,8 @@ impl JsonSchema for TypedValue {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Value { pub struct Value {
// The `TypedValue` is under `Arc` so that cloning a `Value` is cheap. pub(crate) typed: TypedValue,
typed: Arc<TypedValue>, pub(crate) raw: RawValue,
raw: RawValue,
} }
// Values are serialized as their TypedValue. // Values are serialized as their TypedValue.
@ -441,6 +389,55 @@ impl JsonSchema for Value {
} }
} }
/// Dual of TypedValue that is not recursive: for container types no entry only the commitment
/// (merkle tree root of underlying data) is available. Used for byte serialization for
/// persistent storage.
#[derive(Serialize, Deserialize)]
enum TypedValueNoRec {
Raw(RawValue),
Int(i64),
PublicKey(PublicKey),
SecretKey(SecretKey),
Predicate(Predicate),
Set(Hash),
Dictionary(Hash),
Array(Hash),
String(String),
}
// NOTE: byte serialization is using json. Using a byte-native serialization would improve
// performance and storage usage.
impl Value {
pub fn to_bytes(&self) -> Vec<u8> {
let v = match &self.typed {
TypedValue::Int(v) => TypedValueNoRec::Int(*v),
TypedValue::Raw(v) => TypedValueNoRec::Raw(*v),
TypedValue::PublicKey(v) => TypedValueNoRec::PublicKey(*v),
TypedValue::SecretKey(v) => TypedValueNoRec::SecretKey(v.clone()),
TypedValue::Predicate(v) => TypedValueNoRec::Predicate(v.clone()),
TypedValue::Set(v) => TypedValueNoRec::Set(v.commitment()),
TypedValue::Dictionary(v) => TypedValueNoRec::Dictionary(v.commitment()),
TypedValue::Array(v) => TypedValueNoRec::Array(v.commitment()),
TypedValue::String(v) => TypedValueNoRec::String(v.clone()),
};
serde_json::to_vec(&v).expect("json serialization succeeds")
}
pub fn from_bytes(bytes: &[u8], db: Box<dyn db::DB>) -> Result<Self> {
let v: TypedValueNoRec = serde_json::from_slice(bytes)?;
Ok(match v {
TypedValueNoRec::Int(v) => Value::from(v),
TypedValueNoRec::Raw(v) => Value::from(v),
TypedValueNoRec::PublicKey(v) => Value::from(v),
TypedValueNoRec::SecretKey(v) => Value::from(v),
TypedValueNoRec::Predicate(v) => Value::from(v),
TypedValueNoRec::Set(v) => Value::from(Set::from_db(v, db)?),
TypedValueNoRec::Dictionary(v) => Value::from(Dictionary::from_db(v, db)?),
TypedValueNoRec::Array(v) => Value::from(Array::from_db(v, db)?),
TypedValueNoRec::String(v) => Value::from(v),
})
}
}
impl PartialEq for Value { impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.raw == other.raw self.raw == other.raw
@ -462,106 +459,110 @@ impl fmt::Display for Value {
} }
impl Value { impl Value {
pub fn new(value: TypedValue) -> Self { pub(crate) fn new(value: TypedValue) -> Self {
let raw_value = RawValue::from(&value); let raw_value = RawValue::from(&value);
Self { Self {
typed: Arc::new(value), typed: value,
raw: raw_value, raw: raw_value,
} }
} }
pub fn typed(&self) -> &TypedValue {
&self.typed
}
pub fn raw(&self) -> RawValue { pub fn raw(&self) -> RawValue {
self.raw self.raw
} }
/// Determines Merkle existence proof for `key` in `self` (if applicable). /// Returns true if the typed value is RawValue, which means it's a generic value with no type
pub(crate) fn prove_existence<'a>( /// information and no extra value data.
&'a self, pub fn is_raw(&self) -> bool {
key: &'a Value, matches!(self.typed, TypedValue::Raw(_))
) -> Result<(&'a Value, MerkleProof)> { }
match &self.typed() { pub fn as_raw(&self) -> RawValue {
TypedValue::Array(a) => match key.typed() { self.raw
TypedValue::Int(i) if i >= &0 => a.prove((*i) as usize), }
_ => Err(Error::custom(format!( pub fn as_int(&self) -> Option<i64> {
"Invalid key {} for container {}.", match self.typed {
key, self TypedValue::Int(i) => Some(i),
)))?, _ => None,
}
}
pub fn as_public_key(&self) -> Option<PublicKey> {
match &self.typed {
TypedValue::PublicKey(pk) => Some(*pk),
_ => None,
}
}
pub fn as_secret_key(&self) -> Option<SecretKey> {
match &self.typed {
TypedValue::SecretKey(sk) => Some(sk.clone()),
_ => None,
}
}
pub fn as_predicate(&self) -> Option<Predicate> {
match &self.typed {
TypedValue::Predicate(p) => Some(p.clone()),
_ => None,
}
}
pub fn as_set(&self) -> Option<Set> {
match &self.typed {
TypedValue::Set(s) => Some(s.clone()),
TypedValue::Dictionary(d) => Some(Set {
inner: d.inner.clone(),
}),
TypedValue::Array(a) => Some(Set {
inner: a.inner.clone(),
}),
_ => None,
}
}
pub fn as_container(&self) -> Option<Container> {
match &self.typed {
TypedValue::Set(s) => Some(s.inner.clone()),
TypedValue::Dictionary(d) => Some(d.inner.clone()),
TypedValue::Array(a) => Some(a.inner.clone()),
_ => None,
}
}
pub fn as_dictionary(&self) -> Option<Dictionary> {
match &self.typed {
TypedValue::Set(s) => Some(Dictionary {
inner: s.inner.clone(),
}),
TypedValue::Dictionary(d) => Some(d.clone()),
TypedValue::Array(a) => Some(Dictionary {
inner: a.inner.clone(),
}),
_ => None,
}
}
pub fn as_array(&self) -> Option<Array> {
match &self.typed {
TypedValue::Set(s) => Some(Array {
inner: s.inner.clone(),
}),
TypedValue::Dictionary(d) => Some(Array {
inner: d.inner.clone(),
}),
TypedValue::Array(a) => Some(a.clone()),
_ => None,
}
}
pub fn as_str(&self) -> Option<&str> {
match &self.typed {
TypedValue::String(s) => Some(s.as_str()),
_ => None,
}
}
pub fn as_string(&self) -> Option<String> {
self.as_str().map(|s| s.to_string())
}
pub fn as_bool(&self) -> Option<bool> {
match self.typed {
TypedValue::Int(i) => match i {
0 => Some(false),
1 => Some(true),
_ => None,
}, },
TypedValue::Dictionary(d) => d.prove(&key.typed().try_into()?), _ => None,
TypedValue::Set(s) => Ok((key, s.prove(key)?)),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Determines Merkle non-existence proof for `key` in `self` (if applicable).
pub(crate) fn prove_nonexistence<'a>(&'a self, key: &'a Value) -> Result<MerkleProof> {
match &self.typed() {
TypedValue::Array(_) => Err(Error::custom(
"Arrays do not support `NotContains` operation.".to_string(),
)),
TypedValue::Dictionary(d) => d.prove_nonexistence(&key.typed().try_into()?),
TypedValue::Set(s) => s.prove_nonexistence(key),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for inserting a
/// key-value pair (if applicable).
pub(crate) fn prove_insertion(
&self,
key: &Value,
value: &Value,
) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Dictionary(mut d) => d.insert(&key.typed().try_into()?, value),
TypedValue::Set(mut s) => s.insert(value),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for updating a
/// key-value pair (if applicable).
pub(crate) fn prove_update(
&self,
key: &Value,
value: &Value,
) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Array(mut a) => match key.typed() {
TypedValue::Int(i) if i >= &0 => a.update(*i as usize, value),
_ => Err(Error::custom(format!(
"Invalid key {} for container {}.",
key, self
)))?,
},
TypedValue::Dictionary(mut d) => d.update(&key.typed().try_into()?, value),
_ => Err(Error::custom(format!(
"Invalid container value {} for update op",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for deleting a
/// key (if applicable).
pub(crate) fn prove_deletion(&self, key: &Value) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Dictionary(mut d) => d.delete(&key.typed().try_into()?),
TypedValue::Set(mut s) => s.delete(key),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
} }
} }
} }
@ -767,6 +768,8 @@ pub struct BaseParams {
/// in a custom predicate /// in a custom predicate
pub max_custom_predicate_arity: usize, pub max_custom_predicate_arity: usize,
pub max_depth_custom_batch_mt: usize, pub max_depth_custom_batch_mt: usize,
// This value depends on `max_custom_predicate_arity`
pub max_operation_args: usize,
} }
pub const BASE_PARAMS: BaseParams = BaseParams { pub const BASE_PARAMS: BaseParams = BaseParams {
@ -774,8 +777,53 @@ pub const BASE_PARAMS: BaseParams = BaseParams {
max_statement_args: 5, max_statement_args: 5,
max_custom_predicate_arity: 5, max_custom_predicate_arity: 5,
max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch
max_operation_args: 5 + 1,
}; };
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")]
pub struct ParamsMerkleProofs {
pub max_small: usize,
pub max_medium: usize,
}
impl ParamsMerkleProofs {
pub fn max_total(&self) -> usize {
self.max_small + self.max_medium
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")]
pub struct ParamsContainers {
// Parameters for exists/nonexists container operations. The small set only supports exists
pub state: ParamsMerkleProofs,
// Parameters for transition container operations (insert, delete, update). The small set only
// supports update.
pub transition: ParamsMerkleProofs,
// Max depth of small proofs
pub max_depth_small: usize,
// Max depth of medium proofs
pub max_depth_medium: usize,
}
impl Default for ParamsContainers {
fn default() -> Self {
Self {
state: ParamsMerkleProofs {
max_small: 22,
max_medium: 8,
},
transition: ParamsMerkleProofs {
max_small: 12,
max_medium: 6,
},
max_depth_small: 8,
max_depth_medium: 32,
}
}
}
/// Params: non dynamic parameters that define the circuit. /// Params: non dynamic parameters that define the circuit.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -784,18 +832,12 @@ pub struct Params {
pub max_input_pods_public_statements: usize, pub max_input_pods_public_statements: usize,
pub max_statements: usize, pub max_statements: usize,
pub max_public_statements: usize, pub max_public_statements: usize,
pub max_operation_args: usize,
// max number of different custom predicates that can be used in a MainPod // max number of different custom predicates that can be used in a MainPod
pub max_custom_predicates: usize, pub max_custom_predicates: usize,
// max number of operations using custom predicates that can be verified in the MainPod // max number of operations using custom predicates that can be verified in the MainPod
pub max_custom_predicate_verifications: usize, pub max_custom_predicate_verifications: usize,
pub max_custom_predicate_wildcards: usize, pub max_custom_predicate_wildcards: usize,
// maximum number of merkle proofs used for container operations pub containers: ParamsContainers,
pub max_merkle_proofs_containers: usize,
// maximum number of merkle tree state transition proofs used for container update operations
pub max_merkle_tree_state_transition_proofs_containers: usize,
// maximum depth for merkle tree gadget used for container operations
pub max_depth_mt_containers: usize,
// maximum depth of the merkle tree gadget used for verifier_data membership // maximum depth of the merkle tree gadget used for verifier_data membership
// check. This allows creating verifying sets of pod circuits of size // check. This allows creating verifying sets of pod circuits of size
// 2^max_depth_mt_vds. Limits the number of container operations of the type Contains, // 2^max_depth_mt_vds. Limits the number of container operations of the type Contains,
@ -814,13 +856,10 @@ impl Default for Params {
max_input_pods_public_statements: 8, max_input_pods_public_statements: 8,
max_statements: 48, max_statements: 48,
max_public_statements: 8, max_public_statements: 8,
max_operation_args: 5,
max_custom_predicates: 8, max_custom_predicates: 8,
max_custom_predicate_verifications: 8, max_custom_predicate_verifications: 8,
max_custom_predicate_wildcards: 8, max_custom_predicate_wildcards: 8,
max_merkle_proofs_containers: 20, containers: ParamsContainers::default(),
max_merkle_tree_state_transition_proofs_containers: 6,
max_depth_mt_containers: 32,
max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits
max_public_key_of: 2, max_public_key_of: 2,
max_signed_by: 4, max_signed_by: 4,

View file

@ -7,17 +7,14 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::primitives::{ backends::plonky2::primitives::{
ec::{ ec::{curve::GROUP_ORDER, schnorr::Signature},
curve::{Point as PublicKey, GROUP_ORDER},
schnorr::{SecretKey, Signature},
},
merkletree::{MerkleProof, MerkleTree, MerkleTreeOp, MerkleTreeStateTransitionProof}, merkletree::{MerkleProof, MerkleTree, MerkleTreeOp, MerkleTreeStateTransitionProof},
}, },
middleware::{ middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef,
ValueRef, Wildcard, F, Wildcard, BASE_PARAMS, F,
}, },
}; };
@ -92,6 +89,7 @@ pub enum NativeOperation {
ContainerInsertFromEntries = 16, ContainerInsertFromEntries = 16,
ContainerUpdateFromEntries = 17, ContainerUpdateFromEntries = 17,
ContainerDeleteFromEntries = 18, ContainerDeleteFromEntries = 18,
ReplaceValueWithEntry = 19,
// Syntactic sugar operations. These operations are not supported by the backend. The // Syntactic sugar operations. These operations are not supported by the backend. The
// frontend compiler is responsible of translating these operations into the operations above. // frontend compiler is responsible of translating these operations into the operations above.
@ -167,6 +165,7 @@ impl OperationType {
NativeOperation::ContainerDeleteFromEntries => { NativeOperation::ContainerDeleteFromEntries => {
Some(Predicate::Native(NativePredicate::ContainerDelete)) Some(Predicate::Native(NativePredicate::ContainerDelete))
} }
NativeOperation::ReplaceValueWithEntry => None,
no => unreachable!("Unexpected syntactic sugar op {:?}", no), no => unreachable!("Unexpected syntactic sugar op {:?}", no),
}, },
OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())),
@ -222,6 +221,10 @@ pub enum Operation {
/* key */ Statement, /* key */ Statement,
/* proof */ MerkleTreeStateTransitionProof, /* proof */ MerkleTreeStateTransitionProof,
), ),
ReplaceValueWithEntry(
/* Contains/None len=max_statement_args */ Vec<Statement>,
/* to copy */ Statement,
),
Custom(CustomPredicateRef, Vec<Statement>), Custom(CustomPredicateRef, Vec<Statement>),
} }
@ -241,6 +244,10 @@ pub(crate) fn hash_op(x: Value, y: Value) -> Value {
Value::from(hash_values(&[x, y])) Value::from(hash_values(&[x, y]))
} }
fn ok_or_type_err<T>(o: Option<T>, v: &Value, typ: &'static str) -> Result<T> {
o.ok_or_else(|| Error::custom(format!("{v} type is not {typ}")))
}
impl Operation { impl Operation {
pub fn op_type(&self) -> OperationType { pub fn op_type(&self) -> OperationType {
type OT = OperationType; type OT = OperationType;
@ -269,6 +276,7 @@ impl Operation {
OT::Native(ContainerUpdateFromEntries) OT::Native(ContainerUpdateFromEntries)
} }
Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries), Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries),
Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry),
Self::Custom(cpr, _) => OT::Custom(cpr.clone()), Self::Custom(cpr, _) => OT::Custom(cpr.clone()),
} }
} }
@ -294,6 +302,11 @@ impl Operation {
Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3], Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3],
Self::ReplaceValueWithEntry(args, s) => {
let mut sts = args;
sts.push(s);
sts
}
Self::Custom(_, args) => args, Self::Custom(_, args) => args,
} }
} }
@ -376,6 +389,18 @@ impl Operation {
&[s1, s2, s3], &[s1, s2, s3],
OA::MerkleTreeStateTransitionProof(pf), OA::MerkleTreeStateTransitionProof(pf),
) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf), ) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf),
(NO::ReplaceValueWithEntry, args, OA::None) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = args.pop().expect("valid vec len");
Self::ReplaceValueWithEntry(args, st)
}
_ => Err(Error::custom(format!( _ => Err(Error::custom(format!(
"Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.",
op_code, op_code,
@ -404,23 +429,55 @@ impl Operation {
v3: &Value, v3: &Value,
f: impl FnOnce(i64, i64) -> i64, f: impl FnOnce(i64, i64) -> i64,
) -> Result<bool> { ) -> Result<bool> {
let i1: i64 = v1.typed().try_into()?; let i1 = ok_or_type_err(v1.as_int(), v1, "Int")?;
let i2: i64 = v2.typed().try_into()?; let i2 = ok_or_type_err(v2.as_int(), v2, "Int")?;
let i3: i64 = v3.typed().try_into()?; let i3 = ok_or_type_err(v3.as_int(), v3, "Int")?;
Ok(i1 == f(i2, i3)) Ok(i1 == f(i2, i3))
} }
pub(crate) fn check_public_key(v1: &Value, v2: &Value) -> Result<bool> { pub(crate) fn check_public_key(v1: &Value, v2: &Value) -> Result<bool> {
let pk: PublicKey = v1.typed().try_into()?; let pk = ok_or_type_err(v1.as_public_key(), v1, "PublicKey")?;
let sk: SecretKey = v2.typed().try_into()?; let sk = ok_or_type_err(v2.as_secret_key(), v2, "SecretKey")?;
Ok(sk.0 < *GROUP_ORDER && pk == sk.public_key()) Ok(sk.0 < *GROUP_ORDER && pk == sk.public_key())
} }
pub(crate) fn check_signed_by(msg: &Value, pk: &Value, sig: &Signature) -> Result<bool> { pub(crate) fn check_signed_by(msg: &Value, pk: &Value, sig: &Signature) -> Result<bool> {
let pk: PublicKey = pk.typed().try_into()?; let pk = ok_or_type_err(pk.as_public_key(), pk, "PublicKey")?;
Ok(sig.verify(pk, msg.raw())) Ok(sig.verify(pk, msg.raw()))
} }
fn check_replace_value_with_entry(
entries: &[Statement],
st_in: &Statement,
expected_st_out: &Statement,
) -> Result<bool> {
if entries.len() != BASE_PARAMS.max_statement_args {
return Ok(false);
}
let args = iter::zip(st_in.args(), entries)
.map(|(arg_in, entry)| match (arg_in, entry) {
(arg_in, Statement::None) => Ok(arg_in),
(
StatementArg::Literal(v_in),
Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
),
) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new(
Hash::from(root.raw()),
Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?),
))),
_ => Err(Error::custom(
"invalid statement argument in ReplaceValueWithEntry",
)),
})
.collect::<Result<Vec<_>>>()?;
let st_out = Statement::from_args(st_in.predicate(), args)?;
Ok(&st_out == expected_st_out)
}
/// Checks the given operation against a statement. /// Checks the given operation against a statement.
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> { pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*; use Statement::*;
@ -428,8 +485,8 @@ impl Operation {
let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err);
let int_val = |v, s| { let int_val = |v, s| {
let v_op = value_from_op(s, v).ok_or_else(deduction_err)?; let v_op = value_from_op(s, v).ok_or_else(deduction_err)?;
match v_op.typed() { match v_op.as_int() {
&TypedValue::Int(i) => Ok(i), Some(i) => Ok(i),
_ => Err(deduction_err()), _ => Err(deduction_err()),
} }
}; };
@ -494,8 +551,7 @@ impl Operation {
&& pf.op_value == value.raw()) && pf.op_value == value.raw())
.then_some(()) .then_some(())
.ok_or(Error::custom( .ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim." "The provided Merkle tree state transition proof does not match the claim.",
.into(),
))?; ))?;
MerkleTree::verify_state_transition(pf)?; MerkleTree::verify_state_transition(pf)?;
true true
@ -515,8 +571,7 @@ impl Operation {
&& pf.op_value == value.raw()) && pf.op_value == value.raw())
.then_some(()) .then_some(())
.ok_or(Error::custom( .ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim." "The provided Merkle tree state transition proof does not match the claim.",
.into(),
))?; ))?;
MerkleTree::verify_state_transition(pf)?; MerkleTree::verify_state_transition(pf)?;
true true
@ -534,8 +589,7 @@ impl Operation {
&& pf.op_key == key.raw()) && pf.op_key == key.raw())
.then_some(()) .then_some(())
.ok_or(Error::custom( .ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim." "The provided Merkle tree state transition proof does not match the claim.",
.into(),
))?; ))?;
MerkleTree::verify_state_transition(pf)?; MerkleTree::verify_state_transition(pf)?;
true true
@ -543,7 +597,19 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index => if batch == &cpr.batch && index == &cpr.index =>
{ {
check_custom_pred(params, cpr, args, s_args).map(|_| true)? // The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let s_args = s_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(deduction_err()),
})
.collect::<Result<Vec<_>>>()?;
check_custom_pred(params, cpr, args, &s_args).map(|_| true)?
}
(Self::ReplaceValueWithEntry(entries, st_in), st_out) => {
Self::check_replace_value_with_entry(entries, st_in, st_out)?
} }
_ => return Err(deduction_err()), _ => return Err(deduction_err()),
}; };
@ -597,6 +663,11 @@ pub fn check_st_tmpl(
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
wc_check_or_set(v.clone(), wc, wildcard_map) wc_check_or_set(v.clone(), wc, wildcard_map)
} }
(StatementTmplArg::SelfPredicateHash(_), _) => {
unreachable!(
"SelfPredicateHash should be normalized to Literal before template matching"
)
}
_ => Err(Error::mismatched_statement_tmpl_arg( _ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(), st_tmpl_arg.clone(),
st_arg.clone(), st_arg.clone(),
@ -645,9 +716,9 @@ pub fn wildcard_values_from_op_st(
params: &Params, params: &Params,
pred: &CustomPredicate, pred: &CustomPredicate,
op_args: &[Statement], op_args: &[Statement],
st_args: &[Value], resolved_st_args: &[Value],
) -> Result<Vec<Value>> { ) -> Result<Vec<Value>> {
let mut wildcard_map = st_args let mut wildcard_map = resolved_st_args
.iter() .iter()
.map(|v| Some(v.clone())) .map(|v| Some(v.clone()))
.chain(core::iter::repeat(None)) .chain(core::iter::repeat(None))
@ -714,7 +785,7 @@ pub(crate) fn check_custom_pred(
args: &[Statement], args: &[Statement],
s_args: &[Value], s_args: &[Value],
) -> Result<()> { ) -> Result<()> {
let pred = custom_pred_ref.predicate(); let pred = custom_pred_ref.normalized_predicate();
if pred.statements.len() != args.len() { if pred.statements.len() != args.len() {
return Err(Error::diff_amount( return Err(Error::diff_amount(
"custom predicate operation".to_string(), "custom predicate operation".to_string(),
@ -733,7 +804,7 @@ pub(crate) fn check_custom_pred(
} }
// Check that the resolved wildcards match the statement arguments. // Check that the resolved wildcards match the statement arguments.
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) {
Ok(wc_values) => wc_values, Ok(wc_values) => wc_values,
Err(Error::Inner { inner, backtrace }) => match *inner { Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
@ -789,9 +860,8 @@ impl fmt::Display for Operation {
pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option<AnchoredKey> { pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option<AnchoredKey> {
let root_hash = Hash::from(root.raw()); let root_hash = Hash::from(root.raw());
Key::try_from(key.typed()) key.as_str()
.map(|key| AnchoredKey::new(root_hash, key)) .map(|s| AnchoredKey::new(root_hash, Key::from(s)))
.ok()
} }
/// Returns the value associated with `output_ref`. /// Returns the value associated with `output_ref`.

View file

@ -311,7 +311,7 @@ pub enum Statement {
/* old_root */ ValueRef, /* old_root */ ValueRef,
/* key */ ValueRef, /* key */ ValueRef,
), ),
Custom(CustomPredicateRef, Vec<Value>), Custom(CustomPredicateRef, Vec<ValueRef>),
Intro(IntroPredicateRef, Vec<Value>), Intro(IntroPredicateRef, Vec<Value>),
} }
@ -407,7 +407,7 @@ impl Statement {
vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()] vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()]
} }
Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)), Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)),
Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)), Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)),
} }
} }
@ -478,14 +478,11 @@ impl Statement {
} }
(BatchSelf(_), _) => unreachable!(), (BatchSelf(_), _) => unreachable!(),
(Custom(cpr), _) => { (Custom(cpr), _) => {
let v_args: Result<Vec<Value>> = args let v_args = args
.iter() .iter()
.map(|x| match x { .map(|x| x.try_into())
StatementArg::Literal(v) => Ok(v.clone()), .collect::<Result<Vec<ValueRef>>>()?;
_ => Err(Error::incorrect_statements_args()), Self::Custom(cpr, v_args)
})
.collect();
Self::Custom(cpr, v_args?)
} }
(Intro(ir), _) => { (Intro(ir), _) => {
let v_args: Result<Vec<Value>> = args let v_args: Result<Vec<Value>> = args