Compare commits

...

10 commits

Author SHA1 Message Date
5e3ac9a101
Support mixed depth container merkle proofs (#508)
Some checks failed
Rust Build with features / Rust tests (push) Has been cancelled
Clippy Check / Rust formatting (push) Has been cancelled
Publish MainPod circuit info / Update Wiki with new MainPod circuit info (push) Has been cancelled
Check mdbook compilation / compile (push) Has been cancelled
Publish mdbook / build (push) Has been cancelled
Rustfmt Check / Rust formatting (push) Has been cancelled
Rust Tests / Rust tests (push) Has been cancelled
typos / Spell Check with Typos (push) Has been cancelled
Publish mdbook / deploy (push) Has been cancelled
* remove enabled flag from merkle tree proofs

* add small existence mpt proofs in MainPod

* refactor params, add small transition proofs

* complete

* fix edge case in vdset

* fix: use existence only proof for vdset

* use consistent order for aux table
2026-05-06 12:39:27 +02:00
Rob Knight
111b132a00
Use projected statement lookup for op arg resolution (#503)
* Use projected statement lookup for op arg resolution

* Add projected op-arg index coverage test

* Tidying and reorganising
2026-04-29 09:56:39 +02:00
Rob Knight
8844fe124c
Diagnostics for MultiPodBuilder (#500)
* Diagnostics for MultiPodBuilder

* Reduce duplication
2026-04-23 01:41:29 -07:00
Rob Knight
3203c883e5
Use windowed ECAddXuGate for PublicKeyOf (#501) 2026-04-20 00:07:20 -07:00
dbd958dcca
Allow entries as args in custom statements (#498)
- Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement.
- Allow entries as args in custom statements
- Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement.
2026-04-01 23:49:29 +02:00
Rob Knight
22d25e5cb2
Podlang syntax for quoted predicates (#495) 2026-03-30 15:16:19 +01:00
a4069bcc55
Fix pod builder (#496)
Several fixes and code simplifications:
- MainPodBuilder
  - Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments.
  - Enhancement: Deduplicate statements
- MultiPodBuilder
    - Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do)
    - Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver.
    - Fix: Count and constrain custom predicates used in a pod instead of batches used
2026-03-25 18:48:28 +01:00
Rob Knight
1e592e11cf
Self-referential predicate hashes as statement template args (#494)
* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
2026-03-24 07:25:11 -07:00
13cabdb511
Support persistent storage in Containers (#493)
Extend the work of https://github.com/0xPARC/pod2/pull/487 to the Containers (Dictionary, Set, Array).

The merkle tree only stores `RawValue` for both the key and the value, so it is the responsibility of the Container to store the rich value.

In order to handle containers with persistent storage efficiently (which means, cloning them or updating them should not cause an O(n) data copy) I figured we need to have a database of `Value`s indexed by their raw value; as this gives us deduplication and free cloning of containers.
The issue with this approach is that in the current design we have collisions between Value's of different types: https://github.com/0xPARC/pod2/issues/426 and the current API relies on the single type of values.

To resolve this issue I decided to change the API, instead of assuming that a Value has a fixed type, let the value be possibly multiple compatible types and let the user of the library try casting the Value to a particular type.
For this I deprecated the public access of everything related to `TypedValue` and I propose for it to be considered an implementation detail and a blackbox from the external developer point of view.  The `Value` type is now used like this:
- To create a new Value use `Value::from(...)` where you can pass any compatible type (the same types as before)
- To access the Value in typed form you cast it like `value.as_foo()` which returns `Option<Foo>`.

Previously we had a collision between `true` and `1` (and `false` and `0`).  Now it doesn't matter whether a value holds a `true` or a `1`, both should be seen as the same and both return `Some` when doing `as_int` and `as_bool`.

Similarly we had collisions with containers.  For example `set(0, 1, 2) == array[0, 1, 2]` and `set("a", "b") = dict("a": "a", "b": "b")`.  Now any container can be casted to any of `set, array, dict`.  There's a caveat here: each of these types expects a particular encoding of keys, so casting to the wrong type will return errors on some operations.

With this design it no longer matters what is being stored and recovered because the API requires the user to express the expected type and any type with collisions for particular values can be casted to the right type.

There's only one case where it's not desirable to swap one `TypedValue` for another: the `TypedValue::Raw`.  If a non-`RawValue` in the DB is replaced by the corresponding `RawValue` we erase the required information to recover the rich value.  For this reason the implementations of the database treat the `RawValue` as a special case: if an value is stored in non-`RawValue`, the corresponding `RawValue` can never overwrite it.  If a value is stored in `RawValue`, a matching non-`RawValue` will overwrite it (promoting it to a rich value).  This way we never lose data.

A consequence of this is that the serialization, `Display` and `Debug` of a container is not stable.  At any point any of the entries can be swapped for a "compatible" one if they share the storage with other containers that introduce collisions.

I rewrote all containers as wrapper to a generic `Container` which holds a `Map` from `Value` to `Value`.  The serialization of each container now uses the single implementation of the generic `Container`.
2026-03-23 12:31:28 +01:00
arnaucube
32f45872d7
Re-implement merkletree with persistent storage (key-value db) (#487)
* refactor merkletree to work with disk keyvalue database (wip)

* various fixes post reimplementation; pending delete leaf

* add delete operation case for the new in db tree approach

* polish tree update & delete; everything works (pending polishing)

* polish panics into errs, prints, etc

* Implement iterator

* Lint

* fix case no-siblings

* case delete with semi-empty branch

* polishing

* starting to add rocksdb & heeddb for the DB & Txn traits

* Satisfy the borrow checker

* abstract merkletree tests to use the various available DBs

* update store_node interface (rm hash input), rm heed.rs

* polishing

* typos

* Ditch transactions

* add feature for rocksdb, return errs at new_with_db, remove empty leaf case in Leaf::new

* intermediate instead of leaf in empty node when deleting leaf

---------

Co-authored-by: Ahmad <root@ahmadafuni.com>
2026-03-11 16:32:42 +01:00
50 changed files with 6517 additions and 3603 deletions

View file

@ -24,6 +24,8 @@ jobs:
run: cargo build --features metrics
- name: Build time
run: cargo build --features time
- name: Build db_rocksdb
run: cargo build --features db_rocksdb
- name: Build disk_cache
run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache

View file

@ -17,4 +17,5 @@ jobs:
- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
- name: Run tests
run: cargo test --release
# RocksDB is disabled by default but we still want to test it.
run: cargo test --release --features db_rocksdb

View file

@ -48,6 +48,7 @@ good_lp = { version = "1.8", default-features = false, features = [
"scip_bundled",
] }
annotate-snippets = "0.11"
rocksdb = { version = "0.24.0", optional = true } # keyvalue database for merkletree
# Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory.
# [patch."https://github.com/0xPARC/plonky2"]
@ -57,6 +58,7 @@ annotate-snippets = "0.11"
pretty_assertions = "1.4.1"
# Used only for testing JSON Schema generation and validation.
jsonschema = "0.30.0"
tempfile = "3"
[build-dependencies]
vergen-gitcl = { version = "1.0.0", features = ["build"] }
@ -70,6 +72,7 @@ time = []
examples = []
disk_cache = ["directories", "minicbor-serde"]
mem_cache = []
db_rocksdb = ["rocksdb"]
# Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release.
# [profile.release]

View file

@ -51,7 +51,7 @@ use crate::{
mainpod::cache_get_rec_main_pod_verifier_circuit_data,
primitives::merkletree::MerkleClaimAndProof,
},
middleware::{containers::Array, Hash, Params, RawValue, Result, Value},
middleware::{containers::Array, Hash, Params, RawValue, Result, Value, EMPTY_HASH},
};
pub static DEFAULT_VD_LIST: LazyLock<Vec<VerifierOnlyCircuitData>> = LazyLock::new(|| {
@ -95,6 +95,12 @@ impl Eq for VDSet {}
impl VDSet {
fn new_from_vds_hashes(mut vds_hashes: Vec<Hash>) -> Self {
// If vds_hashes is empty we add an zero entry to be used as padding when verifying merkle
// proofs of inclusion in the vds set. This zero entry can't be abused because no circuit
// exists with a vds_hash = 0.
if vds_hashes.is_empty() {
vds_hashes.push(EMPTY_HASH);
}
// before using the hash values, sort them, so that each set of
// verifier_datas gets the same VDSet root
vds_hashes.sort();
@ -150,6 +156,9 @@ impl VDSet {
))?
.clone())
}
pub fn get_vds_proof_0(&self) -> MerkleClaimAndProof {
self.proofs_map[&self.vds_hashes[0]].clone()
}
/// Returns true if the `verifier_data_hash` is in the set
pub fn contains(&self, verifier_data_hash: HashOut) -> bool {
self.proofs_map

View file

@ -25,20 +25,20 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::{
basetypes::{CircuitBuilder, CommonCircuitData, D},
circuits::mainpod::CustomPredicateVerification,
circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator},
error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{
verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget,
MerkleProof, MerkleTreeStateTransitionProofTarget,
MerkleProof, MerkleProofExistenceTarget, MerkleTreeStateTransitionProofTarget,
},
},
middleware::{
hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate,
OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix,
PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg,
StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN,
VALUE_SIZE,
StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE,
STATEMENT_ARG_F_LEN, VALUE_SIZE,
},
};
@ -103,6 +103,20 @@ pub struct StatementArgTarget {
pub elements: [Target; STATEMENT_ARG_F_LEN],
}
impl Flattenable for StatementArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"),
}
}
fn size(_params: &Params) -> usize {
STATEMENT_ARG_F_LEN
}
}
impl StatementArgTarget {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?)
@ -318,7 +332,7 @@ impl OperationTarget {
.args()
.iter()
.chain(iter::repeat(&OperationArg::None))
.take(params.max_operation_args)
.take(BASE_PARAMS.max_operation_args)
.enumerate()
{
self.args[i].set_targets(pw, arg.as_usize())?;
@ -328,7 +342,7 @@ impl OperationTarget {
fn size(params: &Params) -> usize {
OperationTypeTarget::size(params)
+ params.max_operation_args * IndexTarget::size(params)
+ BASE_PARAMS.max_operation_args * IndexTarget::size(params)
+ IndexTarget::size(params)
}
}
@ -711,7 +725,6 @@ impl CustomPredicateInBatchTarget {
let mtp =
MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder);
let _true = builder._true();
builder.connect(_true.target, mtp.enabled.target);
builder.connect(_true.target, mtp.existence.target);
let zero = builder.constant(F(0));
let key = ValueTarget {
@ -749,7 +762,7 @@ impl CustomPredicateInBatchTarget {
value: RawValue::from(hash_fields(&predicate.to_fields())),
proof: mtp.clone(),
};
self.mtp.set_targets(pw, true, &mtp_claim)?;
self.mtp.set_targets(pw, &mtp_claim)?;
Ok(())
}
}
@ -771,7 +784,8 @@ impl CustomPredicateEntryTarget {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
// Replace statement templates of batch-self with (id,index)
// Replace BatchSelf predicates with Custom(batch, i), and
// SelfPredicateHash args with Literal(hash(Custom(batch, i)))
let batch = &predicate.batch;
let predicate = predicate.predicate();
let statements = predicate
@ -788,10 +802,22 @@ impl CustomPredicateEntryTarget {
}
x => x.clone(),
};
StatementTmpl {
pred_or_wc,
args: st_tmpl.args,
let args = st_tmpl
.args
.into_iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other,
})
.collect();
StatementTmpl { pred_or_wc, args }
})
.collect_vec();
let predicate = CustomPredicate {
@ -855,7 +881,7 @@ impl CustomPredicateVerifyEntryTarget {
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect(),
op_args: (0..params.max_operation_args)
op_args: (0..BASE_PARAMS.max_operation_args)
.map(|_| builder.add_virtual_statement(false))
.collect(),
}
@ -885,7 +911,7 @@ impl CustomPredicateVerifyEntryTarget {
cpv.op_args
.iter()
.chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args),
.take(BASE_PARAMS.max_operation_args),
) {
op_arg_target.set_targets(pw, op_arg)?
}
@ -928,7 +954,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
.expect("len = operation_type_size"),
};
let (pos, size) = (pos + size, StatementTarget::size(params));
let op_args = (0..params.max_operation_args)
let op_args = (0..BASE_PARAMS.max_operation_args)
.map(|i| {
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
})
@ -940,7 +966,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
}
}
fn size(params: &Params) -> usize {
StatementTarget::size(params) * (1 + params.max_operation_args)
StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args)
+ OperationTarget::size(params)
}
}
@ -960,7 +986,6 @@ pub trait Flattenable {
/// elsewhere.
#[derive(Copy, Clone)]
pub struct MerkleClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget,
@ -970,7 +995,6 @@ pub struct MerkleClaimTarget {
impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
fn from(pf: MerkleClaimAndProofTarget) -> Self {
Self {
enabled: pf.enabled,
root: pf.root,
key: pf.key,
value: pf.value,
@ -979,12 +1003,25 @@ impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
}
}
impl MerkleClaimTarget {
pub fn from_proof_existence(
builder: &mut CircuitBuilder,
pf: MerkleProofExistenceTarget,
) -> Self {
Self {
root: pf.root,
key: pf.key,
value: pf.value,
existence: builder._true(),
}
}
}
/// For the purpose of op verification, we need only look up the
/// Merkle state transition claim rather than the Merkle state
/// transition proof since it is verified elsewhere.
#[derive(Copy, Clone)]
pub struct MerkleTreeStateTransitionClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) op: Target,
pub(crate) old_root: HashOutTarget,
pub(crate) new_root: HashOutTarget,
@ -995,7 +1032,6 @@ pub struct MerkleTreeStateTransitionClaimTarget {
impl From<MerkleTreeStateTransitionProofTarget> for MerkleTreeStateTransitionClaimTarget {
fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self {
Self {
enabled: pf.enabled,
op: pf.op,
old_root: pf.old_root,
new_root: pf.new_root,
@ -1036,7 +1072,6 @@ impl Flattenable for ValueTarget {
impl Flattenable for MerkleClaimTarget {
fn flatten(&self) -> Vec<Target> {
[
vec![self.enabled.target],
self.root.elements.to_vec(),
self.key.elements.to_vec(),
self.value.elements.to_vec(),
@ -1048,31 +1083,28 @@ impl Flattenable for MerkleClaimTarget {
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
enabled: BoolTarget::new_unsafe(vs[0]),
root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()),
key: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE],
),
root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()),
key: ValueTarget::from_slice(&vs[NUM_HASH_OUT_ELTS..NUM_HASH_OUT_ELTS + VALUE_SIZE]),
value: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
&vs[NUM_HASH_OUT_ELTS + VALUE_SIZE..NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
),
existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]),
existence: BoolTarget::new_unsafe(vs[NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]),
}
}
fn size(params: &Params) -> usize {
2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params)
HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1
}
}
impl Flattenable for MerkleTreeStateTransitionClaimTarget {
fn flatten(&self) -> Vec<Target> {
[
vec![self.enabled.target, self.op],
self.old_root.elements.to_vec(),
self.new_root.elements.to_vec(),
self.op_key.elements.to_vec(),
self.op_value.elements.to_vec(),
vec![self.op],
]
.concat()
}
@ -1080,24 +1112,22 @@ impl Flattenable for MerkleTreeStateTransitionClaimTarget {
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
enabled: BoolTarget::new_unsafe(vs[0]),
op: vs[1],
old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()),
old_root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()),
new_root: HashOutTarget::from_vec(
vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(),
vs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS].to_vec(),
),
op_key: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE],
&vs[2 * NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS + VALUE_SIZE],
),
op_value: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE
..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE],
&vs[2 * NUM_HASH_OUT_ELTS + VALUE_SIZE..2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
),
op: vs[2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
}
}
fn size(params: &Params) -> usize {
2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params)
2 * HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1
}
}
@ -1335,6 +1365,18 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
/// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then
/// materializes the selected row via a witness generator and constrains its hash. Cheaper than
/// `vec_ref` when each entry has many fields, since random access runs only over the 4-field
/// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and
/// reusing the same slices across multiple lookups.
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T;
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
@ -1412,7 +1454,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget {
op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args)
args: (0..BASE_PARAMS.max_operation_args)
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
.collect(),
aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self),
@ -1722,7 +1764,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
let num_chunks = array.len().div_ceil(CHUNK_LEN);
for chunk in array.chunks(CHUNK_LEN) {
let mut index_chunk = i.low;
// I we have several chunks and the last one is smaller (it's index needs less than 6
// If we have several chunks and the last one is smaller (it's index needs less than 6
// bits), make it zero except when it's used so that the range check over the index
// passes.
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {
@ -1737,12 +1779,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
self.random_access(i.high, chunk_res)
}
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len();
@ -1766,6 +1802,28 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
}
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T {
assert_eq!(ts_flattened.len(), ts_hashes.len());
let selected_hash = self.vec_ref(params, ts_hashes, i);
let selected_flattened = self.add_virtual_targets(T::size(params));
let selected_flattened_hash =
self.hash_n_to_hash_no_pad::<PoseidonHash>(selected_flattened.clone());
self.connect_hashes(selected_hash, selected_flattened_hash);
let result = T::from_flattened(params, &selected_flattened);
self.add_simple_generator(TableGetGenerator::new(
i.clone(),
ts_flattened.to_vec(),
selected_flattened,
));
result
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero();
self.vec_ref(
@ -2012,7 +2070,7 @@ pub(crate) mod tests {
// Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish();
let custom_predicate_batch = cpb_builder.finish()?;
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
// Some cases from the examples

File diff suppressed because it is too large Load diff

View file

@ -107,11 +107,11 @@ impl MuxTableTarget {
rev_resolved_tagged_flattened.reverse();
let resolved_tagged_flattened = rev_resolved_tagged_flattened;
builder.add_simple_generator(TableGetGenerator {
index: index.clone(),
tagged_entries: self.tagged_entries.clone(),
get_tagged_entry: resolved_tagged_flattened.clone(),
});
builder.add_simple_generator(TableGetGenerator::new(
index.clone(),
self.tagged_entries.clone(),
resolved_tagged_flattened.clone(),
));
measure_gates_end!(builder, measure);
TableEntryTarget {
params: self.params.clone(),
@ -123,8 +123,18 @@ impl MuxTableTarget {
#[derive(Debug, Clone, Default)]
pub struct TableGetGenerator {
index: IndexTarget,
tagged_entries: Vec<Vec<Target>>,
get_tagged_entry: Vec<Target>,
entries: Vec<Vec<Target>>,
revealed_entry: Vec<Target>,
}
impl TableGetGenerator {
pub fn new(index: IndexTarget, entries: Vec<Vec<Target>>, revealed_entry: Vec<Target>) -> Self {
Self {
index,
entries,
revealed_entry,
}
}
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
@ -135,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
fn dependencies(&self) -> Vec<Target> {
[self.index.low, self.index.high]
.into_iter()
.chain(self.tagged_entries.iter().flatten().copied())
.chain(self.entries.iter().flatten().copied())
.collect()
}
@ -148,12 +158,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
let index_high = witness.get_target(self.index.high);
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
let entry = witness.get_targets(&self.tagged_entries[index as usize]);
let entry = witness.get_targets(&self.entries[index as usize]);
for (target, value) in self.get_tagged_entry.iter().zip(
for (target, value) in self.revealed_entry.iter().zip(
entry
.iter()
.chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())),
.chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())),
) {
out_buffer.set_target(*target, *value)?;
}
@ -166,12 +176,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
dst.write_target(self.index.low)?;
dst.write_target(self.index.high)?;
dst.write_usize(self.tagged_entries.len())?;
for tagged_entry in &self.tagged_entries {
dst.write_target_vec(tagged_entry)?;
dst.write_usize(self.entries.len())?;
for entry in &self.entries {
dst.write_target_vec(entry)?;
}
dst.write_target_vec(&self.get_tagged_entry)
dst.write_target_vec(&self.revealed_entry)
}
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
@ -181,16 +191,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
high: src.read_target()?,
};
let len = src.read_usize()?;
let mut tagged_entries = Vec::with_capacity(len);
let mut entries = Vec::with_capacity(len);
for _ in 0..len {
tagged_entries.push(src.read_target_vec()?);
entries.push(src.read_target_vec()?);
}
let get_tagged_entry = src.read_target_vec()?;
let revealed_entry = src.read_target_vec()?;
Ok(Self {
index,
tagged_entries,
get_tagged_entry,
entries,
revealed_entry,
})
}
}

View file

@ -61,8 +61,8 @@ macro_rules! new {
}
use InnerError::*;
impl Error {
pub fn custom(s: String) -> Self {
new!(Custom(s))
pub fn custom(s: impl Into<String>) -> Self {
new!(Custom(s.into()))
}
pub fn plonky2_proof_fail(context: impl Into<String>, e: anyhow::Error) -> Self {
Self::Plonky2ProofFail(context.into(), e)

View file

@ -1,5 +1,5 @@
pub mod operation;
use crate::middleware::{wildcard_values_from_op_st, PodType};
use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS};
pub mod statement;
use std::iter;
@ -39,7 +39,7 @@ use crate::{
middleware::{
self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs,
MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
ToFields, VDSet, Value,
ToFields, VDSet, Value, ValueRef,
},
timed,
};
@ -104,8 +104,20 @@ pub(crate) fn extract_custom_predicate_verifications(
if let middleware::Operation::Custom(cpr, sts) = op {
if let middleware::Statement::Custom(st_cpr, st_args) = st {
assert_eq!(cpr, st_cpr);
// The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let st_args = st_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(Error::custom(
"custom operation cannot output entries as arguments",
)),
})
.collect::<Result<Vec<_>>>()?;
let normalized_pred = cpr.normalized_predicate();
let wildcard_values =
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args)
wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args)
.expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let custom_predicate_table_index = custom_predicates
@ -136,14 +148,20 @@ pub(crate) fn extract_custom_predicate_verifications(
Ok(table)
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MerkleProofs {
pub(crate) medium: Vec<MerkleClaimAndProof>,
pub(crate) small: Vec<MerkleClaimAndProof>,
}
/// Extracts Merkle proofs from Contains/NotContains ops.
pub(crate) fn extract_merkle_proofs(
params: &Params,
aux_list: &mut [OperationAux],
operations: &[middleware::Operation],
statements: &[middleware::Statement],
) -> Result<Vec<MerkleClaimAndProof>> {
let mut table = Vec::new();
) -> Result<MerkleProofs> {
let mut tables = MerkleProofs::default();
for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() {
let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone());
let (root, key, value, pf) = match (op, st) {
@ -166,31 +184,42 @@ pub(crate) fn extract_merkle_proofs(
}
_ => continue,
};
aux_list[i] = OperationAux::MerkleProofIndex(table.len());
table.push(MerkleClaimAndProof::new(
Hash::from(root),
key,
value,
pf.clone(),
));
let claim_proof = MerkleClaimAndProof::new(Hash::from(root), key, value, pf.clone());
if pf.existence
// TODO: Make sure there's no off-by-one error here
&& pf.siblings.len() <= params.containers.max_depth_small
&& tables.small.len() < params.containers.state.max_small
{
aux_list[i] = OperationAux::MerkleProofIndex(Size::Small, tables.small.len());
tables.small.push(claim_proof);
} else {
aux_list[i] = OperationAux::MerkleProofIndex(Size::Medium, tables.medium.len());
tables.medium.push(claim_proof);
}
if table.len() > params.max_merkle_proofs_containers {
}
if tables.medium.len() > params.containers.state.max_medium {
return Err(Error::custom(format!(
"The number of required Merkle proofs ({}) exceeds the maximum number ({}).",
table.len(),
params.max_merkle_proofs_containers
tables.medium.len(),
params.containers.state.max_medium
)));
}
Ok(table)
Ok(tables)
}
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MerkleTransitionProofs {
pub(crate) medium: Vec<MerkleTreeStateTransitionProof>,
pub(crate) small: Vec<MerkleTreeStateTransitionProof>,
}
/// Extracts Merkle state transition proofs from container update ops.
pub(crate) fn extract_merkle_tree_state_transition_proofs(
pub(crate) fn extract_merkle_transition_proofs(
params: &Params,
aux_list: &mut [OperationAux],
operations: &[middleware::Operation],
) -> Result<Vec<MerkleTreeStateTransitionProof>> {
let mut table = Vec::new();
) -> Result<MerkleTransitionProofs> {
let mut tables = MerkleTransitionProofs::default();
for (i, op) in operations.iter().enumerate() {
let pf = match op {
middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf)
@ -198,17 +227,27 @@ pub(crate) fn extract_merkle_tree_state_transition_proofs(
| middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(),
_ => continue,
};
aux_list[i] = OperationAux::MerkleTreeStateTransitionProofIndex(table.len());
table.push(pf);
if pf.op_proof.existence
// TODO: Make sure there's no off-by-one error here
&& pf.siblings.len() <= params.containers.max_depth_small
&& tables.small.len() < params.containers.transition.max_small
{
aux_list[i] = OperationAux::MerkleTransitionProofIndex(Size::Small, tables.small.len());
tables.small.push(pf);
} else {
aux_list[i] =
OperationAux::MerkleTransitionProofIndex(Size::Medium, tables.medium.len());
tables.medium.push(pf);
}
if table.len() > params.max_merkle_tree_state_transition_proofs_containers {
}
if tables.medium.len() > params.containers.transition.max_medium {
return Err(Error::custom(format!(
"The number of required Merkle proofs ({}) exceeds the maximum number ({}).",
table.len(),
params.max_merkle_tree_state_transition_proofs_containers
tables.medium.len(),
params.containers.transition.max_medium
)));
}
Ok(table)
Ok(tables)
}
pub(crate) fn extract_public_key_of(
@ -225,11 +264,10 @@ pub(crate) fn extract_public_key_of(
) = (op, st)
{
let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone());
let sk = SecretKey::try_from(
value_from_op(sk_s, sk_ref)
.ok_or_else(deduction_err)?
.typed(),
)?;
let value = value_from_op(sk_s, sk_ref).ok_or_else(deduction_err)?;
let sk = value
.as_secret_key()
.ok_or_else(|| Error::custom("{value} not SecretKey"))?;
aux_list[i] = OperationAux::PublicKeyOfIndex(table.len());
table.push(sk);
}
@ -283,7 +321,9 @@ pub(crate) fn extract_signatures(
aux_list[i] = OperationAux::SignedByIndex(table.len());
table.push(SignedBy {
msg: msg.raw(),
pk: PublicKey::try_from(pk.typed())?,
pk: pk
.as_public_key()
.ok_or_else(|| Error::custom(format!("{pk} is not PublicKey")))?,
sig: sig.clone(),
});
}
@ -327,8 +367,8 @@ pub fn pad_statement(s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args())
}
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, params.max_operation_args)
fn pad_operation_args(args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args)
}
/// Returns the statements from the given MainPodInputs, padding to the respective max lengths
@ -426,7 +466,7 @@ pub(crate) fn process_private_statements_operations(
.map(|mid_arg| find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>>>()?;
pad_operation_args(params, &mut args);
pad_operation_args(&mut args);
operations.push(Operation(op.op_type(), args, *aux));
}
Ok(operations)
@ -457,7 +497,11 @@ pub(crate) fn process_public_statements_operations(
OperationAux::None,
)
};
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
fill_pad(
&mut op.1,
OperationArg::None,
BASE_PARAMS.max_operation_args,
);
operations.push(op);
}
Ok(operations)
@ -467,6 +511,7 @@ pub struct Prover {}
impl MainPodProver for Prover {
fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result<Box<dyn Pod>> {
assert_eq!(inputs.statements.len(), inputs.operations.len());
// Pad input recursive pods with empty pods if necessary
let empty_pod = if inputs.pods.len() == params.max_input_pods {
// We don't need padding so we skip creating an EmptyPod
@ -495,6 +540,8 @@ impl MainPodProver for Prover {
let mut aux_list = vec![OperationAux::None; params.max_priv_statements()];
let merkle_proofs =
extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?;
let merkle_transition_proofs =
extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?;
let custom_predicates = extract_custom_predicates(params, inputs.operations)?;
let custom_predicate_verifications = extract_custom_predicate_verifications(
params,
@ -519,9 +566,6 @@ impl MainPodProver for Prover {
let signed_bys =
extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?;
let merkle_tree_state_transition_proofs =
extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?;
let (statements, public_statements) = layout_statements(params, false, &inputs)?;
let operations = process_private_statements_operations(
params,
@ -554,20 +598,15 @@ impl MainPodProver for Prover {
.collect_vec();
let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len());
let pad_vd_mt_proof = inputs.vd_set.get_vds_proof_0();
for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) {
vd_mt_proofs.push(if pod.is_main() {
(true, inputs.vd_set.get_vds_proof(vd)?)
inputs.vd_set.get_vds_proof(vd)?
} else {
// For intro pods we don't verify inclusion of their vk into the vd set, so we
// generate a dummy mt proof with expected root and value to pass some constraints
(
false,
MerkleClaimAndProof {
root: inputs.vd_set.root(),
value: RawValue::from(pod.verifier_data_hash()),
..MerkleClaimAndProof::empty()
},
)
// use a valid vds proof that matches the expected root but not the value to pass
// the constraints
pad_vd_mt_proof.clone()
});
}
@ -580,7 +619,7 @@ impl MainPodProver for Prover {
merkle_proofs,
public_key_of_sks,
signed_bys,
merkle_tree_state_transition_proofs,
merkle_transition_proofs,
custom_predicates_with_mpt_proofs,
custom_predicate_verifications,
};
@ -967,7 +1006,18 @@ pub mod tests {
max_statements: 2,
max_public_statements: 1,
max_input_pods_public_statements: 0,
max_merkle_proofs_containers: 0,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
max_public_key_of: 0,
max_custom_predicate_verifications: 0,
max_custom_predicates: 0,
@ -1003,15 +1053,23 @@ pub mod tests {
max_input_pods_public_statements: 2,
max_statements: 5,
max_public_statements: 2,
max_operation_args: 5,
max_custom_predicates: 2,
max_custom_predicate_verifications: 2,
max_custom_predicate_wildcards: 3,
max_merkle_proofs_containers: 2,
max_merkle_tree_state_transition_proofs_containers: 2,
max_public_key_of: 2,
max_depth_mt_containers: 4,
max_depth_mt_vds: 6,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 2,
max_medium: 2,
},
transition: middleware::ParamsMerkleProofs {
max_small: 2,
max_medium: 2,
},
max_depth_small: 2,
max_depth_medium: 4,
},
};
let mut vds = DEFAULT_VD_LIST.clone();
vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone());
@ -1068,11 +1126,20 @@ pub mod tests {
max_input_pods: 0,
max_statements: 9,
max_public_statements: 4,
max_operation_args: 5,
max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 3,
max_merkle_tree_state_transition_proofs_containers: 0,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 3,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
..Default::default()
};
println!("{:#?}", params);
@ -1095,7 +1162,7 @@ pub mod tests {
&[stb0.clone(), stb1.clone()],
)?;
let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?;
let cpb = cpb_builder.finish();
let cpb = cpb_builder.finish()?;
let cpb_and = CustomPredicateRef::new(cpb.clone(), 0);
let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1);
@ -1129,6 +1196,72 @@ pub mod tests {
Ok(pod.verify()?)
}
#[test]
fn test_main_self_predicate_hash() -> frontend::Result<()> {
use frontend::BuilderArg;
let params = Params {
max_signed_by: 0,
max_input_pods: 0,
max_statements: 6,
max_public_statements: 2,
max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2,
containers: middleware::ParamsContainers {
state: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
transition: middleware::ParamsMerkleProofs {
max_small: 0,
max_medium: 0,
},
max_depth_small: 8,
max_depth_medium: 32,
},
..Default::default()
};
let mut vds = DEFAULT_VD_LIST.clone();
vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone());
let vd_set = VDSet::new(&vds);
// Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
let stb_b = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = cpb.finish()?;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut pod_builder = MainPodBuilder::new(&params, &vd_set);
let eq_st =
pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?;
pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?;
// Mock
let prover = MockProver {};
let pod = pod_builder.prove(&prover)?;
assert!(pod.pod.verify().is_ok());
// Real
let prover = Prover {};
let pod = pod_builder.prove(&prover)?;
let pod = (pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
Ok(pod.verify()?)
}
#[test]
fn test_set_contains() -> frontend::Result<()> {
let params = Params::default();
@ -1192,10 +1325,108 @@ pub mod tests {
);
let st = middleware::Statement::Custom(
cpr,
[1, 1, 2].into_iter().map(middleware::Value::from).collect(),
[1, 1, 2]
.into_iter()
.map(middleware::ValueRef::from)
.collect(),
);
builder.insert(true, (st, op)).unwrap();
builder.insert((st.clone(), op)).unwrap();
builder.reveal(&st).unwrap();
let prover = Prover {};
builder.prove(&prover).unwrap();
}
#[test]
fn test_replace_value_with_entry() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"a" => 42, "b" => 33});
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42))
.unwrap();
let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap();
// Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)`
builder
.pub_op(frontend::Operation::replace_value_with_entry(
vec![None, Some((&d, "a"))],
st,
))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
assert_eq!(
middleware::Statement::Lt(
middleware::ValueRef::Literal(Value::from(5)),
middleware::ValueRef::Key(middleware::AnchoredKey {
root: d.commitment(),
key: middleware::Key::from("a")
})
),
pod.public_statements[0]
);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
#[test]
fn test_entry_custom_statement_arg() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let input = r#"
PredA(x) = AND(
Lt(x, 100)
)
PredB(d) = AND(
PredA(d.x)
)
"#;
let module = load_module(input, "my_mod", &params, &[]).expect("lang parse");
let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap();
let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"x" => 42, "y" => 33});
let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap();
let st_a = builder
.priv_op(frontend::Operation::custom(pred_a, [st_lt]))
.unwrap();
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42))
.unwrap();
// Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)`
let st_a1 = builder
.priv_op(frontend::Operation::replace_value_with_entry(
vec![Some((&d, "x"))],
st_a,
))
.unwrap();
builder
.pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1]))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
let expected = middleware::Statement::Custom(
pred_b,
vec![middleware::ValueRef::Literal(Value::from(d))],
);
assert_eq!(expected, pod.public_statements[0]);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
}

