From def0730462527bd0f87cc49ae382ff1b642ea314 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Fri, 16 May 2025 13:17:14 +0200 Subject: [PATCH] Fix custom predicate circuits and add tests for them (#235) * add tests, fix custom predicates * wip * wip * fix custom predicates * modularize code * fix typos * remove scratch file * update * Update src/backends/plonky2/circuits/mainpod.rs Co-authored-by: Ahmad Afuni --------- Co-authored-by: Ahmad Afuni --- book/src/statements.md | 25 +- src/backends/plonky2/circuits/common.rs | 42 +- src/backends/plonky2/circuits/mainpod.rs | 408 ++++++++++++++---- src/backends/plonky2/circuits/mod.rs | 1 + src/backends/plonky2/circuits/utils.rs | 73 ++++ src/backends/plonky2/mainpod/mod.rs | 163 ++++++- src/backends/plonky2/mock/mainpod.rs | 11 +- src/backends/plonky2/mock/signedpod.rs | 4 +- .../plonky2/primitives/signature/mod.rs | 2 +- src/backends/plonky2/signedpod.rs | 4 + src/examples/custom.rs | 13 +- src/examples/mod.rs | 12 +- src/frontend/custom.rs | 4 +- src/frontend/mod.rs | 9 +- src/frontend/serialization.rs | 7 +- src/middleware/custom.rs | 4 +- 16 files changed, 629 insertions(+), 153 deletions(-) create mode 100644 src/backends/plonky2/circuits/utils.rs diff --git a/book/src/statements.md b/book/src/statements.md index 45429c9..db09873 100644 --- a/book/src/statements.md +++ b/book/src/statements.md @@ -24,18 +24,19 @@ The following table summarises the natively-supported statements, where we write | Code | Identifier | Args | Meaning | |------|---------------|---------------------|-------------------------------------------------------------------| -| 0 | `None` | | no statement (useful for padding) | -| 1 | `ValueOf` | `ak`, `value` | `value_of(ak) = value` | -| 2 | `Eq` | `ak1`, `ak2` | `value_of(ak1) = value_of(ak2)` | -| 3 | `NEq` | `ak1`, `ak2` | `value_of(ak1) != value_of(ak2)` | -| 4 | `Gt` | `ak1`, `ak2` | `value_of(ak1) > value_of(ak2)` | -| 5 | `LEq` | `ak1`, `ak2` | `value_of(ak1) <= value_of(ak2)` | -| 6 | `Contains` | `ak1`, `ak2` | `(key_of(ak2), value_of(ak2)) ∈ value_of(ak1)` (Merkle inclusion) | -| 7 | `NotContains` | `ak1`, `ak2` | `(key_of(ak2), value_of(ak2)) ∉ value_of(ak1)` (Merkle exclusion) | -| 8 | `SumOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = value_of(ak2) + value_of(ak3)` | -| 9 | `ProductOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = value_of(ak2) * value_of(ak3)` | -| 10 | `MaxOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = max(value_of(ak2), value_of(ak3))` | -| 11 | `HashOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = hash(value_of(ak2), value_of(ak3))` | +| 0 | `None` | | no statement, always true (useful for padding) | +| 1 | `False` | | always false (useful for padding disjunctions) | +| 2 | `ValueOf` | `ak`, `value` | `value_of(ak) = value` | +| 3 | `Eq` | `ak1`, `ak2` | `value_of(ak1) = value_of(ak2)` | +| 4 | `NEq` | `ak1`, `ak2` | `value_of(ak1) != value_of(ak2)` | +| 5 | `Gt` | `ak1`, `ak2` | `value_of(ak1) > value_of(ak2)` | +| 6 | `LEq` | `ak1`, `ak2` | `value_of(ak1) <= value_of(ak2)` | +| 7 | `Contains` | `ak1`, `ak2` | `(key_of(ak2), value_of(ak2)) ∈ value_of(ak1)` (Merkle inclusion) | +| 8 | `NotContains` | `ak1`, `ak2` | `(key_of(ak2), value_of(ak2)) ∉ value_of(ak1)` (Merkle exclusion) | +| 9 | `SumOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = value_of(ak2) + value_of(ak3)` | +| 10 | `ProductOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = value_of(ak2) * value_of(ak3)` | +| 11 | `MaxOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = max(value_of(ak2), value_of(ak3))` | +| 12 | `HashOf` | `ak1`, `ak2`, `ak3` | `value_of(ak1) = hash(value_of(ak2), value_of(ak3))` | ### Frontend statements diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 440803a..633ed61 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -153,7 +153,7 @@ impl Build for T { } impl StatementTarget { - /// Build a new native StatementTarget + /// Build a new native StatementTarget. Pads the arguments. pub fn new_native( builder: &mut CircuitBuilder, params: &Params, @@ -311,7 +311,7 @@ impl NativePredicateTarget { #[derive(Clone)] pub struct PredicateTarget { - elements: [Target; Params::predicate_size()], + pub(crate) elements: [Target; Params::predicate_size()], } impl PredicateTarget { @@ -520,8 +520,35 @@ impl CustomPredicateEntryTarget { ) -> Result<()> { pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; - self.predicate - .set_targets(pw, params, predicate.predicate())?; + + // Replace statement templates of batch-self with (id,index) + let batch = &predicate.batch; + let predicate = predicate.predicate(); + let statements = predicate + .statements + .clone() + .into_iter() + .map(|st_tmpl| { + let pred = match st_tmpl.pred { + Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { + batch: batch.clone(), + index: i, + }), + p => p, + }; + StatementTmpl { + pred, + args: st_tmpl.args, + } + }) + .collect_vec(); + let predicate = CustomPredicate { + name: predicate.name.clone(), + conjunction: predicate.conjunction, + statements, + args_len: predicate.args_len, + }; + self.predicate.set_targets(pw, params, &predicate)?; Ok(()) } } @@ -570,6 +597,7 @@ impl CustomPredicateVerifyEntryTarget { self.custom_predicate_table_index, F::from_canonical_usize(cpv.custom_predicate_table_index), )?; + // Replace statement templates of batch-self with (id,index) self.custom_predicate .set_targets(pw, params, &cpv.custom_predicate)?; let pad_arg = WildcardValue::None; @@ -1439,7 +1467,7 @@ pub(crate) mod tests { let params = Params::default(); let config = CircuitConfig::standard_recursion_config(); - let custom_predicate_batch = eth_friend_batch(¶ms)?; + let custom_predicate_batch = eth_friend_batch(¶ms, false)?; for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() { let mut builder = CircuitBuilder::::new(config.clone()); @@ -1502,10 +1530,10 @@ pub(crate) mod tests { helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); // Some cases from the examples - let custom_predicate_batch = eth_friend_batch(¶ms)?; + let custom_predicate_batch = eth_friend_batch(¶ms, false)?; helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); - let custom_predicate_batch = eth_dos_batch(¶ms)?; + let custom_predicate_batch = eth_dos_batch(¶ms, false)?; helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); let custom_predicate_batch = diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 3b8b5db..8e7bce5 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -14,10 +14,10 @@ use crate::{ circuits::{ common::{ CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget, - CustomPredicateVerifyEntryTarget, CustomPredicateVerifyQueryTarget, Flattenable, - MerkleClaimTarget, OperationTarget, OperationTypeTarget, PredicateTarget, - StatementArgTarget, StatementTarget, StatementTmplArgTarget, StatementTmplTarget, - ValueTarget, + CustomPredicateTarget, CustomPredicateVerifyEntryTarget, + CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget, OperationTarget, + OperationTypeTarget, PredicateTarget, StatementArgTarget, StatementTarget, + StatementTmplArgTarget, StatementTmplTarget, ValueTarget, }, signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, }, @@ -30,8 +30,8 @@ use crate::{ }, middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PodType, Statement, StatementArg, ToFields, Value, WildcardValue, - F, KEY_TYPE, SELF, VALUE_SIZE, + NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields, + Value, WildcardValue, F, KEY_TYPE, SELF, VALUE_SIZE, }, }; @@ -188,8 +188,7 @@ impl OperationVerifyGadget { .concat(); let ok = builder.any(op_checks); - - builder.connect(ok.target, _true.target); + builder.assert_one(ok.target); Ok(()) } @@ -813,6 +812,7 @@ impl CustomOperationVerifyGadget { // expected_sts.len() == self.params.max_custom_predicate_arity // op_args.len() == self.params.max_operation_args; assert!(self.params.max_custom_predicate_arity <= self.params.max_operation_args); + let sts_eq: Vec<_> = expected_sts .iter() .zip(op_args.iter()) @@ -837,6 +837,119 @@ struct MainPodVerifyGadget { } impl MainPodVerifyGadget { + // Replace predicates of batch-self with the corresponding global custom predicate batch_id and + // index + fn normalize_st_tmpl( + &self, + builder: &mut CircuitBuilder, + st_tmpl: &StatementTmplTarget, + id: HashOutTarget, + ) -> StatementTmplTarget { + let params = &self.params; + let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf)); + let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self); + let pred_index = st_tmpl.pred.elements[1]; + let custom_pred = PredicateTarget::new_custom(builder, id, pred_index); + let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred); + StatementTmplTarget { + pred, + args: st_tmpl.args.clone(), + } + } + /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as + /// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we + /// calculate the id of each batch. + fn build_custom_predicate_table( + &self, + builder: &mut CircuitBuilder, + ) -> Result<(Vec, Vec)> { + let params = &self.params; + let mut custom_predicate_table = + Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); + let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches); + for _ in 0..params.max_custom_predicate_batches { + let cpb = builder.add_virtual_custom_predicate_batch(params); + let id = cpb.id(builder); // constrain the id + for (index, cp) in cpb.predicates.iter().enumerate() { + let statements = cp + .statements + .iter() + .map(|st_tmpl| self.normalize_st_tmpl(builder, st_tmpl, id)) + .collect_vec(); + let cp = CustomPredicateTarget { + conjunction: cp.conjunction, + statements, + args_len: cp.args_len, + }; + let entry = CustomPredicateEntryTarget { + id, // output + index: builder.constant(F::from_canonical_usize(index)), // constant + predicate: cp.clone(), // input + }; + + let in_query_hash = entry.hash(builder); + custom_predicate_table.push(in_query_hash); + } + custom_predicate_batches.push(cpb); // We keep this for witness assignment + } + Ok((custom_predicate_table, custom_predicate_batches)) + } + + /// Build table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] + /// with queryable part as hash([st, op, op_args]). While building the table we verify each + /// custom predicate against the operation and statement. + fn build_custom_predicate_verification_table( + &self, + builder: &mut CircuitBuilder, + custom_predicate_table: &[HashOutTarget], + ) -> Result<(Vec, Vec)> { + let params = &self.params; + let mut custom_predicate_verifications = + Vec::with_capacity(params.max_custom_predicate_verifications); + let mut custom_predicate_verification_table = + Vec::with_capacity(params.max_custom_predicate_verifications); + for _ in 0..params.max_custom_predicate_verifications { + let custom_predicate_table_index = builder.add_virtual_target(); + let custom_predicate = builder.add_virtual_custom_predicate_entry(params); + let args = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect_vec(); + let op_args = (0..params.max_operation_args) + .map(|_| builder.add_virtual_statement(params)) + .collect_vec(); + + // Verify the custom predicate operation + let (statement, op_type) = CustomOperationVerifyGadget { + params: params.clone(), + } + .eval(builder, &custom_predicate, &op_args, &args)?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = + builder.vec_ref(params, custom_predicate_table, custom_predicate_table_index); + let out_query_hash = custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let entry = CustomPredicateVerifyEntryTarget { + custom_predicate_table_index, // input + custom_predicate, // input + args, // input + query: CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args, // input + }, + }; + let in_query_hash = entry.query.hash(builder); + custom_predicate_verification_table.push(in_query_hash); + custom_predicate_verifications.push(entry); // We keep this for witness assignment + } + Ok(( + custom_predicate_verification_table, + custom_predicate_verifications, + )) + } + fn eval(&self, builder: &mut CircuitBuilder) -> Result { let params = &self.params; // 1. Verify all input signed pods @@ -851,12 +964,17 @@ impl MainPodVerifyGadget { // Build the statement array let mut statements = Vec::new(); + // Statement at index 0 is always None to be used for padding operation arguments in custom + // predicate statements + let st_none = + StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]); + statements.push(st_none); for signed_pod in &signed_pods { statements.extend_from_slice(signed_pod.pub_statements(builder, false).as_slice()); } debug_assert_eq!( statements.len(), - self.params.max_input_signed_pods * self.params.max_signed_pod_values + 1 + self.params.max_input_signed_pods * self.params.max_signed_pod_values ); // TODO: Fill with input main pods for _main_pod in 0..self.params.max_input_main_pods { @@ -895,73 +1013,13 @@ impl MainPodVerifyGadget { .map(|pf| pf.into()) .collect(); - // Table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as - // hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we - // calculate the id of each batch. - let mut custom_predicate_table = - Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); - let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches); - for _ in 0..params.max_custom_predicate_batches { - let cpb = builder.add_virtual_custom_predicate_batch(&self.params); - let id = cpb.id(builder); // constrain the id - for (index, cp) in cpb.predicates.iter().enumerate() { - let entry = CustomPredicateEntryTarget { - id, // output - index: builder.constant(F::from_canonical_usize(index)), // constant - predicate: cp.clone(), // input - }; - let in_query_hash = entry.hash(builder); - custom_predicate_table.push(in_query_hash); - } - custom_predicate_batches.push(cpb); // We keep this for witness assignment - } + // Table of custom predicate batches with batch_id calculation + let (custom_predicate_table, custom_predicate_batches) = + self.build_custom_predicate_table(builder)?; - // Table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] - // with queryable part as hash([st, op, op_args]). While building the table we verify each - // custom predicate against the operation and statement. - let mut custom_predicate_verifications = - Vec::with_capacity(params.max_custom_predicate_verifications); - let mut custom_predicate_verification_table = - Vec::with_capacity(params.max_custom_predicate_verifications); - for _ in 0..params.max_custom_predicate_verifications { - let custom_predicate_table_index = builder.add_virtual_target(); - let custom_predicate = builder.add_virtual_custom_predicate_entry(&self.params); - let args = (0..params.max_custom_predicate_wildcards) - .map(|_| builder.add_virtual_value()) - .collect_vec(); - let op_args = (0..params.max_operation_args) - .map(|_| builder.add_virtual_statement(&self.params)) - .collect_vec(); - - // Verify the custom predicate operation - let (statement, op_type) = CustomOperationVerifyGadget { - params: params.clone(), - } - .eval(builder, &custom_predicate, &op_args, &args)?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = builder.vec_ref( - &self.params, - &custom_predicate_table, - custom_predicate_table_index, - ); - let out_query_hash = custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let entry = CustomPredicateVerifyEntryTarget { - custom_predicate_table_index, // input - custom_predicate, // input - args, // input - query: CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args, // input - }, - }; - let in_query_hash = entry.query.hash(builder); - custom_predicate_verification_table.push(in_query_hash); - custom_predicate_verifications.push(entry); // We keep this for witness assignment - } + // Table of custom predicate statements verification against operations + let (custom_predicate_verification_table, custom_predicate_verifications) = + self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?; // 2. Calculate the Pod Id from the public statements let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect(); @@ -2193,7 +2251,7 @@ mod tests { custom_predicate: CustomPredicateRef, op_args: Vec, args: Vec, - expected_st: Statement, + expected_st: Option, ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); @@ -2226,22 +2284,22 @@ mod tests { arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; } // Expected Output - st_target.set_targets(&mut pw, params, &expected_st.into())?; + if let Some(expected_st) = expected_st { + st_target.set_targets(&mut pw, params, &expected_st.into())?; + } let expected_op_type = OperationType::Custom(custom_predicate); op_type_target.set_targets(&mut pw, params, &expected_op_type)?; // generate & verify proof let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) } // TODO: Add negative tests #[test] - fn test_custom_operation_verify_gadget() -> frontend::Result<()> { + fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { // We set the parameters to the exact sizes we have in the test so that we don't have to // pad. let params = Params { @@ -2298,7 +2356,7 @@ mod tests { custom_predicate, op_args, args, - expected_st, + Some(expected_st), ) .unwrap(); @@ -2322,7 +2380,7 @@ mod tests { custom_predicate, op_args, args, - expected_st, + Some(expected_st), ) .unwrap(); @@ -2349,10 +2407,190 @@ mod tests { custom_predicate, op_args, args, - expected_st, + Some(expected_st), ) .unwrap(); Ok(()) } + + #[test] + fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { + // We set the parameters to the exact sizes we have in the test so that we don't have to + // pad. + let params = Params { + max_custom_predicate_arity: 2, + max_custom_predicate_wildcards: 2, + max_operation_args: 2, + max_statement_args: 2, + ..Default::default() + }; + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new(NP::ValueOf) + .arg(("id", key("score"))) + .arg(literal(42)); + let stb1 = STB::new(NP::Equal) + .arg(("id", "secret_key")) + .arg(("id", key("score"))); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret_key"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret_key"], &[stb0, stb1])?; + let batch = builder.finish(); + + let pod_id = PodId(hash_str("pod_id")); + + // AND (0) Sanity check with correct values + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(42), + ), + Statement::Equal( + AnchoredKey::new(pod_id, Key::from("foo")), + AnchoredKey::new(pod_id, Key::from("score")), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), WildcardValue::None], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // AND (1) Different pod_id for same wildcard + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(42), + ), + Statement::Equal( + AnchoredKey::new(PodId(hash_str("BAD")), Key::from("foo")), + AnchoredKey::new(pod_id, Key::from("score")), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (2) key doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf(AnchoredKey::new(pod_id, Key::from("BAD")), Value::from(42)), + Statement::Equal( + AnchoredKey::new(pod_id, Key::from("foo")), + AnchoredKey::new(pod_id, Key::from("score")), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (3) literal doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(0xbad), + ), + Statement::Equal( + AnchoredKey::new(pod_id, Key::from("foo")), + AnchoredKey::new(pod_id, Key::from("score")), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // AND (4) predicate doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(42), + ), + Statement::NotEqual( + AnchoredKey::new(pod_id, Key::from("foo")), + AnchoredKey::new(pod_id, Key::from("score")), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None, + ) + .is_err()); + + // OR (1) Two Nones + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![Statement::None, Statement::None]; + let args = vec![WildcardValue::PodId(pod_id), WildcardValue::None]; + + assert!(helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + None + ) + .is_err()); + + Ok(()) + } } diff --git a/src/backends/plonky2/circuits/mod.rs b/src/backends/plonky2/circuits/mod.rs index 1afbc5e..ea1f498 100644 --- a/src/backends/plonky2/circuits/mod.rs +++ b/src/backends/plonky2/circuits/mod.rs @@ -1,3 +1,4 @@ pub mod common; pub mod mainpod; pub mod signedpod; +pub mod utils; diff --git a/src/backends/plonky2/circuits/utils.rs b/src/backends/plonky2/circuits/utils.rs new file mode 100644 index 0000000..fa4c5ad --- /dev/null +++ b/src/backends/plonky2/circuits/utils.rs @@ -0,0 +1,73 @@ +use plonky2::{ + field::extension::Extendable, + hash::hash_types::RichField, + iop::{ + generator::{GeneratedValues, SimpleGenerator}, + target::Target, + witness::{PartitionWitness, Witness}, + }, + plonk::circuit_data::CommonCircuitData, + util::serialization::{Buffer, IoResult, Read, Write}, +}; + +/// Plonky2 generator that allows debugging values assigned to targets. This generator doesn't +/// actually generate any value and doesn't assign any witness. Instead it can be registered to +/// monitor targets and print their values once they are available. +/// +/// Example usage: +/// ```rust,ignore +/// builder.add_simple_generator(DebugGenerator::new( +/// format!("values_{}", i), +/// vec![v1, v2, v3], +/// )); +/// ``` +#[derive(Debug, Default)] +pub struct DebugGenerator { + pub(crate) name: String, + pub(crate) xs: Vec, +} + +impl DebugGenerator { + pub fn new(name: String, xs: Vec) -> Self { + Self { name, xs } + } +} + +impl, const D: usize> SimpleGenerator for DebugGenerator { + fn id(&self) -> String { + "DebugGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + self.xs.clone() + } + + fn run_once( + &self, + witness: &PartitionWitness, + _out_buffer: &mut GeneratedValues, + ) -> anyhow::Result<()> { + let xs = witness.get_targets(&self.xs); + + println!("debug: values of {}", self.name); + for (i, x) in xs.iter().enumerate() { + println!("- {:03}: {}", i, x); + } + Ok(()) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.name.len())?; + dst.write_all(self.name.as_bytes())?; + dst.write_target_vec(&self.xs) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let name_len = src.read_usize()?; + let mut name_buf = vec![0; name_len]; + src.read_exact(&mut name_buf)?; + let name = unsafe { String::from_utf8_unchecked(name_buf) }; + let xs = src.read_target_vec()?; + Ok(Self { name, xs }) + } +} diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 6bb9218..73b14d5 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -87,7 +87,7 @@ pub(crate) fn extract_custom_predicate_verifications( .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) .expect("find the custom predicate from the extracted unique list"); let custom_predicate_table_index = - batch_index * params.max_custom_predicate_batches + cpr.index; + batch_index * params.max_custom_batch_size + cpr.index; CustomPredicateVerification { custom_predicate_table_index, custom_predicate: cpr.clone(), @@ -150,20 +150,18 @@ pub(crate) fn extract_merkle_proofs( /// Find the operation argument statement in the list of previous statements and return the index. fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Result { - match op_arg { - middleware::Statement::None => Ok(OperationArg::None), - _ => statements - .iter() - .enumerate() - .find_map(|(i, s)| { - (&middleware::Statement::try_from(s.clone()).ok()? == op_arg).then_some(i) - }) - .map(OperationArg::Index) - .ok_or(Error::custom(format!( - "Statement corresponding to op arg {} not found", - op_arg - ))), - } + // NOTE: The `None` `Statement` always exists as a constant at index 0 + statements + .iter() + .enumerate() + .find_map(|(i, s)| { + (&middleware::Statement::try_from(s.clone()).ok()? == op_arg).then_some(i) + }) + .map(OperationArg::Index) + .ok_or(Error::custom(format!( + "Statement corresponding to op arg {} not found", + op_arg + ))) } /// Find the operation auxiliary data in the list of auxiliary data and return the index. @@ -179,15 +177,21 @@ fn find_op_aux( op: &middleware::Operation, ) -> Result { let op_aux = op.aux(); - let op_type = op.op_type(); - if let (OperationType::Custom(cpr), Some(cpvs)) = (op_type, custom_predicate_verifications) { + if let (middleware::Operation::Custom(cpr, op_args), Some(cpvs)) = + (op, custom_predicate_verifications) + { return Ok(cpvs .iter() .enumerate() .find_map(|(i, cpv)| { (cpv.custom_predicate.batch.id() == cpr.batch.id() - && cpv.custom_predicate.index == cpr.index) - .then_some(i) + && cpv.custom_predicate.index == cpr.index + && cpv + .op_args + .iter() + .zip_eq(op_args.iter()) + .all(|(a0, a1)| a0.0 == a1.predicate() && a0.1 == a1.args())) + .then_some(i) }) .map(OperationAux::CustomPredVerifyIndex) .expect("custom predicate verification in the list")); @@ -228,6 +232,10 @@ fn pad_operation_args(params: &Params, args: &mut Vec) { pub(crate) fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec { let mut statements = Vec::new(); + // Statement at index 0 is always None to be used for padding operation arguments in custom + // predicate statements + statements.push(middleware::Statement::None.into()); + // Input signed pods region let none_sig_pod_box: Box = Box::new(NonePod {}); let none_sig_pod = none_sig_pod_box.as_ref(); @@ -531,10 +539,16 @@ pub mod tests { primitives::signature::SecretKey, signedpod::Signer, }, - examples::{zu_kyc_pod_builder, zu_kyc_sign_pod_builders}, - frontend::{self}, + examples::{ + eth_dos_pod_builder, eth_friend_signed_pod_builder, zu_kyc_pod_builder, + zu_kyc_sign_pod_builders, + }, + frontend::{ + key, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, + {self}, + }, middleware, - middleware::RawValue, + middleware::{CustomPredicateRef, NativePredicate as NP, RawValue}, op, }; @@ -644,4 +658,109 @@ pub mod tests { let pod = (kyc_pod.pod as Box).downcast::().unwrap(); pod.verify().unwrap() } + + #[test] + fn test_main_ethdos() -> frontend::Result<()> { + let params = Params { + max_input_signed_pods: 2, + max_input_main_pods: 1, + max_statements: 26, + max_public_statements: 5, + max_signed_pod_values: 8, + max_statement_args: 6, + max_operation_args: 4, + max_custom_predicate_arity: 4, + max_custom_batch_size: 3, + max_custom_predicate_wildcards: 12, + max_custom_predicate_verifications: 8, + ..Default::default() + }; + + let mut alice = Signer(SecretKey(RawValue::from(1))); + let bob = Signer(SecretKey(RawValue::from(2))); + let mut charlie = Signer(SecretKey(RawValue::from(3))); + + // Alice attests that she is ETH friends with Charlie and Charlie + // attests that he is ETH friends with Bob. + let alice_attestation = + eth_friend_signed_pod_builder(¶ms, charlie.public_key().0.into()) + .sign(&mut alice)?; + let charlie_attestation = + eth_friend_signed_pod_builder(¶ms, bob.public_key().0.into()).sign(&mut charlie)?; + + let alice_bob_ethdos_builder = eth_dos_pod_builder( + ¶ms, + false, + &alice_attestation, + &charlie_attestation, + bob.public_key().0.into(), + )?; + + let mut prover = MockProver {}; + let pod = alice_bob_ethdos_builder.prove(&mut prover, ¶ms)?; + assert!(pod.pod.verify().is_ok()); + + let mut prover = Prover {}; + let alice_bob_ethdos = alice_bob_ethdos_builder.prove(&mut prover, ¶ms)?; + let pod = (alice_bob_ethdos.pod as Box) + .downcast::() + .unwrap(); + + Ok(pod.verify()?) + } + + #[test] + fn test_main_mini_custom_1() -> frontend::Result<()> { + let params = Params { + max_input_signed_pods: 0, + max_input_main_pods: 0, + max_statements: 9, + max_public_statements: 4, + max_statement_args: 3, + max_operation_args: 3, + max_custom_predicate_arity: 3, + max_custom_batch_size: 3, + max_custom_predicate_wildcards: 4, + max_custom_predicate_verifications: 2, + ..Default::default() + }; + + let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); + let stb0 = STB::new(NP::ValueOf) + .arg(("id", key("score"))) + .arg(literal(42)); + let stb1 = STB::new(NP::Equal) + .arg(("id", "secret_key")) + .arg(("id", key("score"))); + let _ = cpb_builder.predicate_and( + "pred_and", + &["id"], + &["secret_key"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = cpb_builder.predicate_or("pred_or", &["id"], &["secret_key"], &[stb0, stb1])?; + let cpb = cpb_builder.finish(); + + let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); + let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); + + let mut pod_builder = MainPodBuilder::new(¶ms); + + let st0 = pod_builder.priv_op(op!(new_entry, ("score", 42)))?; + let st1 = pod_builder.priv_op(op!(new_entry, ("foo", 42)))?; + let st2 = pod_builder.priv_op(op!(eq, st1.clone(), st0.clone()))?; + + let _st3 = pod_builder.priv_op(op!(custom, cpb_and.clone(), st0, st2))?; + + let mut prover = MockProver {}; + let pod = pod_builder.prove(&mut prover, ¶ms)?; + assert!(pod.pod.verify().is_ok()); + + let mut prover = Prover {}; + let pod = pod_builder.prove(&mut prover, ¶ms)?; + + let pod = (pod.pod as Box).downcast::().unwrap(); + + Ok(pod.verify()?) + } } diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index c838c8c..d7528ed 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -57,11 +57,14 @@ impl fmt::Display for MockMainPod { writeln!(f, "MockMainPod ({}):", self.id)?; // TODO print input signed pods id and type // TODO print input main pods id and type + let offset_input_signed_pods = Self::offset_input_signed_pods(); let offset_input_main_pods = self.offset_input_main_pods(); let offset_input_statements = self.offset_input_statements(); let offset_public_statements = self.offset_public_statements(); for (i, st) in self.statements.iter().enumerate() { - if (i < self.offset_input_main_pods()) && (i % self.params.max_signed_pod_values == 0) { + if (i >= offset_input_signed_pods && i < offset_input_main_pods) + && ((i - offset_input_signed_pods) % self.params.max_signed_pod_values == 0) + { writeln!( f, " from input SignedPod {}:", @@ -125,8 +128,12 @@ fn fmt_statement_index( /// - private Statements /// - public Statements impl MockMainPod { + fn offset_input_signed_pods() -> usize { + 1 + } fn offset_input_main_pods(&self) -> usize { - self.params.max_input_signed_pods * self.params.max_signed_pod_values + Self::offset_input_signed_pods() + + self.params.max_input_signed_pods * self.params.max_signed_pod_values } fn offset_input_statements(&self) -> usize { self.offset_input_main_pods() diff --git a/src/backends/plonky2/mock/signedpod.rs b/src/backends/plonky2/mock/signedpod.rs index 0f862ee..e148388 100644 --- a/src/backends/plonky2/mock/signedpod.rs +++ b/src/backends/plonky2/mock/signedpod.rs @@ -19,7 +19,7 @@ pub struct MockSigner { } impl MockSigner { - pub fn pubkey(&self) -> Hash { + pub fn public_key(&self) -> Hash { hash_str(&self.pk) } } @@ -27,7 +27,7 @@ impl MockSigner { impl MockSigner { fn _sign(&mut self, _params: &Params, kvs: &HashMap) -> Result { let mut kvs = kvs.clone(); - let pubkey = self.pubkey(); + let pubkey = self.public_key(); kvs.insert(Key::from(KEY_SIGNER), Value::from(pubkey)); kvs.insert(Key::from(KEY_TYPE), Value::from(PodType::MockSigned)); diff --git a/src/backends/plonky2/primitives/signature/mod.rs b/src/backends/plonky2/primitives/signature/mod.rs index bb3a4e1..eef2dfb 100644 --- a/src/backends/plonky2/primitives/signature/mod.rs +++ b/src/backends/plonky2/primitives/signature/mod.rs @@ -55,7 +55,7 @@ pub struct VerifierParams(pub(crate) VerifierCircuitData); pub struct SecretKey(pub(crate) RawValue); #[derive(Clone, Debug)] -pub struct PublicKey(pub(crate) RawValue); +pub struct PublicKey(pub RawValue); #[derive(Clone, Debug)] pub struct Signature(pub(crate) Proof); diff --git a/src/backends/plonky2/signedpod.rs b/src/backends/plonky2/signedpod.rs index 12b98e5..5fd2cd8 100644 --- a/src/backends/plonky2/signedpod.rs +++ b/src/backends/plonky2/signedpod.rs @@ -36,6 +36,10 @@ impl Signer { dict, }) } + + pub fn public_key(&self) -> PublicKey { + self.0.public_key() + } } impl PodSigner for Signer { diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 65c4c57..f059b12 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -11,7 +11,12 @@ use crate::{ }; /// Instantiates an ETH friend batch -pub fn eth_friend_batch(params: &Params) -> Result> { +pub fn eth_friend_batch(params: &Params, mock: bool) -> Result> { + let pod_type = if mock { + PodType::MockSigned + } else { + PodType::Signed + }; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "eth_friend".into()); let _eth_friend = builder.predicate_and( "eth_friend", @@ -24,7 +29,7 @@ pub fn eth_friend_batch(params: &Params) -> Result> { // there is an attestation pod that's a SignedPod STB::new(NP::ValueOf) .arg(("attestation_pod", key(KEY_TYPE))) - .arg(literal(PodType::MockSigned)), // TODO + .arg(literal(pod_type)), // the attestation pod is signed by (src_or, src_key) STB::new(NP::Equal) .arg(("attestation_pod", key(KEY_SIGNER))) @@ -41,8 +46,8 @@ pub fn eth_friend_batch(params: &Params) -> Result> { } /// Instantiates an ETHDoS batch -pub fn eth_dos_batch(params: &Params) -> Result> { - let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0)); +pub fn eth_dos_batch(params: &Params, mock: bool) -> Result> { + let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params, mock)?, 0)); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "eth_dos_distance_base".into()); diff --git a/src/examples/mod.rs b/src/examples/mod.rs index f94d099..49342be 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -69,10 +69,7 @@ pub fn zu_kyc_pod_builder( // ETHDoS -pub fn eth_friend_signed_pod_builder( - params: &Params, - friend_pubkey: TypedValue, -) -> SignedPodBuilder { +pub fn eth_friend_signed_pod_builder(params: &Params, friend_pubkey: Value) -> SignedPodBuilder { let mut attestation = SignedPodBuilder::new(params); attestation.insert("attestation", friend_pubkey); @@ -81,13 +78,14 @@ pub fn eth_friend_signed_pod_builder( pub fn eth_dos_pod_builder( params: &Params, + mock: bool, alice_attestation: &SignedPod, charlie_attestation: &SignedPod, - bob_pubkey: &TypedValue, + bob_pubkey: Value, ) -> Result { // Will need ETH friend and ETH DoS custom predicate batches. - let eth_friend = CustomPredicateRef::new(eth_friend_batch(params)?, 0); - let eth_dos_batch = eth_dos_batch(params)?; + let eth_friend = CustomPredicateRef::new(eth_friend_batch(params, mock)?, 0); + let eth_dos_batch = eth_dos_batch(params, mock)?; let eth_dos_base = CustomPredicateRef::new(eth_dos_batch.clone(), 0); let eth_dos_ind = CustomPredicateRef::new(eth_dos_batch.clone(), 1); let eth_dos = CustomPredicateRef::new(eth_dos_batch.clone(), 2); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index ace5720..c08beaa 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -274,12 +274,12 @@ mod tests { params.print_serialized_sizes(); // ETH friend custom predicate batch - let eth_friend = eth_friend_batch(¶ms)?; + let eth_friend = eth_friend_batch(¶ms, true)?; // This batch only has 1 predicate, so we pick it already for convenience let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend, 0)); - let eth_dos_batch = eth_dos_batch(¶ms)?; + let eth_dos_batch = eth_dos_batch(¶ms, true)?; let eth_dos_batch_mw: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(eth_dos_batch); let fields = eth_dos_batch_mw.to_fields(¶ms); diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 41efddc..08ce601 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -793,7 +793,7 @@ pub mod build_utils { (max_of, $($arg:expr),+) => { $crate::frontend::Operation( $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::MaxOf), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; - (custom, $op:expr, $($arg:expr),+) => { $crate::frontend::Operation( + (custom, $op:expr, $($arg:expr),*) => { $crate::frontend::Operation( $crate::middleware::OperationType::Custom($op), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (dict_contains, $dict:expr, $key:expr, $value:expr) => { $crate::frontend::Operation( @@ -925,18 +925,19 @@ pub mod tests { // Alice attests that she is ETH friends with Charlie and Charlie // attests that he is ETH friends with Bob. let alice_attestation = - eth_friend_signed_pod_builder(¶ms, charlie.pubkey().into()).sign(&mut alice)?; + eth_friend_signed_pod_builder(¶ms, charlie.public_key().into()).sign(&mut alice)?; check_kvs(&alice_attestation)?; let charlie_attestation = - eth_friend_signed_pod_builder(¶ms, bob.pubkey().into()).sign(&mut charlie)?; + eth_friend_signed_pod_builder(¶ms, bob.public_key().into()).sign(&mut charlie)?; check_kvs(&charlie_attestation)?; let mut prover = MockProver {}; let alice_bob_ethdos = eth_dos_pod_builder( ¶ms, + true, &alice_attestation, &charlie_attestation, - &bob.pubkey().into(), + bob.public_key().into(), )? .prove(&mut prover, ¶ms)?; diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 64158f4..960b0cc 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -276,16 +276,17 @@ mod tests { // Alice attests that she is ETH friends with Charlie and Charlie // attests that he is ETH friends with Bob. let alice_attestation = - eth_friend_signed_pod_builder(¶ms, charlie.pubkey().into()).sign(&mut alice)?; + eth_friend_signed_pod_builder(¶ms, charlie.public_key().into()).sign(&mut alice)?; let charlie_attestation = - eth_friend_signed_pod_builder(¶ms, bob.pubkey().into()).sign(&mut charlie)?; + eth_friend_signed_pod_builder(¶ms, bob.public_key().into()).sign(&mut charlie)?; let mut prover = MockProver {}; let alice_bob_ethdos = eth_dos_pod_builder( ¶ms, + true, &alice_attestation, &charlie_attestation, - &bob.pubkey().into(), + bob.public_key().into(), )? .prove(&mut prover, ¶ms)?; diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 3f2fe2e..27109b9 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -276,9 +276,9 @@ impl CustomPredicate { pub fn pad_statement_tmpl(&self) -> StatementTmpl { StatementTmpl { pred: Predicate::Native(if self.conjunction { - NativePredicate::False - } else { NativePredicate::None + } else { + NativePredicate::False }), args: vec![], }