View file

@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::{
error::{Error, Result},
mainpod::{SignedBy, Statement},
primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof},
mainpod::{MerkleProofs, MerkleTransitionProofs, SignedBy, Statement},
},
middleware::{self, OperationType, Params},
};
@ -30,50 +29,89 @@ impl OperationArg {
}
}
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum Size {
Small,
Medium,
}
impl fmt::Display for Size {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Small => write!(f, "small"),
Self::Medium => write!(f, "medium"),
}
}
}
impl Size {
pub const fn min() -> Self {
Self::Small
}
pub const fn max() -> Self {
Self::Medium
}
}
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum OperationAux {
None,
MerkleProofIndex(usize),
MerkleProofIndex(Size, usize),
MerkleTransitionProofIndex(Size, usize),
PublicKeyOfIndex(usize),
SignedByIndex(usize),
MerkleTreeStateTransitionProofIndex(usize),
CustomPredVerifyIndex(usize),
}
impl OperationAux {
fn table_offset_merkle_proof(_params: &Params) -> usize {
fn table_offset_merkle_proof(params: &Params, size: Size) -> usize {
match size {
// At index 0 we store a zero entry
1
Size::Small => 1,
Size::Medium => {
Self::table_offset_merkle_proof(params, Size::Small)
+ params.containers.state.max_small
}
}
}
fn table_offset_merkle_transition_proof(params: &Params, size: Size) -> usize {
match size {
Size::Small => {
Self::table_offset_merkle_proof(params, Size::min())
+ params.containers.state.max_total()
}
Size::Medium => {
Self::table_offset_merkle_transition_proof(params, Size::Small)
+ params.containers.transition.max_small
}
}
}
fn table_offset_custom_pred_verify(params: &Params) -> usize {
Self::table_offset_merkle_transition_proof(params, Size::min())
+ params.containers.transition.max_total()
}
fn table_offset_public_key_of(params: &Params) -> usize {
Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers
Self::table_offset_custom_pred_verify(params) + params.max_custom_predicate_verifications
}
fn table_offset_signed_by(params: &Params) -> usize {
Self::table_offset_public_key_of(params) + params.max_public_key_of
}
fn table_offset_merkle_tree_state_transition_proof(params: &Params) -> usize {
Self::table_offset_signed_by(params) + params.max_signed_by
}
fn table_offset_custom_pred_verify(params: &Params) -> usize {
Self::table_offset_merkle_tree_state_transition_proof(params)
+ params.max_merkle_tree_state_transition_proofs_containers
}
pub(crate) fn table_size(params: &Params) -> usize {
1 + params.max_merkle_proofs_containers
1 + params.containers.state.max_total()
+ params.containers.transition.max_total()
+ params.max_custom_predicate_verifications
+ params.max_public_key_of
+ params.max_signed_by
+ params.max_merkle_tree_state_transition_proofs_containers
+ params.max_custom_predicate_verifications
}
pub fn table_index(&self, params: &Params) -> usize {
match self {
Self::None => 0,
Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i,
Self::MerkleProofIndex(size, i) => Self::table_offset_merkle_proof(params, *size) + *i,
Self::MerkleTransitionProofIndex(size, i) => {
Self::table_offset_merkle_transition_proof(params, *size) + *i
}
Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i,
Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i,
Self::MerkleTreeStateTransitionProofIndex(i) => {
Self::table_offset_merkle_tree_state_transition_proof(params) + *i
}
Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i,
}
}
@ -96,8 +134,8 @@ impl Operation {
&self,
statements: &[Statement],
signatures: &[SignedBy],
merkle_proofs: &[MerkleClaimAndProof],
merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProof],
merkle_proofs: &MerkleProofs,
merkle_transition_proofs: &MerkleTransitionProofs,
) -> Result<crate::middleware::Operation> {
let deref_args = self
.1
@ -113,17 +151,26 @@ impl Operation {
.collect::<Result<Vec<_>>>()?;
let deref_aux = match self.2 {
OperationAux::None => crate::middleware::OperationAux::None,
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None,
OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof(
merkle_proofs
OperationAux::MerkleProofIndex(size, i) => {
let table = match size {
Size::Small => &merkle_proofs.small,
Size::Medium => &merkle_proofs.medium,
};
crate::middleware::OperationAux::MerkleProof(
table
.get(i)
.ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))?
.proof
.clone(),
),
OperationAux::MerkleTreeStateTransitionProofIndex(i) => {
)
}
OperationAux::MerkleTransitionProofIndex(size, i) => {
let table = match size {
Size::Small => &merkle_transition_proofs.small,
Size::Medium => &merkle_transition_proofs.medium,
};
crate::middleware::OperationAux::MerkleTreeStateTransitionProof(
merkle_tree_state_transition_proofs
table
.get(i)
.ok_or(Error::custom(format!(
"Missing Merkle state transition proof index {}",
@ -132,6 +179,7 @@ impl Operation {
.clone(),
)
}
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None,
OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature(
signatures
.get(i)
@ -165,12 +213,14 @@ impl fmt::Display for Operation {
}
match self.2 {
OperationAux::None => (),
OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?,
OperationAux::MerkleProofIndex(size, i) => {
write!(f, " {}_merkle_proof_{:02}", size, i)?
}
OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?,
OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?,
OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?,
OperationAux::MerkleTreeStateTransitionProofIndex(i) => {
write!(f, " merkle_tree_state_transition_proof_{:02}", i)?
OperationAux::MerkleTransitionProofIndex(size, i) => {
write!(f, " {}_merkle_transition_proof_{:02}", size, i)?
}
}
Ok(())

View file

@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::error::{Error, Result},
middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS},
middleware::{
self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS,
},
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -96,15 +98,15 @@ impl TryFrom<Statement> for middleware::Statement {
)))?,
},
Predicate::Custom(cpr) => {
let vs: Vec<Value> = proper_args
let args: Vec<ValueRef> = proper_args
.into_iter()
.filter_map(|arg| match arg {
SA::None => None,
SA::Literal(v) => Some(v),
_ => unreachable!(),
StatementArg::Literal(v) => Some(ValueRef::Literal(v)),
StatementArg::Key(k) => Some(ValueRef::Key(k)),
StatementArg::None => None,
})
.collect();
S::Custom(cpr, vs)
S::Custom(cpr, args)
}
Predicate::Intro(ir) => {
let vs: Vec<Value> = proper_args

View file

@ -11,13 +11,12 @@ use crate::{
basetypes::{Proof, VerifierOnlyCircuitData},
error::{Error, Result},
mainpod::{
calculate_statements_hash, extract_merkle_proofs,
extract_merkle_tree_state_transition_proofs, extract_signatures, layout_statements,
process_private_statements_operations, process_public_statements_operations, Operation,
calculate_statements_hash, extract_merkle_proofs, extract_merkle_transition_proofs,
extract_signatures, layout_statements, process_private_statements_operations,
process_public_statements_operations, MerkleProofs, MerkleTransitionProofs, Operation,
OperationAux, SignedBy, Statement,
},
mock::emptypod::MockEmptyPod,
primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof},
recursion::hash_verifier_data,
},
middleware::{
@ -45,10 +44,10 @@ pub struct MockMainPod {
operations: Vec<Operation>,
// public subset of the `statements` vector
public_statements: Vec<Statement>,
// All Merkle proofs
merkle_proofs_containers: Vec<MerkleClaimAndProof>,
// All Merkle tree state transition proofs
merkle_tree_state_transition_proofs_containers: Vec<MerkleTreeStateTransitionProof>,
// All Merkle proofs for containers
merkle_proofs: MerkleProofs,
// All Merkle tree state transition proofs for containers
merkle_transition_proofs: MerkleTransitionProofs,
// All verified signatures
signatures: Vec<SignedBy>,
}
@ -124,8 +123,8 @@ struct Data {
public_statements: Vec<Statement>,
operations: Vec<Operation>,
statements: Vec<Statement>,
merkle_proofs: Vec<MerkleClaimAndProof>,
merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProof>,
merkle_proofs: MerkleProofs,
merkle_transition_proofs: MerkleTransitionProofs,
signatures: Vec<SignedBy>,
input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>,
}
@ -153,8 +152,8 @@ impl MockMainPod {
let merkle_proofs =
extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?;
// Similarly for Merkle state transition proofs.
let merkle_tree_state_transition_proofs =
extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?;
let merkle_transition_proofs =
extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?;
let signatures =
extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?;
@ -185,8 +184,8 @@ impl MockMainPod {
public_statements,
statements,
operations,
merkle_proofs_containers: merkle_proofs,
merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs,
merkle_proofs,
merkle_transition_proofs,
signatures,
})
}
@ -260,8 +259,8 @@ impl Pod for MockMainPod {
.deref(
&self.statements[..input_statement_offset + i],
&self.signatures,
&self.merkle_proofs_containers,
&self.merkle_tree_state_transition_proofs_containers,
&self.merkle_proofs,
&self.merkle_transition_proofs,
)?
.check_and_log(&self.params, &s.clone().try_into()?)
.map_err(|e| e.into())
@ -321,10 +320,8 @@ impl Pod for MockMainPod {
public_statements: self.public_statements.clone(),
operations: self.operations.clone(),
statements: self.statements.clone(),
merkle_proofs: self.merkle_proofs_containers.clone(),
merkle_tree_state_transition_proofs: self
.merkle_tree_state_transition_proofs_containers
.clone(),
merkle_proofs: self.merkle_proofs.clone(),
merkle_transition_proofs: self.merkle_transition_proofs.clone(),
signatures: self.signatures.clone(),
input_pods,
})
@ -344,7 +341,7 @@ impl Pod for MockMainPod {
operations,
statements,
merkle_proofs,
merkle_tree_state_transition_proofs,
merkle_transition_proofs,
signatures,
input_pods,
} = serde_json::from_value(data)?;
@ -362,8 +359,8 @@ impl Pod for MockMainPod {
public_statements,
operations,
statements,
merkle_proofs_containers: merkle_proofs,
merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs,
merkle_proofs,
merkle_transition_proofs,
signatures,
})
}
@ -380,7 +377,8 @@ pub mod tests {
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
zu_kyc_sign_dict_builders, MOCK_VD_SET,
},
frontend, middleware,
frontend::{self},
middleware,
middleware::{Signer as _, Value},
};

View file

@ -207,7 +207,7 @@ impl Point {
u: *u,
});
points.find(|p| p.is_in_subgroup()).ok_or(Error::custom(
"One of the points must lie in the EC subgroup.".into(),
"One of the points must lie in the EC subgroup.",
))
}
pub fn as_bytes_from_subgroup(&self) -> Result<Vec<u8>, Error> {

View file

@ -32,7 +32,7 @@ use crate::{
circuits::common::{CircuitBuilderPod, ValueTarget},
error::{Error, Result},
primitives::merkletree::{
MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError,
MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, MAX_DEPTH,
},
},
measure_gates_begin, measure_gates_end,
@ -42,8 +42,6 @@ use crate::{
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MerkleClaimAndProofTarget {
pub(crate) max_depth: usize,
// `enabled` determines if the merkleproof verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget,
@ -121,16 +119,9 @@ pub fn verify_merkle_proof_circuit(
let obtained_root =
compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings);
// check that obtained_root==root (from inputs), when enabled==true
let zero = builder.zero();
let expected_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, proof.root.elements[j], zero))
.collect();
let computed_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero))
.collect();
// check that obtained_root==root (from inputs)
for j in 0..HASH_SIZE {
builder.connect(computed_root[j], expected_root[j]);
builder.connect(obtained_root.elements[j], proof.root.elements[j]);
}
measure_gates_end!(builder, measure);
}
@ -139,7 +130,6 @@ impl MerkleClaimAndProofTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
MerkleClaimAndProofTarget {
max_depth,
enabled: builder.add_virtual_bool_target_safe(),
root: builder.add_virtual_hash(),
key: builder.add_virtual_value(),
value: builder.add_virtual_value(),
@ -154,12 +144,7 @@ impl MerkleClaimAndProofTarget {
}
/// assigns the given values to the targets
#[allow(clippy::too_many_arguments)]
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleClaimAndProof,
) -> Result<()> {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, mp: &MerkleClaimAndProof) -> Result<()> {
if mp.proof.siblings.len() > self.max_depth {
return Err(Error::Tree(TreeError::circuit_depth_too_small(
self.max_depth,
@ -167,7 +152,6 @@ impl MerkleClaimAndProofTarget {
)));
}
pw.set_bool_target(self.enabled, enabled)?;
pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?;
pw.set_target_arr(&self.key.elements, &mp.key.0)?;
pw.set_target_arr(&self.value.elements, &mp.value.0)?;
@ -207,8 +191,6 @@ impl MerkleClaimAndProofTarget {
#[derive(Clone, Serialize, Deserialize)]
pub struct MerkleProofExistenceTarget {
max_depth: usize,
// `enabled` determines if the merkleproof verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget,
@ -236,16 +218,9 @@ pub fn verify_merkle_proof_existence_circuit(
let obtained_root =
compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings);
// check that obtained_root==root (from inputs), when enabled==true
let zero = builder.zero();
let expected_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, proof.root.elements[j], zero))
.collect();
let computed_root: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero))
.collect();
// check that obtained_root==root (from inputs)
for j in 0..HASH_SIZE {
builder.connect(computed_root[j], expected_root[j]);
builder.connect(obtained_root.elements[j], proof.root.elements[j]);
}
measure_gates_end!(builder, measure);
@ -256,7 +231,6 @@ impl MerkleProofExistenceTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
MerkleProofExistenceTarget {
max_depth,
enabled: builder.add_virtual_bool_target_safe(),
root: builder.add_virtual_hash(),
key: builder.add_virtual_value(),
value: builder.add_virtual_value(),
@ -265,12 +239,7 @@ impl MerkleProofExistenceTarget {
}
}
/// assigns the given values to the targets
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleClaimAndProof,
) -> Result<()> {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, mp: &MerkleClaimAndProof) -> Result<()> {
assert!(mp.proof.existence); // sanity check
if mp.proof.siblings.len() > self.max_depth {
return Err(Error::Tree(TreeError::circuit_depth_too_small(
@ -279,7 +248,6 @@ impl MerkleProofExistenceTarget {
)));
}
pw.set_bool_target(self.enabled, enabled)?;
pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?;
pw.set_target_arr(&self.key.elements, &mp.key.0)?;
pw.set_target_arr(&self.value.elements, &mp.value.0)?;
@ -456,8 +424,6 @@ fn hash_with_flag_target<H: AlgebraicHasher<F>>(
#[derive(Clone, Serialize, Deserialize)]
pub struct MerkleTreeStateTransitionProofTarget {
pub(crate) max_depth: usize,
// `enabled` determines if the merkleproof state transition verification is enabled
pub(crate) enabled: BoolTarget,
pub(crate) op: Target,
pub(crate) old_root: HashOutTarget,
pub(crate) op_proof: MerkleClaimAndProofTarget,
@ -511,7 +477,6 @@ pub fn verify_merkle_state_transition_circuit(
};
let new_key_proof = MerkleProofExistenceTarget {
max_depth: proof.max_depth,
enabled: proof.enabled,
root,
key: proof.op_key,
value: proof.op_value,
@ -523,13 +488,7 @@ pub fn verify_merkle_state_transition_circuit(
// Insert/Delete: Non-existence
// Update: Existence
let proof_type = is_update;
builder.conditional_assert_eq(
proof.enabled.target,
proof.op_proof.existence.target,
proof_type.target,
);
// 3.2) assert that proof.enabled matches with op_proof.enabled
builder.connect(proof.op_proof.enabled.target, proof.enabled.target);
builder.connect(proof.op_proof.existence.target, proof_type.target);
// 4) assert proof_non_existence.root corresponds to the root
// specified by the op (old_root for Insert/Update and new_root
@ -545,17 +504,9 @@ pub fn verify_merkle_state_transition_circuit(
};
for j in 0..HASH_SIZE {
// 4.1) assert that proof.proof_non_existence.root == proof.old_root
builder.conditional_assert_eq(
proof.enabled.target,
proof.op_proof.root.elements[j],
claim_root.elements[j],
);
builder.connect(proof.op_proof.root.elements[j], claim_root.elements[j]);
// 4.2) assert that the non-existence proof uses the op_key (value not needed).
builder.conditional_assert_eq(
proof.enabled.target,
proof.op_proof.key.elements[j],
proof.op_key.elements[j],
);
builder.connect(proof.op_proof.key.elements[j], proof.op_key.elements[j]);
}
// prepare value for check 5.2)
@ -593,7 +544,7 @@ pub fn verify_merkle_state_transition_circuit(
.map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j]))
.collect();
for j in 0..HASH_SIZE {
builder.conditional_assert_eq(proof.enabled.target, old_sibling_i[j], new_sibling_i[j]);
builder.connect(old_sibling_i[j], new_sibling_i[j]);
}
// 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that:
@ -611,7 +562,7 @@ pub fn verify_merkle_state_transition_circuit(
let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level);
// do the case2's checks
let sel = builder.and(proof.enabled, in_case_5_2);
let sel = in_case_5_2;
for j in 0..HASH_SIZE {
builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero);
builder.conditional_assert_eq(
@ -641,7 +592,6 @@ impl MerkleTreeStateTransitionProofTarget {
pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder<F, D>) -> Self {
Self {
max_depth,
enabled: builder.add_virtual_bool_target_safe(),
op: builder.add_virtual_target(),
old_root: builder.add_virtual_hash(),
@ -661,7 +611,6 @@ impl MerkleTreeStateTransitionProofTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
enabled: bool,
mp: &MerkleTreeStateTransitionProof,
) -> Result<()> {
let new_siblings = mp.siblings.clone();
@ -672,13 +621,11 @@ impl MerkleTreeStateTransitionProofTarget {
)));
}
pw.set_bool_target(self.enabled, enabled)?;
pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?;
pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?;
self.op_proof.set_targets(
pw,
enabled,
&MerkleClaimAndProof {
root: if mp.op == MerkleTreeOp::Delete {
mp.new_root
@ -703,10 +650,13 @@ impl MerkleTreeStateTransitionProofTarget {
{
pw.set_hash_target(self.siblings[i], HashOut::from_vec(sibling.0.to_vec()))?;
}
pw.set_target(
self.divergence_level,
F::from_canonical_u64((new_siblings.len() - 1) as u64),
)?;
let div_lvl = if new_siblings.is_empty() {
// don't subtract since it would underflow, use MAX_DEPTH
MAX_DEPTH as u64
} else {
(new_siblings.len() - 1) as u64
};
pw.set_target(self.divergence_level, F::from_canonical_u64(div_lvl))?;
Ok(())
}
@ -856,7 +806,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets(
&mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?;
@ -868,6 +817,42 @@ pub mod tests {
Ok(())
}
#[test]
fn test_merkleproof_pad_valid() -> Result<()> {
// circuit
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleClaimAndProofTarget::new_virtual(32, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, &MerkleClaimAndProof::pad())?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
#[test]
fn test_merkleproof_transition_pad_valid() -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(32, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, &MerkleTreeStateTransitionProof::pad())?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
#[test]
fn test_merkleproof_only_existence_verify() -> Result<()> {
for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] {
@ -903,7 +888,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets(
&mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?;
@ -979,7 +963,6 @@ pub mod tests {
verify_merkle_proof_circuit(&mut builder, &targets);
targets.set_targets(
&mut pw,
true,
&MerkleClaimAndProof::new(tree.root(), key, Some(value), proof),
)?;
@ -1025,32 +1008,15 @@ pub mod tests {
let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets);
// verification enabled & proof of existence
// proof of existence
let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof);
targets.set_targets(&mut pw, true, &mp)?;
targets.set_targets(&mut pw, &mp)?;
// generate proof, expecting it to fail (since we're using the wrong
// root)
let data = builder.build::<C>();
assert!(data.prove(pw).is_err());
// Now generate a new proof, using `enabled=false`, which should pass the verification
// despite containing 'wrong' witness.
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_proof_circuit(&mut builder, &targets);
// verification disabled & proof of existence
targets.set_targets(&mut pw, false, &mp)?;
// generate proof, should pass despite using wrong witness, since the
// `enabled=false`
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
@ -1073,7 +1039,7 @@ pub mod tests {
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, true, state_transition_proof)?;
targets.set_targets(&mut pw, state_transition_proof)?;
// generate & verify proof
let data = builder.build::<C>();
@ -1270,71 +1236,4 @@ pub mod tests {
assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check
Ok(())
}
#[test]
fn test_state_transition_gadget_disabled() -> Result<()> {
let max_depth: usize = 32;
let mut kvs = HashMap::new();
for i in 0..8 {
kvs.insert(RawValue::from(i), RawValue::from(1000 + i));
}
let mut tree = MerkleTree::new(&kvs);
let key = RawValue::from(37);
let value = RawValue::from(1037);
let _ = tree.insert(&key, &value)?;
let key = RawValue::from(21);
let value = RawValue::from(1021);
let original_state_transition_proof = tree.insert(&key, &value)?;
let mut state_transition_proof = original_state_transition_proof.clone();
// modify the proof, so that it should fail when `enabled=true`, by
// changing the new_root
state_transition_proof.new_root = state_transition_proof.old_root;
run_circuit_disabled(max_depth, &state_transition_proof)?;
// modify the proof, so that it should fail when `enabled=true`, by
// changing the new_sibling at the divergence level, which should not
// pass the verification in the case where we're inserting key=21
let mut state_transition_proof = original_state_transition_proof.clone();
state_transition_proof.siblings[4] = EMPTY_HASH;
run_circuit_disabled(max_depth, &state_transition_proof)?;
Ok(())
}
fn run_circuit_disabled(
max_depth: usize,
state_transition_proof: &MerkleTreeStateTransitionProof,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, true, state_transition_proof)?;
// generate proof, and expect it to fail
let data = builder.build::<C>();
assert!(data.prove(pw).is_err()); // expect prove to fail
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::<F>::new();
let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder);
verify_merkle_state_transition_circuit(&mut builder, &targets);
targets.set_targets(&mut pw, false, state_transition_proof)?;
// generate and expect it to pass
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
}

View file

@ -0,0 +1,97 @@
//! Module that implements the key-value DB used at the MerkleTree module.
use std::{
collections::HashMap,
fmt::Debug,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Result};
use dyn_clone::DynClone;
use crate::{
backends::plonky2::primitives::merkletree::{Intermediate, Node},
middleware::{Hash, EMPTY_HASH},
};
#[cfg(feature = "db_rocksdb")]
pub mod rocks;
pub trait DB: Debug + DynClone + Sync + Send {
/// Must always return the empty intermediate node when hash is EMPTY_HASH
fn load_node(&self, hash: Hash) -> Result<Option<Node>>;
fn store_node(&mut self, node: Node) -> Result<()>;
}
dyn_clone::clone_trait_object!(DB);
/// MemDB implements the DB trait in a in-memory HashMap.
#[derive(Clone, Debug, Default)]
pub(crate) struct MemDB {
inner: Arc<Mutex<HashMap<Hash, Node>>>,
}
impl MemDB {
pub fn new() -> Self {
Self::default()
}
}
impl DB for MemDB {
fn load_node(&self, hash: Hash) -> Result<Option<Node>> {
let db = self
.inner
.lock()
.map_err(|e| anyhow!("failed to acquire memdb lock for read: {}", e))?;
if hash == EMPTY_HASH {
return Ok(Some(Node::Intermediate(Intermediate::new(
EMPTY_HASH, EMPTY_HASH,
))));
}
Ok(db.get(&hash).cloned())
}
fn store_node(&mut self, node: Node) -> Result<()> {
let mut db = self
.inner
.lock()
.map_err(|e| anyhow!("failed to acquire memdb lock for write: {}", e))?;
db.insert(node.hash(), node);
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use super::{super::Leaf, *};
#[test]
fn test_db() -> Result<()> {
let mut db = MemDB::new();
test_db_opt(&mut db)?;
#[cfg(feature = "db_rocksdb")]
{
let path = "/tmp/rocksdb";
let mut db = rocks::RocksDB::open(path)?;
test_db_opt(&mut db)?;
}
Ok(())
}
fn test_db_opt(db: &mut dyn DB) -> Result<()> {
let node = Leaf::new(1.into(), 1.into());
db.store_node(Node::Leaf(node.clone()))?;
let obtained_node = db.load_node(node.hash)?.unwrap();
let leaf = match obtained_node {
Node::Leaf(l) => l,
_ => panic!("expected a leaf"),
};
assert_eq!(leaf.hash, node.hash);
Ok(())
}
}

View file

@ -0,0 +1,55 @@
use std::{fmt, path::Path, sync::Arc};
use anyhow::{anyhow, Result};
use rocksdb::{Options, TransactionDB, TransactionDBOptions};
use crate::{
backends::plonky2::primitives::merkletree::{self, db},
middleware::{Hash, RawValue, EMPTY_HASH},
};
#[derive(Clone)]
pub struct RocksDB(Arc<TransactionDB>);
#[allow(dead_code)]
impl RocksDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let mut options = Options::default();
options.create_if_missing(true);
let txn_options = TransactionDBOptions::default();
let inner =
TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?;
Ok(Self(Arc::new(inner)))
}
}
impl fmt::Debug for RocksDB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RocksDB")
}
}
impl db::DB for RocksDB {
fn load_node(&self, hash: Hash) -> Result<Option<merkletree::Node>> {
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
match self
.0
.get(RawValue::from(hash).to_bytes())
.map_err(|e| anyhow!("rocksdb: get failed: {e}"))?
{
None => Ok(None),
Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)),
}
}
fn store_node(&mut self, node: merkletree::Node) -> Result<()> {
self.0
.put(RawValue::from(node.hash()).to_bytes(), node.encode()?)
.map_err(|e| anyhow!("rocksdb transaction put failed: {e}"))
}
}

View file

@ -2,12 +2,16 @@
use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::Hash;
pub type TreeResult<T, E = TreeError> = core::result::Result<T, E>;
#[derive(Debug, thiserror::Error)]
pub enum TreeInnerError {
#[error("key not found")]
KeyNotFound,
#[error("node with hash {0} not found")]
NodeNotFound(Hash),
#[error("key already exists")]
KeyExists,
#[error("max depth reached")]
@ -22,6 +26,9 @@ pub enum TreeInnerError {
StateTransitionProofFail(String),
#[error("circuit max_depth {0} is smaller than proof depth {1}")]
CircuitDepthTooSmall(usize, usize),
// Other
#[error("{0}")]
Custom(String),
}
#[derive(thiserror::Error)]
@ -31,8 +38,8 @@ pub enum TreeError {
inner: Box<TreeInnerError>,
backtrace: Box<Backtrace>,
},
#[error("anyhow::Error: {0}")]
Anyhow(#[from] anyhow::Error),
#[error("database error: {0}")]
Database(anyhow::Error),
}
impl Debug for TreeError {
@ -60,6 +67,9 @@ impl TreeError {
pub(crate) fn key_not_found() -> Self {
new!(KeyNotFound)
}
pub(crate) fn node_not_found(hash: Hash) -> Self {
new!(NodeNotFound(hash))
}
pub(crate) fn key_exists() -> Self {
new!(KeyExists)
}
@ -81,4 +91,7 @@ impl TreeError {
pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self {
new!(CircuitDepthTooSmall(circuit_depth, proof_depth))
}
pub(crate) fn custom(s: impl Into<String>) -> Self {
new!(Custom(s.into()))
}
}

File diff suppressed because it is too large Load diff

View file

@ -180,11 +180,7 @@ impl EthDosHelper {
};
assert_eq!(int, Value::from(int_attestation.public_key));
let n_i64 = if let TypedValue::Int(x) = n.typed() {
*x
} else {
panic!("distance value is not Int")
};
let n_i64 = n.as_int().unwrap();
// eth_dos src->dst dist=n+1
self.n_plus_1(&mut pod, eth_dos_int_to_dst, int_attestation, n_i64)?;

View file

@ -18,6 +18,8 @@ pub enum BuilderArg {
/// Key: (origin, key), where origin is Wildcard and key is Key
Key(String, String),
WildcardLiteral(String),
/// Reference to a same-batch predicate's identity hash (resolved by name in finish()).
SelfPredicateHash(String),
}
/// When defining a `BuilderArg`, it can be done from 3 different inputs:
@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder {
params: Params,
pub name: String,
pub predicates: Vec<CustomPredicate>,
/// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name)
pending_self_pred_hashes: Vec<(usize, usize, usize, String)>,
}
impl CustomPredicateBatchBuilder {
@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder {
params,
name,
predicates: Vec::new(),
pending_self_pred_hashes: Vec::new(),
}
}
@ -171,6 +176,12 @@ impl CustomPredicateBatchBuilder {
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Result<Predicate> {
if self.predicates.iter().any(|p| p.name == name) {
return Err(Error::custom(format!(
"Duplicate predicate name '{}' in batch",
name
)));
}
if self.predicates.len() >= Params::max_custom_batch_size() {
return Err(Error::max_length(
"self.predicates.len".to_string(),
@ -194,14 +205,18 @@ impl CustomPredicateBatchBuilder {
));
}
let pred_idx = self.predicates.len();
let mut pending = Vec::new();
let statements = sts
.iter()
.map(|sb| {
.enumerate()
.map(|(stmt_idx, sb)| {
let stb = sb.clone().desugar();
let st_tmpl_args = stb
.args
.iter()
.map(|a| {
.enumerate()
.map(|(arg_idx, a)| {
Ok::<_, Error>(match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey(
@ -211,6 +226,22 @@ impl CustomPredicateBatchBuilder {
BuilderArg::WildcardLiteral(v) => {
StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?)
}
BuilderArg::SelfPredicateHash(pred_name) => {
// Try backward reference first
match self.predicates.iter().position(|p| p.name == *pred_name) {
Some(index) => StatementTmplArg::SelfPredicateHash(index),
None => {
// Forward reference - placeholder, resolved in finish()
pending.push((
pred_idx,
stmt_idx,
arg_idx,
pred_name.clone(),
));
StatementTmplArg::SelfPredicateHash(0)
}
}
}
})
})
.collect::<Result<_>>()?;
@ -240,11 +271,27 @@ impl CustomPredicateBatchBuilder {
.collect(),
)?;
self.predicates.push(custom_predicate);
self.pending_self_pred_hashes.extend(pending);
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
}
pub fn finish(self) -> Arc<CustomPredicateBatch> {
CustomPredicateBatch::new(self.name, self.predicates)
pub fn finish(mut self) -> Result<Arc<CustomPredicateBatch>> {
// Resolve forward references for SelfPredicateHash
for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes {
let target_idx = self
.predicates
.iter()
.position(|p| p.name == *name)
.ok_or_else(|| {
Error::custom(format!(
"SelfPredicateHash references unknown predicate '{}'",
name
))
})?;
self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] =
StatementTmplArg::SelfPredicateHash(target_idx);
}
Ok(CustomPredicateBatch::new(self.name, self.predicates))
}
}
@ -269,7 +316,9 @@ mod tests {
backends::plonky2::mock::mainpod::MockProver,
examples::{custom::eth_dos_batch, MOCK_VD_SET},
frontend::{MainPodBuilder, Operation},
middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET},
middleware::{
self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET,
},
};
#[test]
@ -306,7 +355,7 @@ mod tests {
.arg("s2");
builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?;
let batch = builder.finish();
let batch = builder.finish()?;
let batch_clone = batch.clone();
let gt_custom_pred = CustomPredicateRef::new(batch, 0);
@ -356,7 +405,7 @@ mod tests {
&[],
&[set_contains_stb],
)?;
let batch = builder.finish();
let batch = builder.finish()?;
let batch_clone = batch.clone();
let mut mp_builder = MainPodBuilder::new(&params, vd_set);
@ -386,4 +435,83 @@ mod tests {
Ok(())
}
#[test]
fn test_builder_self_predicate_hash_unknown_ref() {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("nonexistent".into()));
builder
.predicate_and("pred_A", &["x"], &[], &[stb])
.unwrap();
// finish() should fail because "nonexistent" was never defined
assert!(builder.finish().is_err());
}
/// Tests cyclic SelfPredicateHash references end-to-end:
/// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward
/// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD
/// using pred_A via MockProver.
#[test]
fn test_builder_self_predicate_hash_e2e() -> Result<()> {
let params = Params::default();
let vd_set = &*MOCK_VD_SET;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
// pred_A references pred_B's hash (forward ref, pred_B not yet defined)
let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
// pred_B references pred_A's hash (backward ref, pred_A already defined)
let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = builder.finish()?;
// Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0)
assert_eq!(
batch.predicates()[0].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(1)
);
assert_eq!(
batch.predicates()[1].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(0)
);
// Compute concrete hashes
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut mp_builder = MainPodBuilder::new(&params, vd_set);
let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?;
mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?;
// Prove and verify
let prover = MockProver {};
let proof = mp_builder.prove(&prover)?;
proof.pod.verify()?;
// Verify the public statement contains pred_b_hash as its argument
let pub_sts = proof.pod.pub_self_statements();
let custom_st = pub_sts
.iter()
.find(|s| matches!(s, middleware::Statement::Custom(_, _)))
.expect("should have a custom statement");
if let middleware::Statement::Custom(_, args) = custom_st {
assert_eq!(args[0], ValueRef::Literal(pred_b_hash));
}
Ok(())
}
}

View file

@ -4,7 +4,7 @@
use std::{
collections::{HashMap, HashSet},
convert::From,
fmt,
fmt, iter,
};
use itertools::Itertools;
@ -13,10 +13,12 @@ use serde::{Deserialize, Serialize};
pub use serialization::SerializedMainPod;
use crate::middleware::{
self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op,
prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation,
OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement,
StatementArg, VDSet, Value, ValueRef,
self, check_custom_pred,
containers::{Container, Dictionary},
fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key,
MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey,
RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS,
EMPTY_VALUE,
};
mod custom;
@ -92,8 +94,11 @@ impl fmt::Display for SignedDict {
// https://0xparc.github.io/pod2/merkletree.html will not need it since it will be
// deterministic based on the keys values not on the order of the keys when added into the
// tree.
for (k, v) in self.dict.kvs().iter().sorted_by_key(|kv| kv.0.hash()) {
writeln!(f, " - {} = {}", k, v)?;
for kv in self.dict.iter() {
match kv {
Ok((k, v)) => writeln!(f, " - {} = {}", k, v)?,
Err(e) => writeln!(f, " - ERR: {}", e)?,
}
}
Ok(())
}
@ -106,16 +111,13 @@ impl SignedDict {
.then_some(())
.ok_or(Error::custom("Invalid signature!"))
}
pub fn kvs(&self) -> &HashMap<Key, Value> {
self.dict.kvs()
}
pub fn get(&self, key: impl Into<Key>) -> Option<&Value> {
self.kvs().get(&key.into())
pub fn get(&self, key: impl Into<Key>) -> Option<Value> {
self.dict.get(&key.into()).unwrap()
}
// Returns the Contains statement that defines key if it exists.
pub fn get_statement(&self, key: impl Into<Key>) -> Option<Statement> {
let key: Key = key.into();
self.kvs().get(&key).map(|value| {
self.dict.get(&key).unwrap().map(|value| {
Statement::Contains(
ValueRef::Literal(Value::from(self.dict.clone())),
ValueRef::Literal(Value::from(key.name())),
@ -136,7 +138,7 @@ pub struct MainPodBuilder {
pub operations: Vec<Operation>,
pub public_statements: Vec<Statement>,
// Internal state
dict_contains: Vec<(Value, Value)>, // (root, key)
contains: Vec<(RawValue, RawValue)>, // (root, key)
}
impl fmt::Display for MainPodBuilder {
@ -156,6 +158,11 @@ impl fmt::Display for MainPodBuilder {
}
}
fn as_container_or_err(v: &Value) -> Result<Container> {
v.as_container()
.ok_or_else(|| Error::custom(format!("{v} not a container")))
}
impl MainPodBuilder {
pub fn new(params: &Params, vd_set: &VDSet) -> Self {
Self {
@ -165,10 +172,16 @@ impl MainPodBuilder {
statements: Vec::new(),
operations: Vec::new(),
public_statements: Vec::new(),
dict_contains: Vec::new(),
contains: Vec::new(),
}
}
pub fn stmt_len(&self) -> usize {
self.statements.len()
}
pub fn add_pod(&mut self, pod: MainPod) -> Result<()> {
for st in &pod.public_statements {
self.track_contains(st);
}
self.input_pods.push(pod);
match self.input_pods.len() > self.params.max_input_pods {
true => Err(Error::too_many_input_pods(
@ -178,31 +191,26 @@ impl MainPodBuilder {
_ => Ok(()),
}
}
pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) -> Result<()> {
// TODO: Do error handling instead of panic
let (st, op) = st_op;
// If we're adding a Contains statement with literal arguments (an Entry), track it in
// `dict_contains` to avoid adding it again via `Self::add_entries_contains`.
fn track_contains(&mut self, st: &Statement) {
if let Statement::Contains(
ValueRef::Literal(dict),
ValueRef::Literal(key),
ValueRef::Literal(_),
) = &st
{
let root_key = (dict.clone(), key.clone());
self.dict_contains.push(root_key);
let root_key = (dict.raw(), key.raw());
self.contains.push(root_key);
}
}
if public {
self.public_statements.push(st.clone());
}
if self.public_statements.len() > self.params.max_public_statements {
return Err(Error::too_many_public_statements(
self.public_statements.len(),
self.params.max_public_statements,
));
}
pub fn insert(&mut self, st_op: (Statement, Operation)) -> Result<()> {
// TODO: Do error handling instead of panic
let (st, op) = st_op;
self.track_contains(&st);
self.statements.push(st);
self.operations.push(op);
if self.statements.len() > self.params.max_statements {
@ -347,11 +355,12 @@ impl MainPodBuilder {
.ok_or(Error::custom(format!(
"Invalid key argument for op {}.",
op
)))?;
)))?
.raw();
let proof = if op_type == &Native(ContainsFromEntries) {
container.prove_existence(key)?.1
as_container_or_err(container)?.prove(key)?.1
} else {
container.prove_nonexistence(key)?
as_container_or_err(container)?.prove_nonexistence(key)?
};
Ok(Operation(op_type.clone(), op.1, OpAux::MerkleProof(proof)))
}
@ -375,18 +384,16 @@ impl MainPodBuilder {
let value =
op.1.get(3)
.and_then(|arg| arg.value())
.ok_or(Error::custom(format!(
"Invalid key argument for op {}.",
op
)));
.cloned()
.unwrap_or(Value::from(EMPTY_VALUE));
let proof = match op_type {
Native(ContainerInsertFromEntries) => {
old_container.prove_insertion(key, value?)?
as_container_or_err(old_container)?.insert(key.clone(), value)?
}
Native(ContainerUpdateFromEntries) => {
old_container.prove_update(key, value?)?
as_container_or_err(old_container)?.update(key.raw(), value)?
}
_ => old_container.prove_deletion(key)?,
_ => as_container_or_err(old_container)?.delete(key.raw())?,
};
Ok(Operation(
op_type.clone(),
@ -399,7 +406,7 @@ impl MainPodBuilder {
}
fn op_statement(
&mut self,
&self,
wildcard_values: Vec<(usize, Value)>,
op: Operation,
) -> Result<Statement> {
@ -560,6 +567,37 @@ impl MainPodBuilder {
// TODO: validate proof
Statement::ContainerDelete(r1, r2, r3)
}
(ReplaceValueWithEntry, &args, _) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = match args.pop().expect("valid vec len") {
OperationArg::Statement(st) => st,
_ => return Err(Error::custom("expected statement")),
};
let new_st_args = iter::zip(st.args().into_iter(), args)
.map(|(st_arg, arg)| match (st_arg, arg) {
(st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg),
(
StatementArg::Literal(st_arg_v),
OperationArg::Statement(Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
)),
) if st_arg_v == v => root_key_to_ak(&root, &key)
.map(StatementArg::Key)
.ok_or_else(native_arg_error),
_ => Err(Error::custom("unexpected operation argument")),
})
.collect::<Result<Vec<_>>>()?;
Statement::from_args(st.predicate(), new_st_args)?
}
(t, _, _) => {
if t.is_syntactic_sugar() {
return Err(Error::custom(format!(
@ -573,7 +611,7 @@ impl MainPodBuilder {
}
}
OperationType::Custom(cpr) => {
let pred = &cpr.batch.predicates()[cpr.index];
let pred = cpr.normalized_predicate();
if pred.statements.len() != op.1.len() {
return Err(Error::custom(format!(
"Custom predicate operation needs {} statements but has {}.",
@ -601,7 +639,7 @@ impl MainPodBuilder {
}
wildcard_map[index] = Some(value);
}
fill_wildcard_values(pred, &args, &mut wildcard_map)?;
fill_wildcard_values(&pred, &args, &mut wildcard_map)?;
let v_default = Value::from(0);
let st_args: Vec<_> = wildcard_map
.into_iter()
@ -609,14 +647,14 @@ impl MainPodBuilder {
.map(|v| v.unwrap_or_else(|| v_default.clone()))
.collect();
check_custom_pred(&self.params, &cpr, &args, &st_args)?;
Statement::Custom(cpr, st_args)
Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect())
}
};
Ok(st)
}
/// For every operation that has Entry statements as arguments we add a Contains statement to
/// open the dictionary.
/// open the dictionary (unless such Contains already exists).
fn add_entries_contains(&mut self, op: &Operation) -> Result<()> {
for arg in &op.1 {
if let OperationArg::Statement(Statement::Contains(
@ -625,9 +663,9 @@ impl MainPodBuilder {
ValueRef::Literal(v),
)) = arg
{
let root_key = (dict.clone(), key.clone());
if !self.dict_contains.contains(&root_key) {
self.dict_contains.push(root_key);
let root_key = (dict.raw(), key.raw());
if !self.contains.contains(&root_key) {
self.contains.push(root_key);
self.priv_op(Operation::dict_contains(dict, key, v))?;
}
}
@ -645,14 +683,29 @@ impl MainPodBuilder {
self.add_entries_contains(&op)?;
let op = Self::fill_in_aux(Self::lower_op(op)?)?;
let st = self.op_statement(wildcard_values, op.clone())?;
self.insert(public, (st, op))?;
Ok(self.statements[self.statements.len() - 1].clone())
// Skip adding the statement and operation if it already exists
if !self.statements.contains(&st) {
self.insert((st.clone(), op))?;
}
if public {
self.reveal(&st)?;
}
pub fn reveal(&mut self, st: &Statement) {
Ok(st)
}
pub fn reveal(&mut self, st: &Statement) -> Result<()> {
if !self.public_statements.contains(st) {
self.public_statements.push(st.clone());
}
if self.public_statements.len() > self.params.max_public_statements {
return Err(Error::too_many_public_statements(
self.public_statements.len(),
self.params.max_public_statements,
));
}
Ok(())
}
pub fn prove(&self, prover: &dyn MainPodProver) -> Result<MainPod> {
let compiler = MainPodCompiler::new(&self.params);
@ -1346,11 +1399,9 @@ pub mod tests {
OperationAux::None,
);
builder
.insert(false, (value_of_a.clone(), op_contains.clone()))
.unwrap();
builder
.insert(false, (value_of_b.clone(), op_contains))
.insert((value_of_a.clone(), op_contains.clone()))
.unwrap();
builder.insert((value_of_b.clone(), op_contains)).unwrap();
let st = Statement::equal(
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
@ -1363,7 +1414,7 @@ pub mod tests {
],
OperationAux::None,
);
builder.insert(false, (st, op)).unwrap();
builder.insert((st, op)).unwrap();
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();

View file

@ -6,60 +6,20 @@
use std::collections::BTreeSet;
use crate::{
frontend::{Operation, OperationArg},
middleware::{
CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef,
},
frontend::Operation,
middleware::{CustomPredicateRef, Hash, NativeOperation, OperationType, Predicate},
};
/// Unique identifier for a custom predicate batch.
/// Unique identifier for a custom predicate in a module.
///
/// Uses the batch's cryptographic hash as identifier. Two batches with the same
/// Uses the predicate's cryptographic hash as identifier. Two predicates with the same
/// hash are considered identical for resource counting purposes.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CustomBatchId(pub Hash);
pub struct CustomPredicateId(pub Hash);
impl From<&CustomPredicateBatch> for CustomBatchId {
fn from(batch: &CustomPredicateBatch) -> Self {
Self(batch.id())
}
}
/// Unique identifier for an anchored key (dict, key) pair.
///
/// When a Contains statement is used as an argument to operations like gt(), eq(), etc.,
/// the value is accessed via an "anchored key" - a reference to a specific key in a
/// specific dictionary. Each unique anchored key used in a POD requires a Contains
/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed).
///
/// We use the raw values of the dict and key for comparison, as they uniquely identify
/// the anchored key regardless of the specific Value types involved.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AnchoredKeyId {
/// The dictionary root value (raw representation for Ord).
pub dict: RawValue,
/// The key within the dictionary (raw representation for Ord).
pub key: RawValue,
}
impl AnchoredKeyId {
/// Create a new anchored key ID from raw values.
pub fn new(dict: RawValue, key: RawValue) -> Self {
Self { dict, key }
}
/// Try to extract an anchored key ID from a Contains statement with all literal values.
pub fn from_contains_statement(stmt: &Statement) -> Option<Self> {
if let Statement::Contains(
ValueRef::Literal(dict),
ValueRef::Literal(key),
ValueRef::Literal(_value),
) = stmt
{
Some(Self::new(dict.raw(), key.raw()))
} else {
None
}
impl From<&CustomPredicateRef> for CustomPredicateId {
fn from(predicate: &CustomPredicateRef) -> Self {
Self(Predicate::Custom(predicate.clone()).hash())
}
}
@ -88,17 +48,9 @@ pub struct StatementCost {
/// Limit: `params.max_public_key_of`
pub public_key_of: usize,
/// Custom predicate batches used (for batch cardinality constraint).
/// Limit: `params.max_custom_predicate_batches` distinct batches per POD.
pub custom_batch_ids: BTreeSet<CustomBatchId>,
/// Anchored keys referenced by this operation.
///
/// When a Contains statement with all literal values is used as an argument,
/// the operation references an "anchored key" (dict, key pair). Each unique
/// anchored key used in a POD incurs an additional Contains statement cost,
/// as MainPodBuilder::add_entries_contains will auto-insert it if not already present.
pub anchored_keys: BTreeSet<AnchoredKeyId>,
/// Custom predicates used (for custom predicate cardinality constraint).
/// Limit: `params.max_custom_predicates` distinct custom predicates per POD.
pub custom_predicates_ids: BTreeSet<CustomPredicateId>,
}
impl StatementCost {
@ -159,25 +111,14 @@ impl StatementCost {
// Syntactic sugar variants (lowered before proving)
| NativeOperation::GtEqFromEntries
| NativeOperation::GtFromEntries
| NativeOperation::GtToNotEqual => {}
| NativeOperation::GtToNotEqual
| NativeOperation::ReplaceValueWithEntry => {}
}
}
OperationType::Custom(cpr) => {
cost.custom_pred_verifications = 1;
cost.custom_batch_ids
.insert(CustomBatchId::from(&*cpr.batch));
}
}
// Extract anchored keys from operation arguments.
// Any argument that is a Contains statement with all literal values
// represents an anchored key reference that will require a Contains
// statement in the POD (auto-inserted by MainPodBuilder if needed).
for arg in &op.1 {
if let OperationArg::Statement(stmt) = arg {
if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) {
cost.anchored_keys.insert(anchored_key);
}
cost.custom_predicates_ids
.insert(CustomPredicateId::from(cpr));
}
}

View file

@ -5,7 +5,6 @@
use std::collections::HashMap;
use super::cost::AnchoredKeyId;
use crate::{
frontend::{Operation, OperationArg},
middleware::{Hash, Statement},
@ -100,11 +99,6 @@ impl DependencyGraph {
pod_hash,
statement: dep_stmt.clone(),
}));
} else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() {
// Anchored-key Contains args may be implicit requirements that are
// auto-materialized by MainPodBuilder. They are handled by anchored-key
// resource accounting, not by statement dependency edges.
continue;
} else {
// Statement arguments should either be internal (created earlier)
// or from external PODs (except anchored-key implicit Contains).
@ -128,9 +122,8 @@ impl DependencyGraph {
mod tests {
use super::*;
use crate::{
dict,
frontend::Operation as FrontendOp,
middleware::{AnchoredKey, NativeOperation, OperationAux, OperationType, Value, ValueRef},
middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef},
};
fn equal_stmt(n: i64) -> Statement {
@ -195,32 +188,4 @@ mod tests {
assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]);
assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]);
}
#[test]
fn test_anchored_key_contains_arg_is_treated_as_implicit_requirement() {
// A literal Contains statement can be used as an anchored-key argument even when
// no explicit producer statement exists in internal/external statements, because
// MainPodBuilder auto-inserts Contains statements for anchored keys.
let dict = dict!({
"k" => 7_i64
});
let anchored_contains = Statement::Contains(
ValueRef::Literal(Value::from(dict.clone())),
ValueRef::Literal(Value::from("k")),
ValueRef::Literal(Value::from(7_i64)),
);
let ak = AnchoredKey::from((&dict, "k"));
let produced_statement = Statement::Equal(ValueRef::Key(ak.clone()), ValueRef::Key(ak));
// Use a typical frontend operation that consumes entry-like args.
// We're only testing the dependency graph, not the actual proof, so the operation
// just needs to have the right arguments to test what we're looking for.
let statements = vec![produced_statement];
let operations = vec![FrontendOp::eq(anchored_contains.clone(), anchored_contains)];
let graph = DependencyGraph::build(&statements, &operations, &HashMap::new());
assert!(graph.statement_deps[0].is_empty());
}
}

View file

@ -0,0 +1,466 @@
//! Diagnostic utilities for multi-POD resource analysis.
//!
//! Provides two views:
//! - [`ResourceSummary`]: Pre-solve aggregate resource demand vs. per-POD limits.
//! Shows which resource category is the bottleneck (requires the most PODs).
//! - [`SolutionBreakdown`]: Post-solve per-POD utilization showing how full each POD is.
use std::{collections::BTreeSet, fmt};
use super::cost::StatementCost;
use crate::middleware::Params;
/// A single resource category's usage vs. per-POD limit.
///
/// Used both for pre-solve aggregate demand (in [`ResourceSummary`]) where
/// `used` is the total across all statements, and for post-solve per-POD
/// breakdown (in [`PodUtilization`]) where `used` is the POD's consumption.
#[derive(Clone, Debug)]
pub struct UtilizationRow {
pub name: &'static str,
pub used: usize,
pub limit: usize,
}
impl UtilizationRow {
/// Utilization as a fraction (0.0 to 1.0).
pub fn utilization(&self) -> f64 {
if self.limit == 0 {
if self.used == 0 {
0.0
} else {
f64::INFINITY
}
} else {
self.used as f64 / self.limit as f64
}
}
/// Minimum PODs needed for this resource alone: `ceil(used / limit)`.
/// `None` if `limit` is 0 and `used > 0` (infeasible).
pub fn min_pods(&self) -> Option<usize> {
lower_bound(self.used, self.limit)
}
}
/// Aggregate resource usage over a set of statement costs into per-category rows.
///
/// Single source of truth for the resource categories and their corresponding
/// `Params` limits. Used both for pre-solve totals and per-POD breakdowns.
fn aggregate_rows<'a>(
costs: impl IntoIterator<Item = &'a StatementCost>,
params: &Params,
) -> (Vec<UtilizationRow>, usize) {
let mut num_stmts = 0usize;
let mut merkle_proofs = 0usize;
let mut merkle_state_transitions = 0usize;
let mut custom_pred_verifications = 0usize;
let mut signed_by = 0usize;
let mut public_key_of = 0usize;
let mut custom_pred_ids = BTreeSet::new();
for c in costs {
num_stmts += 1;
merkle_proofs += c.merkle_proofs;
merkle_state_transitions += c.merkle_state_transitions;
custom_pred_verifications += c.custom_pred_verifications;
signed_by += c.signed_by;
public_key_of += c.public_key_of;
custom_pred_ids.extend(c.custom_predicates_ids.iter().cloned());
}
let rows = vec![
UtilizationRow {
name: "private statements",
used: num_stmts,
limit: params.max_priv_statements(),
},
UtilizationRow {
name: "merkle proofs",
used: merkle_proofs,
limit: params.containers.state.max_medium,
},
UtilizationRow {
name: "merkle state transitions",
used: merkle_state_transitions,
limit: params.containers.transition.max_medium,
},
UtilizationRow {
name: "custom pred verifications",
used: custom_pred_verifications,
limit: params.max_custom_predicate_verifications,
},
UtilizationRow {
name: "signed_by",
used: signed_by,
limit: params.max_signed_by,
},
UtilizationRow {
name: "public_key_of",
used: public_key_of,
limit: params.max_public_key_of,
},
UtilizationRow {
name: "distinct custom predicates",
used: custom_pred_ids.len(),
limit: params.max_custom_predicates,
},
];
(rows, num_stmts)
}
/// Pre-solve aggregate resource summary.
///
/// Shows total resource demand across all operations and the minimum PODs
/// each resource category would require independently.
#[derive(Clone, Debug)]
pub struct ResourceSummary {
pub rows: Vec<UtilizationRow>,
pub num_statements: usize,
}
impl ResourceSummary {
/// Compute a resource summary from per-statement costs and params.
pub fn from_costs(costs: &[StatementCost], params: &Params) -> Self {
let (rows, num_statements) = aggregate_rows(costs.iter(), params);
Self {
rows,
num_statements,
}
}
/// The resource category requiring the most PODs (the bottleneck).
/// Returns `None` only if there are no statements.
pub fn bottleneck(&self) -> Option<&UtilizationRow> {
self.rows
.iter()
.filter(|r| r.used > 0)
.max_by_key(|r| r.min_pods().unwrap_or(usize::MAX))
}
}
impl fmt::Display for ResourceSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Resource Summary ({} statements)", self.num_statements)?;
writeln!(
f,
" {:<30} {:>5} {:>9} {:>8}",
"Category", "Total", "Limit/POD", "Min PODs"
)?;
let bottleneck_name = self.bottleneck().map(|r| r.name);
for row in &self.rows {
let min_pods_str = match row.min_pods() {
Some(n) => format!("{}", n),
None => "inf".to_string(),
};
let marker = if Some(row.name) == bottleneck_name && row.used > 0 {
" <<<"
} else {
""
};
writeln!(
f,
" {:<30} {:>5} {:>9} {:>8}{}",
row.name, row.used, row.limit, min_pods_str, marker
)?;
}
Ok(())
}
}
/// Per-POD resource utilization in a solved solution.
#[derive(Clone, Debug)]
pub struct PodUtilization {
/// POD index.
pub pod_idx: usize,
/// Whether this is the output POD (last).
pub is_output: bool,
/// Number of statements in this POD.
pub num_statements: usize,
/// Resource usage vs. limits for each category.
pub resources: Vec<UtilizationRow>,
}
/// Post-solve per-POD resource breakdown.
#[derive(Clone, Debug)]
pub struct SolutionBreakdown {
pub pods: Vec<PodUtilization>,
pub num_statements: usize,
pub pod_count: usize,
}
impl SolutionBreakdown {
/// Compute a solution breakdown from per-statement costs, the solution's
/// pod_statements assignment, and params.
pub fn from_solution(
costs: &[StatementCost],
pod_statements: &[Vec<usize>],
pod_count: usize,
num_statements: usize,
params: &Params,
) -> Self {
let pods = (0..pod_count)
.map(|pod_idx| {
let stmts = &pod_statements[pod_idx];
let (resources, num_stmts) =
aggregate_rows(stmts.iter().map(|&s| &costs[s]), params);
PodUtilization {
pod_idx,
is_output: pod_idx == pod_count - 1,
num_statements: num_stmts,
resources,
}
})
.collect();
Self {
pods,
num_statements,
pod_count,
}
}
}
impl fmt::Display for SolutionBreakdown {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Solution Breakdown ({} statements -> {} PODs)",
self.num_statements, self.pod_count
)?;
for pod in &self.pods {
let role = if pod.is_output {
"output"
} else {
"intermediate"
};
writeln!(f, " POD {} ({}):", pod.pod_idx, role)?;
for row in &pod.resources {
// Only show rows with nonzero usage to reduce noise
if row.used > 0 {
let pct = if row.limit > 0 {
format!("({:>3}%)", (row.used * 100) / row.limit)
} else {
"".to_string()
};
writeln!(
f,
" {:<30} {:>3}/{:<3} {}",
row.name, row.used, row.limit, pct
)?;
}
}
writeln!(f)?;
}
Ok(())
}
}
fn lower_bound(used: usize, limit: usize) -> Option<usize> {
if used == 0 {
Some(0)
} else if limit == 0 {
None
} else {
Some(used.div_ceil(limit))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
frontend::multi_pod::cost::CustomPredicateId,
middleware::{Hash, ParamsContainers, ParamsMerkleProofs, RawValue},
};
fn default_params() -> Params {
Params {
max_statements: 48,
max_public_statements: 8,
containers: ParamsContainers {
state: ParamsMerkleProofs {
max_small: 0,
max_medium: 8,
},
transition: ParamsMerkleProofs {
max_small: 0,
max_medium: 4,
},
..Default::default()
},
max_custom_predicate_verifications: 10,
max_custom_predicates: 2,
max_signed_by: 4,
max_public_key_of: 4,
..Params::default()
}
}
#[test]
fn test_resource_summary_bottleneck() {
let params = default_params();
// max_priv = 48 - 8 = 40
// 6 merkle proofs, 3 state transitions, rest zero-cost
let costs: Vec<StatementCost> = (0..14)
.map(|i| {
let mut c = StatementCost::default();
if i < 6 {
c.merkle_proofs = 1;
} else if i < 9 {
c.merkle_state_transitions = 1;
}
c
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
// 14 statements / 40 per pod = 1 pod for statements
// 6 merkle proofs / 8 per pod = 1 pod
// 3 state transitions / 4 per pod = 1 pod
// All categories need 1 pod, so bottleneck is whichever has the highest min_pods.
// They're all 1, so the first with total > 0 wins in max_by_key (stable).
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.min_pods(), Some(1));
// Verify display doesn't panic
let display = format!("{}", summary);
assert!(display.contains("Resource Summary (14 statements)"));
assert!(display.contains("merkle proofs"));
}
#[test]
fn test_resource_summary_signed_by_bottleneck() {
let params = Params {
max_statements: 48,
max_public_statements: 8,
max_signed_by: 2,
..Params::default()
};
// max_priv = 40
// 6 signed_by operations
let costs: Vec<StatementCost> = (0..6)
.map(|_| StatementCost {
signed_by: 1,
..Default::default()
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.name, "signed_by");
// 6 / 2 = 3 pods
assert_eq!(bottleneck.min_pods(), Some(3));
}
#[test]
fn test_resource_summary_custom_predicates_bottleneck() {
let params = Params {
max_statements: 48,
max_public_statements: 8,
max_custom_predicates: 1, // Only 1 distinct predicate per POD
max_custom_predicate_verifications: 10,
..Params::default()
};
// 3 statements using 3 different custom predicates
let costs: Vec<StatementCost> = (0..3)
.map(|i| {
let mut ids = std::collections::BTreeSet::new();
ids.insert(CustomPredicateId(Hash::from(RawValue::from(i as i64))));
StatementCost {
custom_pred_verifications: 1,
custom_predicates_ids: ids,
..Default::default()
}
})
.collect();
let summary = ResourceSummary::from_costs(&costs, &params);
let bottleneck = summary.bottleneck().unwrap();
assert_eq!(bottleneck.name, "distinct custom predicates");
// 3 distinct predicates / 1 per pod = 3 pods
assert_eq!(bottleneck.min_pods(), Some(3));
}
#[test]
fn test_solution_breakdown_display() {
let params = default_params();
let costs: Vec<StatementCost> = (0..8)
.map(|i| {
let mut c = StatementCost::default();
if i < 4 {
c.merkle_proofs = 1;
} else {
c.merkle_state_transitions = 1;
}
c
})
.collect();
let pod_statements = vec![
vec![0, 1, 2, 3], // POD 0: 4 merkle proofs
vec![4, 5, 6, 7], // POD 1: 4 state transitions
];
let breakdown = SolutionBreakdown::from_solution(&costs, &pod_statements, 2, 8, &params);
assert_eq!(breakdown.pods.len(), 2);
assert!(!breakdown.pods[0].is_output);
assert!(breakdown.pods[1].is_output);
// POD 0 should have 4 merkle proofs
let mp = breakdown.pods[0]
.resources
.iter()
.find(|r| r.name == "merkle proofs")
.unwrap();
assert_eq!(mp.used, 4);
assert_eq!(mp.limit, 8);
// POD 1 should have 4 state transitions
let mst = breakdown.pods[1]
.resources
.iter()
.find(|r| r.name == "merkle state transitions")
.unwrap();
assert_eq!(mst.used, 4);
assert_eq!(mst.limit, 4);
// Verify display doesn't panic and contains expected content
let display = format!("{}", breakdown);
assert!(display.contains("Solution Breakdown (8 statements -> 2 PODs)"));
assert!(display.contains("POD 0 (intermediate)"));
assert!(display.contains("POD 1 (output)"));
}
#[test]
fn test_utilization_row_fraction() {
let row = UtilizationRow {
name: "test",
used: 3,
limit: 4,
};
assert!((row.utilization() - 0.75).abs() < f64::EPSILON);
let zero_row = UtilizationRow {
name: "test",
used: 0,
limit: 4,
};
assert!((zero_row.utilization()).abs() < f64::EPSILON);
}
}

View file

@ -48,21 +48,23 @@
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeSet, HashMap},
fmt,
};
use crate::{
frontend::{MainPod, MainPodBuilder, Operation, OperationArg},
frontend::{MainPod, MainPodBuilder, Operation},
middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value},
};
mod cost;
mod deps;
pub mod diagnostics;
mod solver;
use cost::{AnchoredKeyId, StatementCost};
use cost::StatementCost;
use deps::{DependencyGraph, StatementSource};
pub use diagnostics::{ResourceSummary, SolutionBreakdown};
pub use solver::MultiPodSolution;
/// Error type for multi-POD operations.
@ -168,12 +170,8 @@ pub struct MultiPodBuilder {
options: Options,
/// External input PODs (already proved).
input_pods: Vec<MainPod>,
/// Statements created by this builder.
statements: Vec<Statement>,
/// Operations that produce each statement.
operations: Vec<Operation>,
/// Optional initial wildcard values for custom operations
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
/// Indices of statements that should be public in output PODs.
/// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted.
output_public_indices: Vec<usize>,
@ -193,7 +191,7 @@ pub struct SolvedMultiPod {
statements: Vec<Statement>,
operations: Vec<Operation>,
output_public_indices: Vec<usize>,
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
solution: MultiPodSolution,
deps: DependencyGraph,
}
@ -204,6 +202,22 @@ impl SolvedMultiPod {
&self.solution
}
/// Compute a post-solve per-POD resource utilization breakdown.
pub fn solution_breakdown(&self) -> SolutionBreakdown {
let costs: Vec<StatementCost> = self
.operations
.iter()
.map(StatementCost::from_operation)
.collect();
SolutionBreakdown::from_solution(
&costs,
&self.solution.pod_statements,
self.solution.pod_count,
self.statements.len(),
&self.params,
)
}
/// Build and prove all PODs.
///
/// Builds PODs in dependency order (0, 1, ..., k) and proves each one.
@ -260,56 +274,27 @@ impl SolvedMultiPod {
let statements_sorted: BTreeSet<usize> = statements_in_this_pod.iter().copied().collect();
let public_set = &solution.pod_public_statements[pod_idx];
// Track statements proved locally in this POD for argument remapping.
// We index by statement content so duplicate statements can reuse a single
// built statement slot in MainPodBuilder.
let mut added_statements_by_content: HashMap<Statement, Statement> = HashMap::new();
for &stmt_idx in &statements_sorted {
let original_stmt = self.statements[stmt_idx].clone();
// If this statement content was already built in this POD, reuse it instead
// of replaying the operation. If any duplicate is public, reveal the
// already-built statement.
if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) {
continue;
}
let mut op = self.operations[stmt_idx].clone();
let wildcard_values = self.operations_wildcard_values[stmt_idx].clone();
// Remap Statement arguments that reference locally-proved statements.
// For external dependencies (from input PODs including earlier generated PODs),
// the original Statement is used directly - MainPodBuilder will find it in
// the input POD's public statements via find_op_arg.
for arg in &mut op.1 {
if let OperationArg::Statement(ref orig_stmt) = arg {
if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) {
*arg = OperationArg::Statement(remapped_stmt.clone());
}
}
}
let op = self.operations[stmt_idx].clone();
let wildcard_values = self
.operations_wildcard_values
.get(&stmt_idx)
.cloned()
.unwrap_or_default();
let stmt = builder.op(false, wildcard_values, op)?;
added_statements_by_content.insert(original_stmt, stmt);
assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check
}
// For the output pod, make statements public in the original order.
// Intermediate pods use the solver-selected public set.
if pod_idx == solution.pod_count - 1 {
for idx in &self.output_public_indices {
let stmt = added_statements_by_content
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
builder.reveal(&self.statements[*idx])?;
}
} else {
for idx in public_set {
let stmt = added_statements_by_content
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
builder.reveal(&self.statements[*idx])?;
}
}
@ -317,7 +302,7 @@ impl SolvedMultiPod {
// for this POD. These do not require local proving in this POD.
for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] {
let ext_premise = &solution.external_premises[*ext_premise_idx];
builder.reveal(&ext_premise.statement);
builder.reveal(&ext_premise.statement)?;
}
// Step 4: Prove the POD
@ -456,9 +441,7 @@ impl MultiPodBuilder {
options,
builder,
input_pods: Vec::new(),
statements: Vec::new(),
operations: Vec::new(),
operations_wildcard_values: Vec::new(),
operations_wildcard_values: HashMap::new(),
output_public_indices: Vec::new(),
}
}
@ -480,6 +463,16 @@ impl MultiPodBuilder {
self.op(false, vec![], op)
}
// Find the index of a statement that has been added. Panics if the statement doesn't
// exist.
fn stmt_index(&self, stmt: &Statement) -> usize {
self.builder
.statements
.iter()
.position(|s| s == stmt)
.expect("exists")
}
pub fn op(
&mut self,
public: bool,
@ -488,8 +481,10 @@ impl MultiPodBuilder {
) -> Result<Statement> {
let stmt = self.add_operation(wildcard_values, op)?;
if public {
// Index is always new (just added), so push without duplicate check
self.output_public_indices.push(self.statements.len() - 1);
let index = self.stmt_index(&stmt);
if !self.output_public_indices.contains(&index) {
self.output_public_indices.push(index);
}
}
Ok(stmt)
}
@ -510,10 +505,8 @@ impl MultiPodBuilder {
let stmt = self
.builder
.op(false, wildcard_values.clone(), op.clone())?;
self.statements.push(stmt.clone());
self.operations.push(op);
self.operations_wildcard_values.push(wildcard_values);
self.operations_wildcard_values
.insert(self.stmt_index(&stmt), wildcard_values.clone());
Ok(stmt)
}
@ -523,7 +516,7 @@ impl MultiPodBuilder {
/// Returns an error if the statement was not found in the builder.
/// Calling this multiple times on the same statement is idempotent.
pub fn reveal(&mut self, stmt: &Statement) -> Result<()> {
if let Some(idx) = self.statements.iter().position(|s| s == stmt) {
if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) {
if !self.output_public_indices.contains(&idx) {
self.output_public_indices.push(idx);
}
@ -536,8 +529,22 @@ impl MultiPodBuilder {
}
/// Get the number of statements.
pub fn num_statements(&self) -> usize {
self.statements.len()
pub fn stmt_len(&self) -> usize {
self.builder.stmt_len()
}
/// Compute a pre-solve resource summary showing aggregate demand vs. per-POD limits.
///
/// This is useful for understanding which resource category is the bottleneck
/// before running the solver, especially when debugging solver performance issues.
pub fn resource_summary(&self) -> ResourceSummary {
let costs: Vec<StatementCost> = self
.builder
.operations
.iter()
.map(StatementCost::from_operation)
.collect();
ResourceSummary::from_costs(&costs, &self.params)
}
/// Solve the packing problem and return a solved builder ready for proving.
@ -545,66 +552,31 @@ impl MultiPodBuilder {
/// This runs the MILP solver to find the optimal POD assignment.
/// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved.
pub fn solve(self) -> Result<SolvedMultiPod> {
let MainPodBuilder {
statements,
operations,
..
} = self.builder;
// Compute costs for each statement
let costs: Vec<StatementCost> = self
.operations
let costs: Vec<StatementCost> = operations
.iter()
.map(StatementCost::from_operation)
.collect();
// Collect all unique anchored keys from the costs
let all_anchored_keys: Vec<AnchoredKeyId> = costs
.iter()
.flat_map(|c| c.anchored_keys.iter().cloned())
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
// Build map from anchored key to its producing statement index (if any).
// A Contains statement with literal (dict, key, value) "produces" that anchored key.
let mut ak_to_producer: HashMap<AnchoredKeyId, usize> = HashMap::new();
for (stmt_idx, stmt) in self.statements.iter().enumerate() {
if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) {
// First producer wins (shouldn't have duplicates in practice)
ak_to_producer.entry(ak).or_insert(stmt_idx);
}
}
// Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i]
let anchored_key_producers: Vec<Option<usize>> = all_anchored_keys
.iter()
.map(|ak| ak_to_producer.get(ak).copied())
.collect();
// Build external POD statement mapping
let external_pod_statements = build_external_statement_map(&self.input_pods);
// Build dependency graph
let deps =
DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements);
// Build statement content groups for deduplication.
// Statements with identical content share a single slot in the POD.
// Keep groups ordered by first occurrence index for deterministic solver input.
let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new();
let mut groups_by_first_idx: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
for (idx, stmt) in self.statements.iter().enumerate() {
let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx);
groups_by_first_idx.entry(first_idx).or_default().push(idx);
}
let statement_content_groups: Vec<Vec<usize>> = groups_by_first_idx.into_values().collect();
let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements);
// Run solver
let input = solver::SolverInput {
num_statements: self.statements.len(),
num_statements: statements.len(),
costs: &costs,
deps: &deps,
output_public_indices: &self.output_public_indices,
params: &self.params,
max_pods: self.options.max_pods,
all_anchored_keys: &all_anchored_keys,
anchored_key_producers: &anchored_key_producers,
statement_content_groups: &statement_content_groups,
};
let solution = solver::solve(&input)?;
@ -613,8 +585,8 @@ impl MultiPodBuilder {
params: self.params,
vd_set: self.vd_set,
input_pods: self.input_pods,
statements: self.statements,
operations: self.operations,
statements,
operations,
output_public_indices: self.output_public_indices,
operations_wildcard_values: self.operations_wildcard_values,
solution,
@ -845,33 +817,13 @@ mod tests {
let solution = solved.solution();
// Expected: exactly 2 PODs
// - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public
// - POD 1 (output): statement 2 (b_out); b_out is public
// The output POD accesses a_out from POD 0 to satisfy b_out's dependency.
assert_eq!(
solution.pod_count, 2,
"Expected exactly 2 PODs for 3-statement chain with max_priv=2"
);
// POD 0 should contain statements 0 and 1 (contains and a_out)
assert!(
solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1),
"POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}",
solution.pod_statements[0]
);
// Statement 1 (a_out) should be public in POD 0 so POD 1 can access it
assert!(
solution.pod_public_statements[0].contains(&1),
"Statement 1 (a_out) should be public in POD 0"
);
// POD 1 (output) should contain statement 2 (b_out)
assert!(
solution.pod_statements[1].contains(&2),
"POD 1 should contain statement 2 (b_out), got {:?}",
solution.pod_statements[1]
);
// Solution A:
// - POD 0 (intermediate): public statements 0 (contains)
// - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out),
// public statement 2 (b_out)
// Solution B:
// - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out)
// - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out)
// Statement 2 (b_out) should be public in POD 1 (it's output-public)
assert!(

View file

@ -52,7 +52,7 @@ use itertools::Itertools;
use super::Result;
use crate::{
frontend::multi_pod::{
cost::{AnchoredKeyId, CustomBatchId, StatementCost},
cost::{CustomPredicateId, StatementCost},
deps::{DependencyGraph, ExternalDependency, StatementSource},
},
middleware::{Hash, Params},
@ -95,7 +95,6 @@ struct DependencyStats {
struct SolveDebugContext {
dep_stats: DependencyStats,
batch_memberships: usize,
anchored_key_memberships: usize,
}
#[derive(Clone, Copy, Debug, Default)]
@ -105,10 +104,8 @@ struct ModelSizeEstimate {
vars_public_external: usize,
vars_pod_used: usize,
vars_batch_used: usize,
vars_anchored_key_used: usize,
vars_uses_input: usize,
vars_uses_external: usize,
vars_content_group_used: usize,
vars_total: usize,
c1_coverage: usize,
c2_output_public: usize,
@ -120,7 +117,6 @@ struct ModelSizeEstimate {
c6_pre_content_group: usize,
c6_resource_limits: usize,
c7_batch_cardinality: usize,
c7b_anchored_key_tracking: usize,
c8a_internal_inputs: usize,
c8b_external_dep_inputs: usize,
c8c_external_forward_inputs: usize,
@ -141,8 +137,6 @@ impl ModelSizeEstimate {
debug_ctx: &SolveDebugContext,
) -> Self {
let n = input.num_statements;
let num_groups = input.statement_content_groups.len();
let num_anchored_keys = input.all_anchored_keys.len();
let triangular_k = target_pods * target_pods.saturating_sub(1) / 2;
let vars_prove = n * target_pods;
@ -150,19 +144,15 @@ impl ModelSizeEstimate {
let vars_public_external = external_premises_len * target_pods;
let vars_pod_used = target_pods;
let vars_batch_used = all_batches_len * target_pods;
let vars_anchored_key_used = num_anchored_keys * target_pods;
let vars_uses_input = triangular_k;
let vars_uses_external = external_pods_len * target_pods;
let vars_content_group_used = num_groups * target_pods;
let vars_total = vars_prove
+ vars_public
+ vars_public_external
+ vars_pod_used
+ vars_batch_used
+ vars_anchored_key_used
+ vars_uses_input
+ vars_uses_external
+ vars_content_group_used;
+ vars_uses_external;
let c1_coverage = n;
let c2_output_public = input.output_public_indices.len();
@ -171,12 +161,10 @@ impl ModelSizeEstimate {
let c4_pod_existence = n * target_pods;
let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods;
let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods;
let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods);
let c6_pre_content_group = n * target_pods;
let c6_resource_limits = 7 * target_pods;
let c7_batch_cardinality =
(debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods);
let c7b_anchored_key_tracking =
(debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods);
let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k;
let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k;
let c8c_external_forward_inputs = external_premises_len * triangular_k;
@ -194,7 +182,6 @@ impl ModelSizeEstimate {
+ c6_pre_content_group
+ c6_resource_limits
+ c7_batch_cardinality
+ c7b_anchored_key_tracking
+ c8a_internal_inputs
+ c8b_external_dep_inputs
+ c8c_external_forward_inputs
@ -209,10 +196,8 @@ impl ModelSizeEstimate {
vars_public_external,
vars_pod_used,
vars_batch_used,
vars_anchored_key_used,
vars_uses_input,
vars_uses_external,
vars_content_group_used,
vars_total,
c1_coverage,
c2_output_public,
@ -224,7 +209,6 @@ impl ModelSizeEstimate {
c6_pre_content_group,
c6_resource_limits,
c7_batch_cardinality,
c7b_anchored_key_tracking,
c8a_internal_inputs,
c8b_external_dep_inputs,
c8c_external_forward_inputs,
@ -300,6 +284,7 @@ pub struct MultiPodSolution {
}
/// Input to the MILP solver.
#[derive(Debug)]
pub struct SolverInput<'a> {
/// Number of statements.
pub num_statements: usize,
@ -318,28 +303,6 @@ pub struct SolverInput<'a> {
/// Maximum number of PODs the solver will consider.
pub max_pods: usize,
/// All unique anchored keys referenced by any statement.
///
/// Each unique (dict, key) pair that is used as an anchored key reference
/// in any operation. When a Contains statement with literal values is used
/// as an argument, it creates an anchored key reference.
pub all_anchored_keys: &'a [AnchoredKeyId],
/// For each anchored key, the statement index that produces it (if any).
///
/// When a Contains statement with literal (dict, key, value) args is explicitly
/// added, it "produces" that anchored key. If the producer is in the same POD
/// as statements using the anchored key, no auto-insertion is needed.
/// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`.
pub anchored_key_producers: &'a [Option<usize>],
/// Statement content groups for deduplication.
///
/// Each inner Vec contains statement indices that have identical content.
/// When multiple statements with the same content are proved in the same POD,
/// they only use one statement slot (the POD deduplicates identical statements).
pub statement_content_groups: &'a [Vec<usize>],
}
/// Solve the MILP problem to find optimal POD packing.
@ -386,11 +349,11 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
)));
}
// Collect all unique custom batch IDs used
let all_batches: Vec<CustomBatchId> = input
// Collect all unique custom predicate IDs used
let all_custom_predicates: Vec<CustomPredicateId> = input
.costs
.iter()
.flat_map(|c| c.custom_batch_ids.iter().cloned())
.flat_map(|c| c.custom_predicates_ids.iter().cloned())
.unique()
.collect();
@ -417,27 +380,26 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
}
let dep_stats = dependency_stats(input.deps);
let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum();
let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum();
let batch_memberships: usize = input
.costs
.iter()
.map(|c| c.custom_predicates_ids.len())
.sum();
let debug_ctx = SolveDebugContext {
dep_stats,
batch_memberships,
anchored_key_memberships,
};
if log::log_enabled!(log::Level::Debug) {
let resource_totals = ResourceTotals::from_costs(input.costs);
let lb_statement_groups =
lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod);
let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod);
let lb_merkle = lower_bound_from_total(
resource_totals.merkle_proofs,
input.params.max_merkle_proofs_containers,
input.params.containers.state.max_medium,
);
let lb_merkle_transitions = lower_bound_from_total(
resource_totals.merkle_state_transitions,
input
.params
.max_merkle_tree_state_transition_proofs_containers,
input.params.containers.transition.max_medium,
);
let lb_custom_pred_verifications = lower_bound_from_total(
resource_totals.custom_pred_verifications,
@ -463,14 +425,12 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
.expect("non-empty lower-bound candidate list");
log::debug!(
"MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \
batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \
"MILP summary: statements={} output_public={} \
custom_predicates={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \
external_premises={} search_min_pods={} max_pods={}",
n,
num_output_public,
input.statement_content_groups.len(),
input.all_anchored_keys.len(),
all_batches.len(),
all_custom_predicates.len(),
dep_stats.internal_edges,
dep_stats.external_edges,
external_pods.len(),
@ -481,14 +441,13 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
log::debug!(
"MILP resource totals: merkle_proofs={} merkle_state_transitions={} \
custom_pred_verifications={} signed_by={} public_key_of={} \
batch_memberships={} anchored_key_memberships={}",
batch_memberships={}",
resource_totals.merkle_proofs,
resource_totals.merkle_state_transitions,
resource_totals.custom_pred_verifications,
resource_totals.signed_by,
resource_totals.public_key_of,
batch_memberships,
anchored_key_memberships
);
log::debug!(
"MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \
@ -513,7 +472,7 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
if let Some(solution) = try_solve_with_pods(
input,
target_pods,
&all_batches,
&all_custom_predicates,
&external_pods,
&external_premises,
&debug_ctx,
@ -540,7 +499,7 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
fn try_solve_with_pods(
input: &SolverInput,
target_pods: usize,
all_batches: &[CustomBatchId],
all_custom_predicates: &[CustomPredicateId],
external_pods: &[Hash],
external_premises: &[ExternalDependency],
debug_ctx: &SolveDebugContext,
@ -574,21 +533,8 @@ fn try_solve_with_pods(
.map(|_| vars.add(variable().binary()))
.collect();
// batch_used[b][p] - custom batch b is used in POD p
let batch_used: Vec<Vec<Variable>> = (0..all_batches.len())
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
.collect()
})
.collect();
// anchored_key_used[ak][p] - anchored key ak is used in POD p
// When a statement references an anchored key (via a Contains statement argument),
// that POD must have a Contains statement for that (dict, key) pair.
// MainPodBuilder::add_entries_contains auto-inserts these, and we must account
// for them in the statement count.
let anchored_key_used: Vec<Vec<Variable>> = (0..input.all_anchored_keys.len())
// custom_predicates[b][p] - custom predicate b is used in POD p
let custom_predicate_used: Vec<Vec<Variable>> = (0..all_custom_predicates.len())
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
@ -633,31 +579,19 @@ fn try_solve_with_pods(
.map(|(i, ext)| (ext.clone(), i))
.collect();
// content_group_used[g][p] - content group g has at least one statement proved in POD p
// When multiple statements have identical content, they share a slot in the POD.
// This variable tracks whether at least one statement from each content group is proved.
let num_groups = input.statement_content_groups.len();
let content_group_used: Vec<Vec<Variable>> = (0..num_groups)
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
.collect()
})
.collect();
if log::log_enabled!(log::Level::Debug) {
let estimate = ModelSizeEstimate::for_target_pods(
input,
target_pods,
all_batches.len(),
all_custom_predicates.len(),
external_pods.len(),
external_premises.len(),
debug_ctx,
);
log::debug!(
"MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \
public_external={} batch_used={} anchored_key_used={} uses_input={} \
uses_external={} content_group_used={}]",
public_external={} batch_used={} uses_input={} \
uses_external={}]",
target_pods,
estimate.vars_total,
estimate.vars_prove,
@ -665,14 +599,12 @@ fn try_solve_with_pods(
estimate.vars_pod_used,
estimate.vars_public_external,
estimate.vars_batch_used,
estimate.vars_anchored_key_used,
estimate.vars_uses_input,
estimate.vars_uses_external,
estimate.vars_content_group_used
);
log::debug!(
"MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \
c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \
c5i={} c5e={} c6_pre={} c6_limits={} c7={} c8a={} c8b={} c8c={} \
c8d={} c9={} c10={} c10b={}]",
target_pods,
estimate.constraints_total,
@ -686,7 +618,6 @@ fn try_solve_with_pods(
estimate.c6_pre_content_group,
estimate.c6_resource_limits,
estimate.c7_batch_cardinality,
estimate.c7b_anchored_key_tracking,
estimate.c8a_internal_inputs,
estimate.c8b_external_dep_inputs,
estimate.c8c_external_forward_inputs,
@ -798,35 +729,11 @@ fn try_solve_with_pods(
}
}
// Constraint 6: Resource limits per POD
//
// 6a-pre: Content group tracking for statement deduplication
// When multiple statement indices have identical content, they share a single slot in the POD.
// content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p.
for (g, group) in input.statement_content_groups.iter().enumerate() {
for p in 0..target_pods {
// Lower bound: if any statement in the group is proved, the group is used
for &s in group {
model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p]));
}
// Upper bound: if no statements in the group are proved, the group is not used
let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum();
model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum));
}
}
for p in 0..target_pods {
// 6a: Unique statement count (unique content groups + anchored key Contains)
// Statements with identical content share a slot, so we count content groups, not indices.
// Anchored key Contains statements are auto-inserted by MainPodBuilder when needed.
// The total must not exceed max_priv_statements (= max_statements - max_public_statements).
let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum();
let anchored_key_sum: Expression = (0..input.all_anchored_keys.len())
.map(|ak| anchored_key_used[ak][p])
.sum();
// 6a: Statement count
let stmt_sum: Expression = (0..n).map(|g| prove[g][p]).sum();
model.add_constraint(constraint!(
unique_stmt_sum + anchored_key_sum
<= (input.params.max_priv_statements() as f64) * pod_used[p]
stmt_sum <= (input.params.max_priv_statements() as f64) * pod_used[p]
));
// 6b: Public statement count (internal public statements + forwarded external premises)
@ -844,7 +751,7 @@ fn try_solve_with_pods(
.map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p])
.sum();
model.add_constraint(constraint!(
merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p]
merkle_sum <= (input.params.containers.state.max_medium as f64) * pod_used[p]
));
// 6d: Merkle state transitions
@ -852,11 +759,7 @@ fn try_solve_with_pods(
.map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p])
.sum();
model.add_constraint(constraint!(
mst_sum
<= (input
.params
.max_merkle_tree_state_transition_proofs_containers as f64)
* pod_used[p]
mst_sum <= (input.params.containers.transition.max_medium as f64) * pod_used[p]
));
// 6e: Custom predicate verifications
@ -885,67 +788,31 @@ fn try_solve_with_pods(
}
// Constraint 7: Batch cardinality
// batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it)
// batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it)
for (b, batch_id) in all_batches.iter().enumerate() {
// custom_predicate_used[b][p] >= prove[s][p] for all s that use custom predicate b (custom
// predicate is used if any statement uses it)
// custom_predicate_used[b][p] <= sum of prove[s][p] for all s using custom predicate b (custom
// predicate is 0 if no statements use it)
for (b, predicate_id) in all_custom_predicates.iter().enumerate() {
for p in 0..target_pods {
let mut sum: Expression = 0.into();
for s in 0..n {
if input.costs[s].custom_batch_ids.contains(batch_id) {
model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p]));
if input.costs[s].custom_predicates_ids.contains(predicate_id) {
model.add_constraint(constraint!(custom_predicate_used[b][p] >= prove[s][p]));
sum += prove[s][p];
}
}
model.add_constraint(constraint!(batch_used[b][p] <= sum));
model.add_constraint(constraint!(custom_predicate_used[b][p] <= sum));
}
}
// Constraint 7b: Anchored key tracking
//
// anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p.
// This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p.
//
// If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)):
// - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak
// - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p]
// This ensures overhead is 0 when the producer is in the same POD.
//
// If no Contains produces ak (anchored_key_producers[ak] = None):
// - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak
// - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak
// Auto-insertion is always needed when any user is present.
for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() {
let producer = input.anchored_key_producers[ak_idx];
// Custom predicate count per POD
for p in 0..target_pods {
let mut user_sum: Expression = 0.into();
for s in 0..n {
if input.costs[s].anchored_keys.contains(ak) {
if let Some(prod_idx) = producer {
// Producer exists: only count overhead if producer not in this POD
let custom_predicate_sum: Expression = (0..all_custom_predicates.len())
.map(|b| custom_predicate_used[b][p])
.sum();
model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p]
custom_predicate_sum <= (input.params.max_custom_predicates as f64) * pod_used[p]
));
} else {
// No producer: always need auto-insertion if user is present
model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] >= prove[s][p]
));
}
user_sum += prove[s][p];
}
}
if let Some(prod_idx) = producer {
// If producer is in POD, no auto-insertion needed (overhead = 0)
model.add_constraint(constraint!(
anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p]
));
} else {
// No producer: overhead is bounded by whether any user is present
model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum));
}
}
}
// Constraint 8a: Internal input POD tracking using uses_input.
@ -1147,9 +1014,6 @@ mod tests {
output_public_indices: &[],
params: &params,
max_pods: 20,
all_anchored_keys: &[],
anchored_key_producers: &[],
statement_content_groups: &[],
};
let result = solve(&input);
@ -1195,7 +1059,6 @@ mod tests {
};
let costs = vec![StatementCost::default(), StatementCost::default()];
let statement_content_groups = vec![vec![0], vec![1]];
let output_public = vec![1];
let input = SolverInput {
@ -1205,9 +1068,6 @@ mod tests {
output_public_indices: &output_public,
params: &params,
max_pods: 4,
all_anchored_keys: &[],
anchored_key_producers: &[],
statement_content_groups: &statement_content_groups,
};
let solution = solve(&input).expect("solver should find a feasible forwarding layout");

View file

@ -1,10 +1,10 @@
use std::fmt;
use std::{fmt, iter};
use crate::{
frontend::SignedDict,
middleware::{
containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux,
OperationType, Signature, Statement, TypedValue, Value, ValueRef,
OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS,
},
};
@ -39,10 +39,9 @@ impl OperationArg {
}
pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> {
self.value_and_ref().and_then(|(r, v)| match v.typed() {
&TypedValue::Int(i) => Some((r, i)),
_ => None,
})
self.value_and_ref()
.and_then(|(r, v)| v.as_int().map(|i| Some((r, i))))
.flatten()
}
}
@ -71,7 +70,7 @@ impl From<&Value> for OperationArg {
impl From<(&Dictionary, &str)> for OperationArg {
fn from((dict, key): (&Dictionary, &str)) -> Self {
// TODO: Use TryFrom
let value = dict.get(&key.into()).cloned().unwrap();
let value = dict.get(&key.into()).unwrap().unwrap();
Self::Statement(Statement::Contains(
dict.clone().into(),
key.into(),
@ -220,6 +219,24 @@ impl Operation {
op_impl_oa!(set_insert, SetInsertFromEntries, 3);
op_impl_oa!(set_delete, SetDeleteFromEntries, 3);
op_impl_oa!(array_update, ArrayUpdateFromEntries, 4);
pub fn replace_value_with_entry(args: Vec<Option<(&Dictionary, &str)>>, st: Statement) -> Self {
assert!(args.len() <= BASE_PARAMS.max_statement_args);
let args = args
.into_iter()
.chain(iter::repeat(None))
.take(BASE_PARAMS.max_statement_args)
.map(|a| match a {
None => OperationArg::Statement(Statement::None),
Some((dict, key)) => OperationArg::from((dict, key)),
})
.chain(iter::once(OperationArg::Statement(st)))
.collect();
Self(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
args,
OperationAux::None,
)
}
pub fn signed_by(
msg: impl Into<OperationArg>,
pk: impl Into<OperationArg>,

View file

@ -83,7 +83,7 @@ mod tests {
middleware::{
self,
containers::{Array, Dictionary, Set},
Params, Signer as _, TypedValue, DEFAULT_VD_LIST,
Params, Signer as _, Value, DEFAULT_VD_LIST,
},
};
@ -91,16 +91,15 @@ mod tests {
fn test_value_serialization() {
// Pairs of values and their expected serialized representations
let values = vec![
(TypedValue::String("hello".to_string()), "\"hello\""),
(TypedValue::Int(42), "{\"Int\":\"42\"}"),
(TypedValue::Bool(true), "true"),
(Value::from("hello"), "\"hello\""),
(Value::from(42), "{\"Int\":\"42\"}"),
(Value::from(true), r#"{"Int":"1"}"#),
(
TypedValue::Array(Array::new(vec!["foo".into(), false.into()])),
"{\"array\":[\"foo\",false]}",
Value::from(Array::new(vec![Value::from("foo"), Value::from(false)])),
r#"{"inner":[[{"Int":"0"},"foo"],[{"Int":"1"},{"Int":"0"}]]}"#,
),
(
TypedValue::Dictionary(
Dictionary::new(HashMap::from([
Value::from(Dictionary::new(HashMap::from([
// The set of valid keys is equal to the set of valid JSON keys
("foo".into(), 123.into()),
// Empty strings are valid JSON keys
@ -113,26 +112,25 @@ mod tests {
(("\0".into()), "".into()),
// Keys can contain emojis
(("🥳".into()), "party time!".into()),
]))
),
"{\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}",
]))),
r#"{"inner":[["!@£$%^&&*()",""],["🥳","party time!"],[" hi",{"Int":"0"}],["foo",{"Int":"123"}],["\u0000",""],["","baz"]]}"#,
),
(
TypedValue::Set(Set::new(HashSet::from(["foo".into(), "bar".into()]))),
"{\"set\":[\"bar\",\"foo\"]}",
Value::from(Set::new(HashSet::from(["foo".into(), "bar".into()]))),
r#"{"inner":[["bar"],["foo"]]}"#,
),
];
for (value, expected) in values {
let serialized = serde_json::to_string(&value).unwrap();
assert_eq!(serialized, expected);
let deserialized: TypedValue = serde_json::from_str(&serialized).unwrap();
let deserialized: Value = serde_json::from_str(&serialized).unwrap();
assert_eq!(
value, deserialized,
"value {:#?} should equal deserialized {:#?}",
value, deserialized
);
let expected_deserialized: TypedValue = serde_json::from_str(expected).unwrap();
let expected_deserialized: Value = serde_json::from_str(expected).unwrap();
assert_eq!(value, expected_deserialized);
}
}
@ -177,7 +175,10 @@ mod tests {
"deserialized: {}",
serde_json::to_string_pretty(&deserialized).unwrap()
);
assert_eq!(signed_dict.dict.kvs(), deserialized.dict.kvs());
assert_eq!(
signed_dict.dict.dump().unwrap(),
deserialized.dict.dump().unwrap()
);
assert_eq!(signed_dict.public_key, deserialized.public_key);
assert_eq!(signed_dict.signature, deserialized.signature);
assert_eq!(signed_dict.verify().is_ok(), deserialized.verify().is_ok());

View file

@ -174,18 +174,6 @@ fn render_validation_error(
"second REQUEST here",
),
ValidationError::InvalidArgumentType { predicate, span } => {
let title = format!("invalid argument type for `{}`", predicate);
render_with_optional_span(
renderer,
source,
path,
&title,
span.as_ref(),
"anchored keys not allowed here",
)
}
ValidationError::DuplicateWildcard { name, span } => {
let title = format!("duplicate wildcard: {}", name);
render_with_optional_span(
@ -287,6 +275,17 @@ fn render_validation_error(
ValidationError::NoRequestBlock => {
render_title_only(renderer, "requests must contain a REQUEST block")
}
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { span } => {
render_with_optional_span(
renderer,
source,
path,
"self-referential predicate literal not allowed in requests",
span.as_ref(),
"not allowed here",
)
}
}
}

View file

@ -135,12 +135,6 @@ pub enum ValidationError {
span: Option<Span>,
},
#[error("Invalid argument type for {predicate}: anchored keys not allowed")]
InvalidArgumentType {
predicate: String,
span: Option<Span>,
},
#[error("Duplicate wildcard in predicate arguments: {name}")]
DuplicateWildcard { name: String, span: Option<Span> },
@ -165,6 +159,9 @@ pub enum ValidationError {
#[error("Modules must contain at least one predicate definition")]
NoPredicatesInModule,
#[error("Self-referential predicate literal not allowed in requests")]
SelfReferentialPredicateLiteralNotAllowedInRequests { span: Option<Span> },
#[error("Requests must contain a REQUEST block")]
NoRequestBlock,
}

View file

@ -116,6 +116,8 @@ pub enum StatementTmplArg {
Literal(LiteralValue),
Wildcard(Identifier),
AnchoredKey(AnchoredKey),
/// Hash of a same-module predicate, resolved at batch finalization time.
SelfPredicateHash(Identifier),
}
/// Anchored key: Var["key"] or Var.key
@ -168,6 +170,13 @@ pub enum LiteralValue {
Array(LiteralArray),
Set(LiteralSet),
Dict(LiteralDict),
/// Hash of a native predicate (resolved immediately).
NativePredicateHash(Identifier),
/// Hash of an external module's predicate (resolved immediately).
ExternalPredicateHash {
module: Identifier,
predicate: Identifier,
},
}
/// Integer literal
@ -391,6 +400,9 @@ impl fmt::Display for StatementTmplArg {
StatementTmplArg::Literal(lit) => write!(f, "{}", lit),
StatementTmplArg::Wildcard(id) => write!(f, "{}", id),
StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak),
StatementTmplArg::SelfPredicateHash(id) => {
write!(f, "@self_predicate({})", id)
}
}
}
}
@ -422,6 +434,12 @@ impl fmt::Display for LiteralValue {
LiteralValue::Array(a) => write!(f, "{}", a),
LiteralValue::Set(s) => write!(f, "{}", s),
LiteralValue::Dict(d) => write!(f, "{}", d),
LiteralValue::NativePredicateHash(id) => {
write!(f, "@native_predicate({})", id)
}
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => write!(f, "@external_predicate({}, {})", module, predicate),
}
}
}
@ -769,6 +787,10 @@ pub mod parse {
let inner = pair.into_inner().next().unwrap();
match inner.as_rule() {
Rule::predicate_hash_self => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(StatementTmplArg::SelfPredicateHash(id))
}
Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)),
Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))),
Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)),
@ -823,6 +845,16 @@ pub mod parse {
Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)),
Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)),
Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)),
Rule::predicate_hash_native => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(LiteralValue::NativePredicateHash(id))
}
Rule::predicate_hash_external => {
let mut parts = inner.into_inner();
let module = parse_identifier(parts.next().unwrap());
let predicate = parse_identifier(parts.next().unwrap());
Ok(LiteralValue::ExternalPredicateHash { module, predicate })
}
_ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()),
}
}
@ -1104,6 +1136,7 @@ mod tests {
AnchoredKeyPath::Dot(id) => id.span = None,
}
}
StatementTmplArg::SelfPredicateHash(id) => id.span = None,
}
}
}
@ -1139,6 +1172,13 @@ mod tests {
clear_literal_spans(&mut pair.value);
}
}
LiteralValue::NativePredicateHash(id) => id.span = None,
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => {
module.span = None;
predicate.span = None;
}
}
}

View file

@ -157,8 +157,10 @@ fn resolve_local_predicate(
/// Lower a literal value from AST to middleware Value.
///
/// This is a pure conversion that cannot fail.
pub fn lower_literal(lit: &LiteralValue) -> Value {
/// This is a pure conversion that cannot fail for context-free literals.
/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when
/// external predicate references may appear (e.g. inside containers).
pub(crate) fn lower_literal(lit: &LiteralValue) -> Value {
match lit {
LiteralValue::Int(i) => Value::from(i.value),
LiteralValue::Bool(b) => Value::from(b.value),
@ -190,13 +192,83 @@ pub fn lower_literal(lit: &LiteralValue) -> Value {
let dict = containers::Dictionary::new(pairs);
Value::from(dict)
}
LiteralValue::NativePredicateHash(id) => {
let np = NativePredicate::from_str(&id.name).expect("validated native predicate");
Value::from(Predicate::Native(np).hash())
}
LiteralValue::ExternalPredicateHash { .. } => {
unreachable!(
"ExternalPredicateHash must be lowered with context via lower_literal_with_context"
)
}
}
}
/// Lower a literal value, resolving external predicate references using the symbol table.
pub fn lower_literal_with_context(
lit: &LiteralValue,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<Value, LoweringError> {
match lit {
LiteralValue::ExternalPredicateHash { module, predicate } => {
let pred_or_wc = resolve_predicate_ref(
&PredicateRef::Qualified {
module: module.clone(),
predicate: predicate.clone(),
},
symbols,
context,
)
.ok_or_else(|| LoweringError::PredicateNotFound {
name: format!("{}::{}", module.name, predicate.name),
})?;
let pred = match pred_or_wc {
crate::frontend::PredicateOrWildcard::Predicate(p) => p,
_ => unreachable!(
"`resolve_predicate_ref` always returns `PredicateOrWildcard::Predicate` on `PredicateRef::Qualified`"
)
};
Ok(Value::from(pred.hash()))
}
LiteralValue::Array(a) => {
let elements: Vec<_> = a
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Array::new(elements)))
}
LiteralValue::Set(s) => {
let elements: std::collections::HashSet<_> = s
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Set::new(elements)))
}
LiteralValue::Dict(d) => {
let pairs: HashMap<_, _> = d
.pairs
.iter()
.map(|pair| {
let key = Key::from(pair.key.value.as_str());
let value = lower_literal_with_context(&pair.value, symbols, context)?;
Ok((key, value))
})
.collect::<Result<_, LoweringError>>()?;
Ok(Value::from(containers::Dictionary::new(pairs)))
}
// All other variants are context-free
other => Ok(lower_literal(other)),
}
}
/// Lower a statement argument from AST to BuilderArg.
///
/// This is a pure conversion that cannot fail.
pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
/// Context-free for most arg types. Panics on ExternalPredicateHash inside literals —
/// use `lower_statement_arg_with_context` when external predicate references may appear.
pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
match arg {
StatementTmplArg::Literal(lit) => {
let value = lower_literal(lit);
@ -210,6 +282,25 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
};
BuilderArg::Key(ak.root.name.clone(), key_str)
}
StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()),
}
}
/// Lower a statement argument, resolving external predicate references using the symbol table.
pub fn lower_statement_arg_with_context(
arg: &StatementTmplArg,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<BuilderArg, LoweringError> {
match arg {
StatementTmplArg::Literal(lit) => {
let value = lower_literal_with_context(lit, symbols, context)?;
Ok(BuilderArg::Literal(value))
}
StatementTmplArg::SelfPredicateHash(id) => {
Ok(BuilderArg::SelfPredicateHash(id.name.clone()))
}
other => Ok(lower_statement_arg(other)),
}
}
@ -324,7 +415,7 @@ impl<'a> Lowerer<'a> {
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate.clone());
for arg in &stmt.args {
let builder_arg = lower_statement_arg(arg);
let builder_arg = lower_statement_arg_with_context(arg, symbols, &context)?;
builder = builder.arg(builder_arg);
}
let desugared = builder.desugar();
@ -346,6 +437,9 @@ impl<'a> Lowerer<'a> {
let key = Key::from(key_str.as_str());
MWStatementTmplArg::AnchoredKey(wildcard, key)
}
BuilderArg::SelfPredicateHash(_) => {
unreachable!("SelfPredicateHash should not appear in request lowering")
}
};
mw_args.push(mw_arg);
}
@ -399,7 +493,7 @@ impl<'a> Lowerer<'a> {
names.push(ak.root.name.clone());
}
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
}
}
}

View file

@ -123,7 +123,7 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
StatementTmplArg::AnchoredKey(ak) => {
wildcards.insert(ak.root.name.clone());
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
}
}

View file

@ -522,7 +522,7 @@ impl Validator {
}
// Validate arguments
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?;
self.validate_statement_args(stmt, wildcard_context)?;
Ok(())
}
@ -530,40 +530,8 @@ impl Validator {
fn validate_statement_args(
&self,
stmt: &StatementTmpl,
pred_info: Option<&PredicateInfo>,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. })
| Some(PredicateKind::BatchImported { .. })
| Some(PredicateKind::ModuleImported { .. })
) {
for arg in &stmt.args {
match arg {
StatementTmplArg::AnchoredKey(_) => {
return Err(ValidationError::InvalidArgumentType {
predicate: stmt.predicate.predicate_name().to_string(),
span: stmt.span,
});
}
StatementTmplArg::Wildcard(id) => {
if let Some((pred_name, scope)) = wildcard_context {
if !scope.wildcards.contains_key(&id.name) {
return Err(ValidationError::UndefinedWildcard {
name: id.name.clone(),
pred_name: pred_name.to_string(),
span: id.span,
});
}
}
}
StatementTmplArg::Literal(_) => {}
}
}
} else {
// Native predicates can have anchored keys
for arg in &stmt.args {
match arg {
StatementTmplArg::Wildcard(id) => {
@ -588,13 +556,91 @@ impl Validator {
}
}
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
}
}
}
Ok(())
}
/// Validate a @self_predicate reference: the name must be a custom predicate in this module.
fn validate_self_predicate_hash(
&self,
id: &Identifier,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// @self_predicate only makes sense inside module predicate definitions
if wildcard_context.is_none() {
return Err(
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests {
span: id.span,
},
);
}
// Must refer to a custom predicate defined in this module (not intro/imported)
match self.symbols.predicates.get(&id.name) {
Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()),
_ => Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
}),
}
}
/// Recursively validate a literal value, checking predicate hash references.
fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> {
match lit {
LiteralValue::NativePredicateHash(id) => {
if NativePredicate::from_str(&id.name).is_err() {
return Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
});
}
Ok(())
}
LiteralValue::ExternalPredicateHash { module, predicate } => {
if let Some(imported) = self.symbols.imported_modules.get(&module.name) {
if !imported.predicate_index.contains_key(&predicate.name) {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module.name, predicate.name),
span: predicate.span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module.name.clone(),
span: module.span,
});
}
Ok(())
}
LiteralValue::Array(a) => {
for elem in &a.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Set(s) => {
for elem in &s.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Dict(d) => {
for pair in &d.pairs {
self.validate_literal_value(&pair.value)?;
}
Ok(())
}
_ => Ok(()),
}
}
}
#[cfg(test)]
@ -755,10 +801,7 @@ mod tests {
module_hash
);
let result = parse_and_validate_request(&input, &available_modules);
assert!(matches!(
result,
Err(ValidationError::InvalidArgumentType { .. })
));
assert!(result.is_ok());
}
#[test]

View file

@ -49,7 +49,14 @@ custom_predicate_def = {
statement_list = { statement+ }
statement_arg = { literal_value | anchored_key | identifier }
// Predicate hash literals: resolve to the predicate's identity hash as a value.
// @native_predicate and @external_predicate are in literal_value (usable in containers).
// @self_predicate is only in statement_arg (not in containers — deferred resolution).
predicate_hash_native = { "@native_predicate" ~ "(" ~ identifier ~ ")" }
predicate_hash_external = { "@external_predicate" ~ "(" ~ identifier ~ "," ~ identifier ~ ")" }
predicate_hash_self = { "@self_predicate" ~ "(" ~ identifier ~ ")" }
statement_arg = { predicate_hash_self | literal_value | anchored_key | identifier }
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
// Predicate reference: either qualified (module::predicate) or local (predicate)
@ -74,6 +81,8 @@ literal_value = {
literal_bool |
literal_raw |
literal_string |
predicate_hash_native |
predicate_hash_external |
literal_int
}

View file

@ -578,7 +578,6 @@ mod tests {
max_input_pods: 3,
max_statements: 31,
max_public_statements: 10,
max_operation_args: 5,
max_custom_predicate_wildcards: 12,
..Default::default()
};

View file

@ -11,7 +11,9 @@ use crate::{
lang::{
error::BatchingError,
frontend_ast::{ConjunctionType, CustomPredicateDef},
frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext},
frontend_ast_lower::{
lower_statement_arg_with_context, resolve_predicate_ref, ResolutionContext,
},
frontend_ast_split::{SplitChainInfo, SplitResult},
frontend_ast_validate::SymbolTable,
},
@ -345,7 +347,9 @@ fn build_single_batch(
})?;
}
Ok(builder.finish())
builder.finish().map_err(|e| BatchingError::Internal {
message: format!("Failed to finalize batch '{}': {}", batch_name, e),
})
}
/// Build a statement template with properly resolved predicate references
@ -372,7 +376,13 @@ fn build_statement_with_resolved_refs(
let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg));
let builder_arg =
lower_statement_arg_with_context(arg, symbols, &context).map_err(|e| {
BatchingError::Internal {
message: format!("Failed to lower argument: {}", e),
}
})?;
builder = builder.arg(builder_arg);
}
Ok(builder)
@ -668,4 +678,110 @@ mod tests {
PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref))
);
}
#[test]
fn test_self_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
// pred_B is at index 1, its template should have SelfPredicateHash(0) resolved
// to a Literal containing pred_A's hash after normalization
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = crate::middleware::Value::from(Predicate::Custom(pred_a_ref).hash());
// Use normalized_predicate to resolve
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
assert_eq!(
normalized.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_self_predicate_hash_podlang_cyclic() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash =
crate::middleware::Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash =
crate::middleware::Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should contain pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_b_hash)
);
// pred_B's normalized form should contain pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_native_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_C(x) = AND(
Equal(x, @native_predicate(Equal))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_c_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_c = pred_c_ref.predicate();
// The second arg should be a Literal containing Equal's predicate hash
let equal_hash = crate::middleware::Value::from(
Predicate::Native(crate::middleware::NativePredicate::Equal).hash(),
);
assert_eq!(
pred_c.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(equal_hash)
);
}
}

View file

@ -137,6 +137,9 @@ mod tests {
assert_inner(&Rule::anchored_key, "someVar[\"key\"]");
assert_inner(&Rule::literal_value, "true");
assert_inner(&Rule::literal_value, "PublicKey(abc)");
assert_inner(&Rule::predicate_hash_self, "@self_predicate(foo)");
assert_inner(&Rule::literal_value, "@native_predicate(Equal)");
assert_inner(&Rule::literal_value, "@external_predicate(mod_a, pred_b)");
}
#[test]
@ -207,6 +210,33 @@ mod tests {
"{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ",
);
assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes
// Predicate hash literals
assert_parses(Rule::predicate_hash_native, "@native_predicate(Equal)");
assert_parses(Rule::predicate_hash_native, "@native_predicate(Lt)");
assert_parses(
Rule::predicate_hash_external,
"@external_predicate(my_module, my_pred)",
);
assert_parses(Rule::predicate_hash_self, "@self_predicate(local_pred)");
// Predicate hashes inside containers (native and external only)
assert_parses(
Rule::literal_array,
"[1, @native_predicate(Equal), @external_predicate(m, p)]",
);
assert_parses(
Rule::literal_set,
"#[@native_predicate(Equal), @native_predicate(Lt)]",
);
assert_parses(
Rule::literal_dict,
"{ \"pred\": @external_predicate(m, p) }",
);
// @self_predicate is NOT a literal_value, so it cannot appear inside containers
assert_fails(Rule::test_literal_value, "@self_predicate(local_pred)");
assert_fails(Rule::literal_array, "[@self_predicate(foo)]");
}
#[test]

View file

@ -92,7 +92,7 @@ impl StatementTmpl {
if i > 0 {
write!(w, ", ")?;
}
arg.fmt_podlang(w)?;
arg.fmt_podlang_with_batch_context(w, batch_context)?;
}
write!(w, ")")?;
@ -102,7 +102,30 @@ impl StatementTmpl {
impl PrettyPrint for StatementTmplArg {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self)
self.fmt_podlang_with_batch_context(w, None)
}
}
impl StatementTmplArg {
fn fmt_podlang_with_batch_context(
&self,
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match self {
StatementTmplArg::SelfPredicateHash(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates().get(*index) {
write!(w, "@self_predicate({})", predicate.name)
} else {
write!(w, "@self_predicate(self_{})", index)
}
} else {
write!(w, "@self_predicate(self_{})", index)
}
}
other => write!(w, "{}", other),
}
}
}
@ -131,7 +154,7 @@ impl CustomPredicateBatch {
impl PrettyPrint for Value {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self.typed())
write!(w, "{}", self.typed)
}
}
@ -540,6 +563,34 @@ mod tests {
assert_round_trip(&input);
}
#[test]
fn test_round_trip_self_predicate_hash() {
let input = r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_self_predicate_hash_cyclic() {
let input = r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test]
fn test_pretty_print_demonstration() {
let input = r#"

View file

@ -169,6 +169,12 @@ pub struct Hash(
pub [F; HASH_SIZE],
);
impl Hash {
pub fn raw(self) -> RawValue {
RawValue::from(self)
}
}
impl From<Hash> for HashOut {
fn from(hash: Hash) -> HashOut {
HashOut { elements: hash.0 }

View file

@ -1,29 +1,260 @@
//! This file implements the types defined at
//! <https://0xparc.github.io/pod2/values.html#dictionary-array-set> .
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
fmt::{self, Debug},
};
use schemars::JsonSchema;
use serde::{Deserialize, Deserializer, Serialize};
use serde::{
de::{Error as _, SeqAccess, Visitor},
ser, Deserialize, Deserializer, Serialize,
};
use super::serialization::{ordered_map, ordered_set};
#[cfg(feature = "backend_plonky2")]
use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree};
use crate::backends::plonky2::primitives::merkletree::{self, MerkleProof, MerkleTree};
use crate::{
backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof,
middleware::{Error, Hash, Key, RawValue, Result, Value},
middleware::{
db::{mem::MemDB, DB},
Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH,
},
};
#[derive(Clone, Debug)]
pub struct Container {
root: Hash,
db: Box<dyn DB>,
}
impl JsonSchema for Container {
fn schema_name() -> String {
"Container".to_string()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
// Just use the schema of Vec<Vec<Value>> since that's what we're actually serializing
Vec::<Vec<Value>>::json_schema(gen)
}
}
impl Serialize for Container {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut pairs = self
.iter()
.collect::<Result<Vec<(Value, Value)>>>()
.map_err(ser::Error::custom)?;
pairs.sort_by(|(k1, _), (k2, _)| k1.raw().cmp(&k2.raw()));
// Serialize as an array
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(pairs.len()))?;
for (k, v) in pairs {
if k == v {
seq.serialize_element(&[&v])?;
} else {
seq.serialize_element(&[&k, &v])?;
}
}
seq.end()
}
}
struct ContainerVisitor;
impl<'de> Visitor<'de> for ContainerVisitor {
type Value = HashMap<Value, Value>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of `[Value]` or `[Value, Value]`")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut kvs = HashMap::<Value, Value>::new();
while let Some(mut elem) = seq.next_element::<Vec<Value>>()? {
match elem.len() {
1 => {
let v = elem.pop().unwrap();
kvs.insert(v.clone(), v);
}
2 => {
let (v, k) = (elem.pop().unwrap(), elem.pop().unwrap());
kvs.insert(k, v);
}
n => {
return Err(A::Error::custom(format!(
"invalid vec length of {n} in container entry"
)))
}
}
}
Ok(kvs)
}
}
impl<'de> Deserialize<'de> for Container {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let kvs = deserializer.deserialize_seq(ContainerVisitor)?;
Ok(Container::new(kvs))
}
}
impl PartialEq for Container {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
}
}
impl Eq for Container {}
fn store_container_mt(db: &mut dyn DB, container: &Container) -> Result<()> {
match db.load_node(container.root) {
Err(e) => return Err(Error::Database(e)),
// Container already exists in the DB
Ok(Some(_)) => return Ok(()),
// Container not existing, we need to save it
Ok(None) => {}
};
let mut container_copy = Container::empty_with_db(db.clone_box());
for kv_result in container.iter() {
let (k, v) = kv_result?;
container_copy.insert(k, v)?;
}
Ok(())
}
fn store_value(db: &mut dyn DB, v: Value) -> Result<()> {
match &v.typed {
TypedValue::Set(Set { inner })
| TypedValue::Dictionary(Dictionary { inner })
| TypedValue::Array(Array { inner }) => {
if db.is_persistent() {
store_container_mt(db, inner)?;
}
db.store_value(v).map_err(Error::Database)?
}
_ => db.store_value(v).map_err(Error::Database)?,
}
Ok(())
}
fn load_value(db: &dyn DB, value_raw: RawValue) -> Result<Value> {
match db.load_value(value_raw) {
Err(e) => Err(Error::Database(e)),
Ok(Some(v)) => Ok(v),
Ok(None) => Err(Error::custom(format!(
"Value from {value_raw} not found in DB"
))),
}
}
impl Container {
fn mt(&self) -> MerkleTree {
MerkleTree::from_db(self.root, self.db.clone())
}
pub fn new(kvs: HashMap<Value, Value>) -> Self {
let db = Box::new(MemDB::new());
let mut container = Self::empty_with_db(db);
for (k, v) in kvs {
container.insert(k, v).expect("no duplicates, no db errors");
}
container
}
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self::from_db(EMPTY_HASH, db).expect("EMPTY_HASH exists implicitly")
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
// Make sure the root exists in the db
let _ = merkletree::load_node(db.as_ref(), root)?;
Ok(Self { root, db })
}
pub fn commitment(&self) -> Hash {
self.root
}
pub fn get(&self, key_raw: RawValue) -> Result<Option<Value>> {
Ok(match self.mt().get(&key_raw)? {
Some(value_raw) => Some(load_value(self.db.as_ref(), value_raw)?),
None => None,
})
}
pub fn prove(&self, key_raw: RawValue) -> Result<(Value, MerkleProof)> {
let (value_raw, mtp) = self.mt().prove(&key_raw)?;
let value = load_value(self.db.as_ref(), value_raw)?;
Ok((value, mtp))
}
pub fn prove_nonexistence(&self, key_raw: RawValue) -> Result<MerkleProof> {
Ok(self.mt().prove_nonexistence(&key_raw)?)
}
pub fn insert(&mut self, key: Value, value: Value) -> Result<MerkleTreeStateTransitionProof> {
let (key_raw, value_raw) = (key.raw(), value.raw());
store_value(self.db.as_mut(), key)?;
store_value(self.db.as_mut(), value)?;
let mut mt = self.mt();
let mtp = mt.insert(&key_raw, &value_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn update(
&mut self,
key_raw: RawValue,
value: Value,
) -> Result<MerkleTreeStateTransitionProof> {
let value_raw = value.raw();
store_value(self.db.as_mut(), value)?;
let mut mt = self.mt();
let mtp = mt.update(&key_raw, &value_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn delete(&mut self, key_raw: RawValue) -> Result<MerkleTreeStateTransitionProof> {
let mut mt = self.mt();
let mtp = mt.delete(&key_raw)?;
self.root = mt.root();
Ok(mtp)
}
pub fn verify(
root: Hash,
proof: &MerkleProof,
key_raw: RawValue,
value_raw: RawValue,
) -> Result<()> {
Ok(MerkleTree::verify(root, proof, &key_raw, &value_raw)?)
}
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key_raw: RawValue) -> Result<()> {
Ok(MerkleTree::verify_nonexistence(root, proof, &key_raw)?)
}
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into())
}
pub fn iter(&self) -> impl Iterator<Item = Result<(Value, Value)>> {
let db = self.db.clone();
self.mt().iter().map(move |(key_raw, value_raw)| {
let key = load_value(db.as_ref(), key_raw)?;
let value = load_value(db.as_ref(), value_raw)?;
Ok((key, value))
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<Value, Value>> {
self.iter().collect()
}
}
/// Dictionary: the user original keys and values are hashed to be used in the leaf.
/// leaf.key=hash(original_key)
/// leaf.value=hash(original_value)
#[derive(Clone, Debug, Serialize, JsonSchema)]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Dictionary {
#[serde(skip)]
#[schemars(skip)]
mt: MerkleTree,
#[serde(serialize_with = "ordered_map")]
kvs: HashMap<Key, Value>,
pub(crate) inner: Container,
}
#[macro_export]
@ -34,255 +265,371 @@ macro_rules! dict {
({ $($key:expr => $val:expr),* }) => ({
let mut map = ::std::collections::HashMap::new();
$( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )*
$crate::middleware::containers::Dictionary::new( map)
$crate::middleware::containers::Dictionary::new(map)
});
}
// TODO: Replace all methods that receive a `&Key` by either `impl Into<String>` for write
// methods and `impl AsRef<str>` for read methods.
// TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a
// trait?
impl Dictionary {
pub fn new(kvs: HashMap<Key, Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> =
kvs.iter().map(|(k, v)| (k.raw(), v.raw())).collect();
Self {
mt: MerkleTree::new(&kvs_raw),
kvs,
inner: Container::new(
kvs.into_iter()
.map(|(k, v)| (Value::from(k.name), v))
.collect(),
),
}
}
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self {
inner: Container::empty_with_db(db),
}
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Ok(Self {
inner: Container::from_db(root, db)?,
})
}
pub fn commitment(&self) -> Hash {
self.mt.root()
self.inner.commitment()
}
pub fn get(&self, key: &Key) -> Result<&Value> {
self.kvs
.get(key)
.ok_or_else(|| Error::custom(format!("key \"{}\" not found", key.name())))
pub fn get(&self, key: &Key) -> Result<Option<Value>> {
self.inner.get(key.raw())
}
pub fn prove(&self, key: &Key) -> Result<(&Value, MerkleProof)> {
let (_, mtp) = self.mt.prove(&key.raw())?;
let value = self.kvs.get(key).expect("key exists");
Ok((value, mtp))
pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> {
self.inner.prove(key.raw())
}
pub fn prove_nonexistence(&self, key: &Key) -> Result<MerkleProof> {
Ok(self.mt.prove_nonexistence(&key.raw())?)
self.inner.prove_nonexistence(key.raw())
}
pub fn insert(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.insert(&key.raw(), &value.raw())?;
self.kvs.insert(key.clone(), value.clone());
Ok(mtp)
self.inner
.insert(Value::from(key.name.clone()), value.clone())
}
pub fn update(&mut self, key: &Key, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.update(&key.raw(), &value.raw())?;
self.kvs.insert(key.clone(), value.clone());
Ok(mtp)
self.inner.update(key.raw(), value.clone())
}
pub fn delete(&mut self, key: &Key) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.delete(&key.raw())?;
self.kvs.remove(key);
Ok(mtp)
self.inner.delete(key.raw())
}
pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> {
let key = key.raw();
Ok(MerkleTree::verify(root, proof, &key, &value.raw())?)
Container::verify(root, proof, key.raw(), value.raw())
}
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> {
let key = key.raw();
Ok(MerkleTree::verify_nonexistence(root, proof, &key)?)
Container::verify_nonexistence(root, proof, key.raw())
}
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into())
Container::verify_state_transition(proof)
}
// TODO: Rename to dict to be consistent maybe?
pub fn kvs(&self) -> &HashMap<Key, Value> {
&self.kvs
pub fn iter(&self) -> impl Iterator<Item = Result<(String, Value)>> + use<'_> {
self.inner.iter().map(|r| match r {
Ok((key, value)) => Ok((
key.as_string()
.ok_or_else(|| Error::custom("dictionary: key is not string"))?,
value,
)),
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<String, Value>> {
self.iter().collect()
}
}
impl PartialEq for Dictionary {
fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root()
self.inner.eq(&other.inner)
}
}
impl Eq for Dictionary {}
impl<'de> Deserialize<'de> for Dictionary {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Aux {
#[serde(serialize_with = "ordered_map")]
kvs: HashMap<Key, Value>,
}
let aux = Aux::deserialize(deserializer)?;
Ok(Dictionary::new(aux.kvs))
}
}
/// Set: the value field of the leaf is unused, and the key contains the hash of the element.
/// leaf.key=hash(original_value)
/// leaf.value=0
#[derive(Clone, Debug, Serialize, JsonSchema)]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Set {
#[serde(skip)]
#[schemars(skip)]
mt: MerkleTree,
#[serde(serialize_with = "ordered_set")]
set: HashSet<Value>,
pub(crate) inner: Container,
}
impl Set {
pub fn new(set: HashSet<Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> = set
.iter()
.map(|e| {
let rv = e.raw();
(rv, rv)
})
.collect();
Self {
mt: MerkleTree::new(&kvs_raw),
set,
inner: Container::new(set.into_iter().map(|v| (v.clone(), v)).collect()),
}
}
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self {
inner: Container::empty_with_db(db),
}
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Ok(Self {
inner: Container::from_db(root, db)?,
})
}
pub fn commitment(&self) -> Hash {
self.mt.root()
self.inner.commitment()
}
pub fn contains(&self, value: &Value) -> bool {
self.set.contains(value)
pub fn contains(&self, value: &Value) -> Result<bool> {
Ok(self.inner.get(value.raw())?.is_some())
}
pub fn prove(&self, value: &Value) -> Result<MerkleProof> {
let rv = value.raw();
let (_, proof) = self.mt.prove(&rv)?;
let (_, proof) = self.inner.prove(value.raw())?;
Ok(proof)
}
pub fn prove_nonexistence(&self, value: &Value) -> Result<MerkleProof> {
let rv = value.raw();
Ok(self.mt.prove_nonexistence(&rv)?)
self.inner.prove_nonexistence(value.raw())
}
pub fn insert(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let raw_value = value.raw();
let mtp = self.mt.insert(&raw_value, &raw_value)?;
self.set.insert(value.clone());
Ok(mtp)
self.inner.insert(value.clone(), value.clone())
}
pub fn delete(&mut self, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.delete(&value.raw())?;
self.set.remove(value);
Ok(mtp)
self.inner.delete(value.raw())
}
pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> {
let rv = value.raw();
Ok(MerkleTree::verify(root, proof, &rv, &rv)?)
Container::verify(root, proof, value.raw(), value.raw())
}
pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> {
let rv = value.raw();
Ok(MerkleTree::verify_nonexistence(root, proof, &rv)?)
Container::verify_nonexistence(root, proof, value.raw())
}
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into())
Container::verify_state_transition(proof)
}
pub fn set(&self) -> &HashSet<Value> {
&self.set
pub fn iter(&self) -> impl Iterator<Item = Result<Value>> + use<'_> {
self.inner.iter().map(|r| match r {
Ok((key, value)) => {
if key != value {
return Err(Error::custom("set: key != value"));
}
Ok(value)
}
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashSet<Value>> {
self.iter().collect()
}
}
impl PartialEq for Set {
fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root()
self.inner.eq(&other.inner)
}
}
impl Eq for Set {}
impl<'de> Deserialize<'de> for Set {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize, JsonSchema)]
struct Aux {
#[serde(serialize_with = "ordered_set")]
set: HashSet<Value>,
}
let aux = Aux::deserialize(deserializer)?;
Ok(Set::new(aux.set))
}
}
/// Array: the elements are placed at the value field of each leaf, and the key field is just the
/// array index (integer).
/// leaf.key=i
/// leaf.value=original_value
#[derive(Clone, Debug, Serialize, JsonSchema)]
/// Due to its construction this should be seen as a sparse array, where there can be gaps
/// (unused indices).
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct Array {
#[serde(skip)]
#[schemars(skip)]
mt: MerkleTree,
array: Vec<Value>,
pub(crate) inner: Container,
}
impl Array {
pub fn new(array: Vec<Value>) -> Self {
let kvs_raw: HashMap<RawValue, RawValue> = array
.iter()
.enumerate()
.map(|(i, e)| (RawValue::from(i as i64), e.raw()))
.collect();
Self {
mt: MerkleTree::new(&kvs_raw),
array,
inner: Container::new(
array
.into_iter()
.enumerate()
.map(|(i, v)| (Value::from(i as i64), v))
.collect(),
),
}
}
pub fn commitment(&self) -> Hash {
self.mt.root()
pub fn empty_with_db(db: Box<dyn DB>) -> Self {
Self {
inner: Container::empty_with_db(db),
}
pub fn get(&self, i: usize) -> Result<&Value> {
self.array.get(i).ok_or_else(|| {
Error::custom(format!("index {} out of bounds 0..{}", i, self.array.len()))
}
pub fn from_db(root: Hash, db: Box<dyn DB>) -> Result<Self> {
Ok(Self {
inner: Container::from_db(root, db)?,
})
}
pub fn prove(&self, i: usize) -> Result<(&Value, MerkleProof)> {
let (_, mtp) = self.mt.prove(&RawValue::from(i as i64))?;
let value = self.array.get(i).expect("valid index");
Ok((value, mtp))
pub fn commitment(&self) -> Hash {
self.inner.commitment()
}
pub fn get(&self, i: usize) -> Result<Option<Value>> {
self.inner.get(Value::from(i as i64).raw())
}
pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> {
self.inner.prove(Value::from(i as i64).raw())
}
pub fn insert(&mut self, i: usize, value: Value) -> Result<MerkleTreeStateTransitionProof> {
self.inner.insert(Value::from(i as i64), value)
}
pub fn delete(&mut self, i: usize) -> Result<MerkleTreeStateTransitionProof> {
self.inner.delete(Value::from(i as i64).raw())
}
pub fn update(&mut self, i: usize, value: &Value) -> Result<MerkleTreeStateTransitionProof> {
let mtp = self.mt.update(&(i as i64).into(), &value.raw())?;
self.array[i] = value.clone();
Ok(mtp)
self.inner
.update(Value::from(i as i64).raw(), value.clone())
}
pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> {
Ok(MerkleTree::verify(
root,
proof,
&RawValue::from(i as i64),
&value.raw(),
)?)
Container::verify(root, proof, Value::from(i as i64).raw(), value.raw())
}
pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> {
MerkleTree::verify_state_transition(proof).map_err(|e| e.into())
Container::verify_state_transition(proof)
}
pub fn array(&self) -> &[Value] {
&self.array
pub fn iter(&self) -> impl Iterator<Item = Result<(usize, Value)>> + use<'_> {
self.inner.iter().map(|r| match r {
Ok((key, value)) => {
let index = key
.as_int()
.ok_or_else(|| Error::custom("array: key is not int"))?;
Ok((index as usize, value))
}
Err(e) => Err(e),
})
}
/// This is an expensive operation
pub fn dump(&self) -> Result<HashMap<usize, Value>> {
self.iter().collect()
}
}
impl PartialEq for Array {
fn eq(&self, other: &Self) -> bool {
self.mt.root() == other.mt.root()
self.inner.eq(&other.inner)
}
}
impl Eq for Array {}
impl<'de> Deserialize<'de> for Array {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::db::mem::MemDB;
fn test_databases(test_fn: &dyn Fn(Box<dyn DB>)) {
let db = MemDB::new();
test_fn(Box::new(db));
#[cfg(feature = "db_rocksdb")]
{
#[derive(Deserialize, JsonSchema)]
struct Aux {
array: Vec<Value>,
use crate::middleware::db;
let db = db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap();
test_fn(Box::new(db));
}
let aux = Aux::deserialize(deserializer)?;
Ok(Array::new(aux.array))
}
fn _test_dict(db: Box<dyn DB>) {
let mut dict0 = Dictionary::empty_with_db(db.clone());
dict0.insert(&Key::from("a"), &Value::from(1)).unwrap();
dict0.insert(&Key::from("b"), &Value::from(2)).unwrap();
dict0.update(&Key::from("a"), &Value::from(3)).unwrap();
dict0.insert(&Key::from("c"), &Value::from(4)).unwrap();
dict0.delete(&Key::from("c")).unwrap();
let kvs0 = dict0.dump().unwrap();
assert_eq!(
kvs0,
[
("a".to_string(), Value::from(3)),
("b".to_string(), Value::from(2))
]
.into_iter()
.collect()
);
let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap();
let kvs1 = dict1.dump().unwrap();
assert_eq!(kvs0, kvs1);
}
fn _test_set(db: Box<dyn DB>) {
let mut set0 = Set::empty_with_db(db.clone());
set0.insert(&Value::from(1)).unwrap();
set0.insert(&Value::from(2)).unwrap();
set0.insert(&Value::from(3)).unwrap();
set0.delete(&Value::from(2)).unwrap();
let s0 = set0.dump().unwrap();
assert_eq!(s0, [Value::from(1), Value::from(3)].into_iter().collect());
let set1 = Set::from_db(set0.commitment(), db).unwrap();
let s1 = set1.dump().unwrap();
assert_eq!(s0, s1);
}
fn _test_array(db: Box<dyn DB>) {
let mut arr0 = Array::empty_with_db(db.clone());
arr0.insert(0, Value::from("a")).unwrap();
arr0.insert(1, Value::from("b")).unwrap();
arr0.insert(2, Value::from("c")).unwrap();
arr0.delete(1).unwrap();
let a0 = arr0.dump().unwrap();
assert_eq!(
a0,
[(0, Value::from("a")), (2, Value::from("c"))]
.into_iter()
.collect()
);
let arr1 = Array::from_db(arr0.commitment(), db).unwrap();
let a1 = arr1.dump().unwrap();
assert_eq!(a0, a1);
}
fn _test_nested(db: Box<dyn DB>) {
let mut nested = Dictionary::empty_with_db(db.clone());
nested.insert(&Key::from("a"), &Value::from(1)).unwrap();
nested.insert(&Key::from("b"), &Value::from(2)).unwrap();
let nested_kvs0 = nested.dump().unwrap();
let mut dict0 = Dictionary::empty_with_db(db.clone());
dict0.insert(&Key::from("x"), &Value::from(1)).unwrap();
dict0
.insert(&Key::from("y"), &Value::from(nested.clone()))
.unwrap();
let kvs0 = dict0.dump().unwrap();
assert_eq!(
kvs0,
[
("x".to_string(), Value::from(1)),
("y".to_string(), Value::from(nested))
]
.into_iter()
.collect()
);
let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap();
let kvs1 = dict1.dump().unwrap();
assert_eq!(kvs0, kvs1);
match &kvs1["y"].typed {
TypedValue::Dictionary(d) => {
let nested_kvs1 = d.dump().unwrap();
assert_eq!(nested_kvs0, nested_kvs1);
}
_ => unreachable!(),
}
}
#[test]
fn test_dict() {
test_databases(&_test_dict);
}
#[test]
fn test_set() {
test_databases(&_test_set);
}
#[test]
fn test_array() {
test_databases(&_test_array);
}
#[test]
fn test_nested() {
test_databases(&_test_nested);
}
}

View file

@ -49,6 +49,9 @@ pub enum StatementTmplArg {
// AnchoredKey where the origin is a wildcard
AnchoredKey(Wildcard, Key),
Wildcard(Wildcard),
/// Reference to a same-batch predicate's identity hash, resolved at verification time.
/// The usize is the predicate index within the batch.
SelfPredicateHash(usize),
}
#[derive(Clone, Copy)]
@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix {
Literal = 1,
AnchoredKey = 2,
WildcardLiteral = 3,
SelfPredicateHash = 4,
}
impl From<StatementTmplArgPrefix> for F {
@ -72,7 +76,8 @@ impl ToFields for StatementTmplArg {
// Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
// SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0)
// In all cases, we pad to 2 * hash_size + 1 = 9 field elements
match self {
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO))
@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg {
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
StatementTmplArg::SelfPredicateHash(index) => {
iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash))
.chain(iter::once(F::from_canonical_usize(*index)))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
}
}
}
@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg {
write!(f, "]")
}
Self::Wildcard(v) => v.fmt(f),
Self::SelfPredicateHash(i) => write!(f, "::self.{}", i),
}
}
}
@ -423,7 +436,7 @@ impl fmt::Display for CustomPredicate {
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)]
#[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)]
enum CustomPredicateBatchData {
Full {
#[serde(skip)]
@ -436,6 +449,20 @@ enum CustomPredicateBatchData {
},
}
// Explicit implementation of Debug to skip the merkle tree
impl fmt::Debug for CustomPredicateBatchData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
Self::Full { mt, predicates } => f
.debug_struct("Full")
.field("id", &mt.root())
.field("predicates", &predicates)
.finish(),
Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(),
}
}
}
// TODO: Rename Batch for Module everywhere in the code base
impl CustomPredicateBatchData {
fn new_full(predicates: Vec<CustomPredicate>) -> Self {
@ -569,6 +596,44 @@ impl CustomPredicateRef {
pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates()[self.index]
}
/// Returns a copy of this predicate with all `SelfPredicateHash(i)` args
/// resolved to `Literal(hash(Custom(batch, i)))`.
pub fn normalized_predicate(&self) -> CustomPredicate {
let pred = self.predicate();
let normalized_statements = pred
.statements
.iter()
.map(|st_tmpl| {
let args = st_tmpl
.args
.iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: self.batch.clone(),
index: *i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other.clone(),
})
.collect();
StatementTmpl {
pred_or_wc: st_tmpl.pred_or_wc.clone(),
args,
}
})
.collect();
CustomPredicate {
name: pred.name.clone(),
conjunction: pred.conjunction,
statements: normalized_statements,
args_len: pred.args_len,
wildcard_names: pred.wildcard_names.clone(),
}
}
}
#[cfg(test)]
@ -579,7 +644,7 @@ mod tests {
middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl,
StatementTmplArg,
StatementTmplArg, ValueRef,
},
};
@ -602,6 +667,9 @@ mod tests {
fn names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect()
}
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
#[allow(clippy::upper_case_acronyms)]
type STA = StatementTmplArg;
@ -650,7 +718,7 @@ mod tests {
});
let custom_statement = Statement::Custom(
CustomPredicateRef::new(cust_pred_batch.clone(), 0),
vec![Value::from(d0.clone())],
vec![value_ref(d0.clone())],
);
let custom_deduction = Operation::Custom(
@ -782,7 +850,7 @@ mod tests {
// Example statement
let ethdos_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)],
vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
);
// Copies should work.
@ -791,7 +859,7 @@ mod tests {
// This could arise as the inductive step.
let ethdos_ind_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)],
vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
);
assert!(Operation::Custom(
@ -806,12 +874,12 @@ mod tests {
let ethdos_facts = vec![
Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)],
vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)],
),
Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)),
Statement::Custom(
CustomPredicateRef::new(eth_friend_batch.clone(), 0),
vec![Value::from("Charlie"), Value::from("Bob")],
vec![value_ref("Charlie"), value_ref("Bob")],
),
];
@ -823,4 +891,173 @@ mod tests {
Ok(())
}
#[test]
fn test_normalized_predicate() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
// Compute expected pred_A hash
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
// Normalize pred_B
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
// The second arg should be resolved to Literal(pred_A_hash)
assert_eq!(
normalized.statements[0].args[1],
STA::Literal(expected_hash)
);
// First arg should be unchanged (still a wildcard)
assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0)));
Ok(())
}
#[test]
fn test_self_predicate_hash_check() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
// Construct a valid operation: Equal(some_value, pred_a_hash)
let some_value = Value::from(42);
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
// The output statement
let output_st = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(some_value.clone())],
);
// This should pass
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?);
// Now try with wrong hash, should fail
let wrong_hash = Value::from(999);
let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)];
assert!(Operation::Custom(pred_b_ref, bad_op_args)
.check(&params, &output_st)
.is_err());
Ok(())
}
#[test]
fn test_self_predicate_hash_cyclic() -> Result<()> {
let params = Params::default();
// Build a batch where pred_A references pred_B's hash and vice versa
// pred_A = Equal(x, SelfPredicateHash(1))
// pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)],
)],
1,
names(&["x"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should reference pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
STA::Literal(pred_b_hash.clone())
);
// pred_B's normalized form should reference pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
STA::Literal(pred_a_hash.clone())
);
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
let st_a = Statement::Custom(
pred_a_ref.clone(),
vec![ValueRef::Literal(pred_b_hash.clone())],
);
assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?);
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
let st_b = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(pred_a_hash.clone())],
);
assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?);
Ok(())
}
}

62
src/middleware/db/mem.rs Normal file
View file

@ -0,0 +1,62 @@
use super::*;
/// MemDB implements the DB trait in a in-memory HashMap.
#[derive(Clone, Debug, Default)]
pub struct MemDB {
nodes: Arc<RwLock<HashMap<Hash, merkletree::Node>>>,
values: Arc<RwLock<HashMap<RawValue, Value>>>,
}
impl MemDB {
pub fn new() -> Self {
Self::default()
}
}
impl merkletree::db::DB for MemDB {
fn load_node(&self, hash: Hash) -> anyhow::Result<Option<merkletree::Node>> {
let nodes = self.nodes.read().expect("lock not poisoned");
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
Ok(nodes.get(&hash).cloned())
}
fn store_node(&mut self, node: merkletree::Node) -> anyhow::Result<()> {
let mut nodes = self.nodes.write().expect("lock not poisoned");
nodes.insert(node.hash(), node);
Ok(())
}
}
impl DB for MemDB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>> {
let values = self.values.read().expect("lock not poisoned");
Ok(values.get(&raw).cloned())
}
fn store_value(&mut self, value: Value) -> anyhow::Result<()> {
let mut values = self.values.write().expect("lock not poisoned");
let value_raw = value.raw();
if let Some(old_value) = values.get(&value_raw) {
let old_is_raw = old_value.is_raw();
// If we had a non-RawValue stored don't overwrite it (specially not with a
// RawValue). Also skip redundant RawValue overwrite.
if !old_is_raw || value.is_raw() {
return Ok(());
}
}
values.insert(value_raw, value);
Ok(())
}
fn is_persistent(&self) -> bool {
false
}
fn clone_box(&self) -> Box<dyn DB> {
Box::new(self.clone())
}
}

30
src/middleware/db/mod.rs Normal file
View file

@ -0,0 +1,30 @@
use std::{
collections::HashMap,
fmt::Debug,
sync::{Arc, RwLock},
};
use dyn_clone::DynClone;
#[cfg(feature = "backend_plonky2")]
use crate::backends::plonky2::primitives::merkletree::{self};
use crate::middleware::{Hash, RawValue, Value, EMPTY_HASH};
pub mod mem;
#[cfg(feature = "db_rocksdb")]
pub mod rocks;
// Trait for database that stores values. Must be cheap to clone.
pub trait DB: Debug + DynClone + Sync + Send + merkletree::db::DB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>>;
// If the DB is persistent, for containers only the root needs to be stored because the
// Container type makes sure the underlying merkle tree is stored in the DB independently, so
// that it can be recovered back just with the root and the DB.
// If the value is RawValue and a previous non-RawValue exists, no store overwrite it.
// should be done. If the value is non-RawValue and a previous RawValue exists, store
// should overwrite it.
fn store_value(&mut self, value: Value) -> anyhow::Result<()>;
fn is_persistent(&self) -> bool;
fn clone_box(&self) -> Box<dyn DB>;
}
dyn_clone::clone_trait_object!(DB);

107
src/middleware/db/rocks.rs Normal file
View file

@ -0,0 +1,107 @@
use std::{fmt, path::Path, sync::Arc};
use anyhow::{anyhow, Result};
use rocksdb::{Options, TransactionDB, TransactionDBOptions};
use super::*;
fn node_key(hash: Hash) -> Vec<u8> {
let mut k = Vec::with_capacity(2 + 4);
k.extend_from_slice(b"n/");
k.extend_from_slice(&RawValue::from(hash).to_bytes());
k
}
fn value_key(raw: RawValue) -> Vec<u8> {
let mut k = Vec::with_capacity(2 + 4);
k.extend_from_slice(b"v/");
k.extend_from_slice(&raw.to_bytes());
k
}
#[derive(Clone)]
pub struct RocksDB {
db: Arc<TransactionDB>,
}
impl fmt::Debug for RocksDB {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RocksDB(path: {:?})", self.db.path())
}
}
impl RocksDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let mut options = Options::default();
options.create_if_missing(true);
let txn_options = TransactionDBOptions::default();
let inner =
TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?;
Ok(Self {
db: Arc::new(inner),
})
}
}
impl merkletree::db::DB for RocksDB {
fn load_node(&self, hash: Hash) -> Result<Option<merkletree::Node>> {
if hash == EMPTY_HASH {
return Ok(Some(merkletree::Node::Intermediate(
merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH),
)));
}
match self.db.get(node_key(hash))? {
None => Ok(None),
Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)),
}
}
fn store_node(&mut self, node: merkletree::Node) -> Result<()> {
self.db
.put(node_key(node.hash()), node.encode()?)
.map_err(|e| anyhow!("rocksdb transaction put failed: {e}"))
}
}
impl DB for RocksDB {
fn load_value(&self, raw: RawValue) -> anyhow::Result<Option<Value>> {
match self.db.get(value_key(raw))? {
None => Ok(None),
Some(bytes) => Ok(Some({
if bytes.is_empty() {
Value::from(raw)
} else {
Value::from_bytes(bytes.as_ref(), self.clone_box())?
}
})),
}
}
fn store_value(&mut self, value: Value) -> anyhow::Result<()> {
let value_key = value_key(value.raw());
let tx = self.db.transaction();
if let Some(old_value_bytes) = tx.get_for_update(&value_key, true)? {
let is_raw = old_value_bytes.is_empty();
// If we had a non-RawValue stored don't overwrite it (specially not with a
// RawValue). Also skip redundant RawValue overwrite.
if !is_raw || (is_raw && value.is_raw()) {
return Ok(());
}
}
let value_bytes = if value.is_raw() {
// For RawValue we store an empty vector because it's a duplicate of the key.
// This way we can easily check for RawValue without decoding.
vec![]
} else {
Value::to_bytes(&value)
};
tx.put(value_key, value_bytes)?;
Ok(tx.commit()?)
}
fn is_persistent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn DB> {
Box::new(self.clone())
}
}

View file

@ -72,6 +72,10 @@ pub enum Error {
},
#[error(transparent)]
Tree(#[from] crate::backends::plonky2::primitives::merkletree::error::TreeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("database error: {0}")]
Database(anyhow::Error),
}
impl Debug for Error {
@ -164,7 +168,7 @@ impl Error {
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateDisjunction(pred))
}
pub(crate) fn custom(s: String) -> Self {
new!(Custom(s))
pub(crate) fn custom(s: impl Into<String>) -> Self {
new!(Custom(s.into()))
}
}

View file

@ -1,16 +1,13 @@
//! The middleware includes the type definitions and the traits used to connect the frontend and
//! the backend.
use std::sync::Arc;
use hex::ToHex;
use itertools::Itertools;
use strum_macros::FromRepr;
mod basetypes;
use std::{cmp::PartialEq, hash};
use containers::{Array, Dictionary, Set};
use containers::{Array, Container, Dictionary, Set};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub mod containers;
@ -22,6 +19,7 @@ pub mod serialization;
mod statement;
use std::{any::Any, fmt};
pub mod db;
pub use basetypes::*;
pub use custom::*;
use dyn_clone::DynClone;
@ -31,14 +29,10 @@ pub use pod_deserialization::*;
use serialization::*;
pub use statement::*;
use crate::backends::plonky2::primitives::merkletree::{
MerkleProof, MerkleTreeStateTransitionProof,
};
// TODO: Move all value-related types to to `value.rs`
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
// TODO #[schemars(transform = serialization::transform_value_schema)]
pub enum TypedValue {
pub(crate) enum TypedValue {
// Serde cares about the order of the enum variants, with untagged variants
// appearing at the end.
// Variants without "untagged" will be serialized as "tagged" values by
@ -73,8 +67,6 @@ pub enum TypedValue {
Array(Array),
#[serde(untagged)]
String(String),
#[serde(untagged)]
Bool(bool),
}
impl From<&str> for TypedValue {
@ -97,7 +89,11 @@ impl From<i64> for TypedValue {
impl From<bool> for TypedValue {
fn from(b: bool) -> Self {
TypedValue::Bool(b)
if b {
TypedValue::Int(1)
} else {
TypedValue::Int(0)
}
}
}
@ -149,70 +145,6 @@ impl From<RawValue> for TypedValue {
}
}
impl TryFrom<&TypedValue> for i64 {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::Int(n) = v {
Ok(*n)
} else {
Err(Error::custom("Value not an int".to_string()))
}
}
}
impl TryFrom<&TypedValue> for String {
type Error = Error;
fn try_from(tv: &TypedValue) -> Result<Self> {
match tv {
TypedValue::String(s) => Ok(s.clone()),
_ => Err(Error::custom(format!(
"Value {} cannot be converted to a string.",
tv
))),
}
}
}
impl TryFrom<&TypedValue> for Key {
type Error = Error;
fn try_from(tv: &TypedValue) -> Result<Self> {
Ok(Key::new(String::try_from(tv)?))
}
}
impl TryFrom<&TypedValue> for PublicKey {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::PublicKey(pk) = v {
Ok(*pk)
} else {
Err(Error::custom("Value not a public key".to_string()))
}
}
}
impl TryFrom<&TypedValue> for SecretKey {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::SecretKey(sk) = v {
Ok(sk.clone())
} else {
Err(Error::custom("Value not a secret key".to_string()))
}
}
}
impl TryFrom<&TypedValue> for Predicate {
type Error = Error;
fn try_from(v: &TypedValue) -> std::result::Result<Self, Self::Error> {
if let TypedValue::Predicate(p) = v {
Ok(p.clone())
} else {
Err(Error::custom("Value not a Predicate".to_string()))
}
}
}
impl fmt::Display for TypedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@ -224,36 +156,54 @@ impl fmt::Display for TypedValue {
Err(_) => write!(f, "\"{}\"", s),
}
}
TypedValue::Bool(b) => write!(f, "{}", b),
TypedValue::Array(a) => {
write!(f, "[")?;
for (i, v) in a.array().iter().enumerate() {
for (i, r) in a.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", v)?;
if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok((index, value)) => write!(f, "{}: {}", index, value)?,
Err(e) => write!(f, "{e}")?,
}
}
write!(f, "]")
}
TypedValue::Dictionary(d) => {
write!(f, "{{ ")?;
let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect();
for (i, (k, v)) in kvs.iter().enumerate() {
for (i, r) in d.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", k, v)?;
if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok((key, value)) => write!(f, "{}: {}", key, value)?,
Err(e) => write!(f, "{e}")?,
}
}
write!(f, " }}")
}
TypedValue::Set(s) => {
write!(f, "#[")?;
let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect();
for (i, v) in values.iter().enumerate() {
for (i, r) in s.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", v)?;
if i == 8 {
write!(f, "")?;
break;
}
match r {
Ok(value) => write!(f, "{}", value)?,
Err(e) => write!(f, "{e}")?,
}
}
write!(f, "]")
}
@ -272,7 +222,6 @@ impl From<&TypedValue> for RawValue {
match v {
TypedValue::String(s) => RawValue::from(hash_str(s)),
TypedValue::Int(v) => RawValue::from(*v),
TypedValue::Bool(b) => RawValue::from(*b as i64),
TypedValue::Dictionary(d) => RawValue::from(d.commitment()),
TypedValue::Set(s) => RawValue::from(s.commitment()),
TypedValue::Array(a) => RawValue::from(a.commitment()),
@ -405,9 +354,8 @@ impl JsonSchema for TypedValue {
#[derive(Clone, Debug)]
pub struct Value {
// The `TypedValue` is under `Arc` so that cloning a `Value` is cheap.
typed: Arc<TypedValue>,
raw: RawValue,
pub(crate) typed: TypedValue,
pub(crate) raw: RawValue,
}
// Values are serialized as their TypedValue.
@ -441,6 +389,55 @@ impl JsonSchema for Value {
}
}
/// Dual of TypedValue that is not recursive: for container types no entry only the commitment
/// (merkle tree root of underlying data) is available. Used for byte serialization for
/// persistent storage.
#[derive(Serialize, Deserialize)]
enum TypedValueNoRec {
Raw(RawValue),
Int(i64),
PublicKey(PublicKey),
SecretKey(SecretKey),
Predicate(Predicate),
Set(Hash),
Dictionary(Hash),
Array(Hash),
String(String),
}
// NOTE: byte serialization is using json. Using a byte-native serialization would improve
// performance and storage usage.
impl Value {
pub fn to_bytes(&self) -> Vec<u8> {
let v = match &self.typed {
TypedValue::Int(v) => TypedValueNoRec::Int(*v),
TypedValue::Raw(v) => TypedValueNoRec::Raw(*v),
TypedValue::PublicKey(v) => TypedValueNoRec::PublicKey(*v),
TypedValue::SecretKey(v) => TypedValueNoRec::SecretKey(v.clone()),
TypedValue::Predicate(v) => TypedValueNoRec::Predicate(v.clone()),
TypedValue::Set(v) => TypedValueNoRec::Set(v.commitment()),
TypedValue::Dictionary(v) => TypedValueNoRec::Dictionary(v.commitment()),
TypedValue::Array(v) => TypedValueNoRec::Array(v.commitment()),
TypedValue::String(v) => TypedValueNoRec::String(v.clone()),
};
serde_json::to_vec(&v).expect("json serialization succeeds")
}
pub fn from_bytes(bytes: &[u8], db: Box<dyn db::DB>) -> Result<Self> {
let v: TypedValueNoRec = serde_json::from_slice(bytes)?;
Ok(match v {
TypedValueNoRec::Int(v) => Value::from(v),
TypedValueNoRec::Raw(v) => Value::from(v),
TypedValueNoRec::PublicKey(v) => Value::from(v),
TypedValueNoRec::SecretKey(v) => Value::from(v),
TypedValueNoRec::Predicate(v) => Value::from(v),
TypedValueNoRec::Set(v) => Value::from(Set::from_db(v, db)?),
TypedValueNoRec::Dictionary(v) => Value::from(Dictionary::from_db(v, db)?),
TypedValueNoRec::Array(v) => Value::from(Array::from_db(v, db)?),
TypedValueNoRec::String(v) => Value::from(v),
})
}
}
impl PartialEq for Value {
fn eq(&self, other: &Self) -> bool {
self.raw == other.raw
@ -462,106 +459,110 @@ impl fmt::Display for Value {
}
impl Value {
pub fn new(value: TypedValue) -> Self {
pub(crate) fn new(value: TypedValue) -> Self {
let raw_value = RawValue::from(&value);
Self {
typed: Arc::new(value),
typed: value,
raw: raw_value,
}
}
pub fn typed(&self) -> &TypedValue {
&self.typed
}
pub fn raw(&self) -> RawValue {
self.raw
}
/// Determines Merkle existence proof for `key` in `self` (if applicable).
pub(crate) fn prove_existence<'a>(
&'a self,
key: &'a Value,
) -> Result<(&'a Value, MerkleProof)> {
match &self.typed() {
TypedValue::Array(a) => match key.typed() {
TypedValue::Int(i) if i >= &0 => a.prove((*i) as usize),
_ => Err(Error::custom(format!(
"Invalid key {} for container {}.",
key, self
)))?,
/// Returns true if the typed value is RawValue, which means it's a generic value with no type
/// information and no extra value data.
pub fn is_raw(&self) -> bool {
matches!(self.typed, TypedValue::Raw(_))
}
pub fn as_raw(&self) -> RawValue {
self.raw
}
pub fn as_int(&self) -> Option<i64> {
match self.typed {
TypedValue::Int(i) => Some(i),
_ => None,
}
}
pub fn as_public_key(&self) -> Option<PublicKey> {
match &self.typed {
TypedValue::PublicKey(pk) => Some(*pk),
_ => None,
}
}
pub fn as_secret_key(&self) -> Option<SecretKey> {
match &self.typed {
TypedValue::SecretKey(sk) => Some(sk.clone()),
_ => None,
}
}
pub fn as_predicate(&self) -> Option<Predicate> {
match &self.typed {
TypedValue::Predicate(p) => Some(p.clone()),
_ => None,
}
}
pub fn as_set(&self) -> Option<Set> {
match &self.typed {
TypedValue::Set(s) => Some(s.clone()),
TypedValue::Dictionary(d) => Some(Set {
inner: d.inner.clone(),
}),
TypedValue::Array(a) => Some(Set {
inner: a.inner.clone(),
}),
_ => None,
}
}
pub fn as_container(&self) -> Option<Container> {
match &self.typed {
TypedValue::Set(s) => Some(s.inner.clone()),
TypedValue::Dictionary(d) => Some(d.inner.clone()),
TypedValue::Array(a) => Some(a.inner.clone()),
_ => None,
}
}
pub fn as_dictionary(&self) -> Option<Dictionary> {
match &self.typed {
TypedValue::Set(s) => Some(Dictionary {
inner: s.inner.clone(),
}),
TypedValue::Dictionary(d) => Some(d.clone()),
TypedValue::Array(a) => Some(Dictionary {
inner: a.inner.clone(),
}),
_ => None,
}
}
pub fn as_array(&self) -> Option<Array> {
match &self.typed {
TypedValue::Set(s) => Some(Array {
inner: s.inner.clone(),
}),
TypedValue::Dictionary(d) => Some(Array {
inner: d.inner.clone(),
}),
TypedValue::Array(a) => Some(a.clone()),
_ => None,
}
}
pub fn as_str(&self) -> Option<&str> {
match &self.typed {
TypedValue::String(s) => Some(s.as_str()),
_ => None,
}
}
pub fn as_string(&self) -> Option<String> {
self.as_str().map(|s| s.to_string())
}
pub fn as_bool(&self) -> Option<bool> {
match self.typed {
TypedValue::Int(i) => match i {
0 => Some(false),
1 => Some(true),
_ => None,
},
TypedValue::Dictionary(d) => d.prove(&key.typed().try_into()?),
TypedValue::Set(s) => Ok((key, s.prove(key)?)),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Determines Merkle non-existence proof for `key` in `self` (if applicable).
pub(crate) fn prove_nonexistence<'a>(&'a self, key: &'a Value) -> Result<MerkleProof> {
match &self.typed() {
TypedValue::Array(_) => Err(Error::custom(
"Arrays do not support `NotContains` operation.".to_string(),
)),
TypedValue::Dictionary(d) => d.prove_nonexistence(&key.typed().try_into()?),
TypedValue::Set(s) => s.prove_nonexistence(key),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for inserting a
/// key-value pair (if applicable).
pub(crate) fn prove_insertion(
&self,
key: &Value,
value: &Value,
) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Dictionary(mut d) => d.insert(&key.typed().try_into()?, value),
TypedValue::Set(mut s) => s.insert(value),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for updating a
/// key-value pair (if applicable).
pub(crate) fn prove_update(
&self,
key: &Value,
value: &Value,
) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Array(mut a) => match key.typed() {
TypedValue::Int(i) if i >= &0 => a.update(*i as usize, value),
_ => Err(Error::custom(format!(
"Invalid key {} for container {}.",
key, self
)))?,
},
TypedValue::Dictionary(mut d) => d.update(&key.typed().try_into()?, value),
_ => Err(Error::custom(format!(
"Invalid container value {} for update op",
self.typed()
))),
}
}
/// Returns a Merkle state transition proof for deleting a
/// key (if applicable).
pub(crate) fn prove_deletion(&self, key: &Value) -> Result<MerkleTreeStateTransitionProof> {
let container = self.typed().clone();
match container {
TypedValue::Dictionary(mut d) => d.delete(&key.typed().try_into()?),
TypedValue::Set(mut s) => s.delete(key),
_ => Err(Error::custom(format!(
"Invalid container value {}",
self.typed()
))),
_ => None,
}
}
}
@ -767,6 +768,8 @@ pub struct BaseParams {
/// in a custom predicate
pub max_custom_predicate_arity: usize,
pub max_depth_custom_batch_mt: usize,
// This value depends on `max_custom_predicate_arity`
pub max_operation_args: usize,
}
pub const BASE_PARAMS: BaseParams = BaseParams {
@ -774,8 +777,53 @@ pub const BASE_PARAMS: BaseParams = BaseParams {
max_statement_args: 5,
max_custom_predicate_arity: 5,
max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch
max_operation_args: 5 + 1,
};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")]
pub struct ParamsMerkleProofs {
pub max_small: usize,
pub max_medium: usize,
}
impl ParamsMerkleProofs {
pub fn max_total(&self) -> usize {
self.max_small + self.max_medium
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")]
pub struct ParamsContainers {
// Parameters for exists/nonexists container operations. The small set only supports exists
pub state: ParamsMerkleProofs,
// Parameters for transition container operations (insert, delete, update). The small set only
// supports update.
pub transition: ParamsMerkleProofs,
// Max depth of small proofs
pub max_depth_small: usize,
// Max depth of medium proofs
pub max_depth_medium: usize,
}
impl Default for ParamsContainers {
fn default() -> Self {
Self {
state: ParamsMerkleProofs {
max_small: 22,
max_medium: 8,
},
transition: ParamsMerkleProofs {
max_small: 12,
max_medium: 6,
},
max_depth_small: 8,
max_depth_medium: 32,
}
}
}
/// Params: non dynamic parameters that define the circuit.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)]
#[serde(rename_all = "camelCase")]
@ -784,18 +832,12 @@ pub struct Params {
pub max_input_pods_public_statements: usize,
pub max_statements: usize,
pub max_public_statements: usize,
pub max_operation_args: usize,
// max number of different custom predicates that can be used in a MainPod
pub max_custom_predicates: usize,
// max number of operations using custom predicates that can be verified in the MainPod
pub max_custom_predicate_verifications: usize,
pub max_custom_predicate_wildcards: usize,
// maximum number of merkle proofs used for container operations
pub max_merkle_proofs_containers: usize,
// maximum number of merkle tree state transition proofs used for container update operations
pub max_merkle_tree_state_transition_proofs_containers: usize,
// maximum depth for merkle tree gadget used for container operations
pub max_depth_mt_containers: usize,
pub containers: ParamsContainers,
// maximum depth of the merkle tree gadget used for verifier_data membership
// check. This allows creating verifying sets of pod circuits of size
// 2^max_depth_mt_vds. Limits the number of container operations of the type Contains,
@ -814,13 +856,10 @@ impl Default for Params {
max_input_pods_public_statements: 8,
max_statements: 48,
max_public_statements: 8,
max_operation_args: 5,
max_custom_predicates: 8,
max_custom_predicate_verifications: 8,
max_custom_predicate_wildcards: 8,
max_merkle_proofs_containers: 20,
max_merkle_tree_state_transition_proofs_containers: 6,
max_depth_mt_containers: 32,
containers: ParamsContainers::default(),
max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits
max_public_key_of: 2,
max_signed_by: 4,

View file

@ -7,17 +7,14 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::primitives::{
ec::{
curve::{Point as PublicKey, GROUP_ORDER},
schnorr::{SecretKey, Signature},
},
ec::{curve::GROUP_ORDER, schnorr::Signature},
merkletree::{MerkleProof, MerkleTree, MerkleTreeOp, MerkleTreeStateTransitionProof},
},
middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value,
ValueRef, Wildcard, F,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef,
Wildcard, BASE_PARAMS, F,
},
};
@ -92,6 +89,7 @@ pub enum NativeOperation {
ContainerInsertFromEntries = 16,
ContainerUpdateFromEntries = 17,
ContainerDeleteFromEntries = 18,
ReplaceValueWithEntry = 19,
// Syntactic sugar operations. These operations are not supported by the backend. The
// frontend compiler is responsible of translating these operations into the operations above.
@ -167,6 +165,7 @@ impl OperationType {
NativeOperation::ContainerDeleteFromEntries => {
Some(Predicate::Native(NativePredicate::ContainerDelete))
}
NativeOperation::ReplaceValueWithEntry => None,
no => unreachable!("Unexpected syntactic sugar op {:?}", no),
},
OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())),
@ -222,6 +221,10 @@ pub enum Operation {
/* key */ Statement,
/* proof */ MerkleTreeStateTransitionProof,
),
ReplaceValueWithEntry(
/* Contains/None len=max_statement_args */ Vec<Statement>,
/* to copy */ Statement,
),
Custom(CustomPredicateRef, Vec<Statement>),
}
@ -241,6 +244,10 @@ pub(crate) fn hash_op(x: Value, y: Value) -> Value {
Value::from(hash_values(&[x, y]))
}
fn ok_or_type_err<T>(o: Option<T>, v: &Value, typ: &'static str) -> Result<T> {
o.ok_or_else(|| Error::custom(format!("{v} type is not {typ}")))
}
impl Operation {
pub fn op_type(&self) -> OperationType {
type OT = OperationType;
@ -269,6 +276,7 @@ impl Operation {
OT::Native(ContainerUpdateFromEntries)
}
Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries),
Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry),
Self::Custom(cpr, _) => OT::Custom(cpr.clone()),
}
}
@ -294,6 +302,11 @@ impl Operation {
Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3],
Self::ReplaceValueWithEntry(args, s) => {
let mut sts = args;
sts.push(s);
sts
}
Self::Custom(_, args) => args,
}
}
@ -376,6 +389,18 @@ impl Operation {
&[s1, s2, s3],
OA::MerkleTreeStateTransitionProof(pf),
) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf),
(NO::ReplaceValueWithEntry, args, OA::None) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = args.pop().expect("valid vec len");
Self::ReplaceValueWithEntry(args, st)
}
_ => Err(Error::custom(format!(
"Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.",
op_code,
@ -404,23 +429,55 @@ impl Operation {
v3: &Value,
f: impl FnOnce(i64, i64) -> i64,
) -> Result<bool> {
let i1: i64 = v1.typed().try_into()?;
let i2: i64 = v2.typed().try_into()?;
let i3: i64 = v3.typed().try_into()?;
let i1 = ok_or_type_err(v1.as_int(), v1, "Int")?;
let i2 = ok_or_type_err(v2.as_int(), v2, "Int")?;
let i3 = ok_or_type_err(v3.as_int(), v3, "Int")?;
Ok(i1 == f(i2, i3))
}
pub(crate) fn check_public_key(v1: &Value, v2: &Value) -> Result<bool> {
let pk: PublicKey = v1.typed().try_into()?;
let sk: SecretKey = v2.typed().try_into()?;
let pk = ok_or_type_err(v1.as_public_key(), v1, "PublicKey")?;
let sk = ok_or_type_err(v2.as_secret_key(), v2, "SecretKey")?;
Ok(sk.0 < *GROUP_ORDER && pk == sk.public_key())
}
pub(crate) fn check_signed_by(msg: &Value, pk: &Value, sig: &Signature) -> Result<bool> {
let pk: PublicKey = pk.typed().try_into()?;
let pk = ok_or_type_err(pk.as_public_key(), pk, "PublicKey")?;
Ok(sig.verify(pk, msg.raw()))
}
fn check_replace_value_with_entry(
entries: &[Statement],
st_in: &Statement,
expected_st_out: &Statement,
) -> Result<bool> {
if entries.len() != BASE_PARAMS.max_statement_args {
return Ok(false);
}
let args = iter::zip(st_in.args(), entries)
.map(|(arg_in, entry)| match (arg_in, entry) {
(arg_in, Statement::None) => Ok(arg_in),
(
StatementArg::Literal(v_in),
Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
),
) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new(
Hash::from(root.raw()),
Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?),
))),
_ => Err(Error::custom(
"invalid statement argument in ReplaceValueWithEntry",
)),
})
.collect::<Result<Vec<_>>>()?;
let st_out = Statement::from_args(st_in.predicate(), args)?;
Ok(&st_out == expected_st_out)
}
/// Checks the given operation against a statement.
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*;
@ -428,8 +485,8 @@ impl Operation {
let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err);
let int_val = |v, s| {
let v_op = value_from_op(s, v).ok_or_else(deduction_err)?;
match v_op.typed() {
&TypedValue::Int(i) => Ok(i),
match v_op.as_int() {
Some(i) => Ok(i),
_ => Err(deduction_err()),
}
};
@ -494,8 +551,7 @@ impl Operation {
&& pf.op_value == value.raw())
.then_some(())
.ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim."
.into(),
"The provided Merkle tree state transition proof does not match the claim.",
))?;
MerkleTree::verify_state_transition(pf)?;
true
@ -515,8 +571,7 @@ impl Operation {
&& pf.op_value == value.raw())
.then_some(())
.ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim."
.into(),
"The provided Merkle tree state transition proof does not match the claim.",
))?;
MerkleTree::verify_state_transition(pf)?;
true
@ -534,8 +589,7 @@ impl Operation {
&& pf.op_key == key.raw())
.then_some(())
.ok_or(Error::custom(
"The provided Merkle tree state transition proof does not match the claim."
.into(),
"The provided Merkle tree state transition proof does not match the claim.",
))?;
MerkleTree::verify_state_transition(pf)?;
true
@ -543,7 +597,19 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index =>
{
check_custom_pred(params, cpr, args, s_args).map(|_| true)?
// The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let s_args = s_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(deduction_err()),
})
.collect::<Result<Vec<_>>>()?;
check_custom_pred(params, cpr, args, &s_args).map(|_| true)?
}
(Self::ReplaceValueWithEntry(entries, st_in), st_out) => {
Self::check_replace_value_with_entry(entries, st_in, st_out)?
}
_ => return Err(deduction_err()),
};
@ -597,6 +663,11 @@ pub fn check_st_tmpl(
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
wc_check_or_set(v.clone(), wc, wildcard_map)
}
(StatementTmplArg::SelfPredicateHash(_), _) => {
unreachable!(
"SelfPredicateHash should be normalized to Literal before template matching"
)
}
_ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(),
st_arg.clone(),
@ -645,9 +716,9 @@ pub fn wildcard_values_from_op_st(
params: &Params,
pred: &CustomPredicate,
op_args: &[Statement],
st_args: &[Value],
resolved_st_args: &[Value],
) -> Result<Vec<Value>> {
let mut wildcard_map = st_args
let mut wildcard_map = resolved_st_args
.iter()
.map(|v| Some(v.clone()))
.chain(core::iter::repeat(None))
@ -714,7 +785,7 @@ pub(crate) fn check_custom_pred(
args: &[Statement],
s_args: &[Value],
) -> Result<()> {
let pred = custom_pred_ref.predicate();
let pred = custom_pred_ref.normalized_predicate();
if pred.statements.len() != args.len() {
return Err(Error::diff_amount(
"custom predicate operation".to_string(),
@ -733,7 +804,7 @@ pub(crate) fn check_custom_pred(
}
// Check that the resolved wildcards match the statement arguments.
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) {
let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) {
Ok(wc_values) => wc_values,
Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
@ -789,9 +860,8 @@ impl fmt::Display for Operation {
pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option<AnchoredKey> {
let root_hash = Hash::from(root.raw());
Key::try_from(key.typed())
.map(|key| AnchoredKey::new(root_hash, key))
.ok()
key.as_str()
.map(|s| AnchoredKey::new(root_hash, Key::from(s)))
}
/// Returns the value associated with `output_ref`.

View file

@ -311,7 +311,7 @@ pub enum Statement {
/* old_root */ ValueRef,
/* key */ ValueRef,
),
Custom(CustomPredicateRef, Vec<Value>),
Custom(CustomPredicateRef, Vec<ValueRef>),
Intro(IntroPredicateRef, Vec<Value>),
}
@ -407,7 +407,7 @@ impl Statement {
vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()]
}
Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)),
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)),
Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)),
}
}
@ -478,14 +478,11 @@ impl Statement {
}
(BatchSelf(_), _) => unreachable!(),
(Custom(cpr), _) => {
let v_args: Result<Vec<Value>> = args
let v_args = args
.iter()
.map(|x| match x {
StatementArg::Literal(v) => Ok(v.clone()),
_ => Err(Error::incorrect_statements_args()),
})
.collect();
Self::Custom(cpr, v_args?)
.map(|x| x.try_into())
.collect::<Result<Vec<ValueRef>>>()?;
Self::Custom(cpr, v_args)
}
(Intro(ir), _) => {
let v_args: Result<Vec<Value>> = args