From c92839d897b79fef760b84d638b16848cd7ee284 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 3 Mar 2025 05:38:51 +0100 Subject: [PATCH] limit the number of StatementTmpl in CustomPredicate: (#101) * limit the number of StatementTmpl in CustomPredicate: - add constructor method for CustomPredicate - make size checks at the CustomPredicate creation, so that once instantiated we can assume that contains valid data This resolves #79 * Update tests to use new interface --------- Co-authored-by: Ahmad --- src/backends/plonky2/mock_main/mod.rs | 2 +- src/frontend/custom.rs | 40 +++++----- src/frontend/operation.rs | 2 +- src/middleware/custom.rs | 104 +++++++++++++++++--------- src/middleware/mod.rs | 32 ++++---- src/middleware/operation.rs | 18 +++-- 6 files changed, 119 insertions(+), 79 deletions(-) diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index d134a5d..f5352e5 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -435,7 +435,7 @@ impl Pod for MockMainPod { self.operations[i] .deref(&self.statements[..input_statement_offset + i]) .unwrap() - .check(&s.clone().try_into().unwrap()) + .check(&self.params, &s.clone().try_into().unwrap()) }) .collect::>>() .unwrap(); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 5cdba95..a589036 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -1,8 +1,9 @@ #![allow(unused)] +use anyhow::Result; use std::sync::Arc; use crate::middleware::{ - hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, + hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, Params, Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F, }; @@ -96,31 +97,34 @@ impl CustomPredicateBatchBuilder { fn predicate_and( &mut self, + params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(true, args, priv_args, sts) + ) -> Result { + self.predicate(params, true, args, priv_args, sts) } fn predicate_or( &mut self, + params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(false, args, priv_args, sts) + ) -> Result { + self.predicate(params, false, args, priv_args, sts) } /// creates the custom predicate from the given input, adds it to the /// self.predicates, and returns the index of the created predicate fn predicate( &mut self, + params: &Params, conjunction: bool, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - ) -> Predicate { + ) -> Result { let statements = sts .iter() .map(|sb| { @@ -138,13 +142,9 @@ impl CustomPredicateBatchBuilder { StatementTmpl(sb.predicate.clone(), args) }) .collect(); - let custom_predicate = CustomPredicate { - conjunction, - statements, - args_len: args.len(), - }; + let custom_predicate = CustomPredicate::new(params, conjunction, statements, args.len())?; self.predicates.push(custom_predicate); - Predicate::BatchSelf(self.predicates.len() - 1) + Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } fn finish(self) -> Arc { @@ -174,7 +174,7 @@ mod tests { use crate::middleware::{CustomPredicateRef, Params, PodType}; #[test] - fn test_custom_pred() { + fn test_custom_pred() -> Result<()> { use NativePredicate as NP; use StatementTmplBuilder as STB; @@ -183,6 +183,7 @@ mod tests { let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); let _eth_friend = builder.predicate_and( + ¶ms, // arguments: &["src_ori", "src_key", "dst_ori", "dst_key"], // private arguments: @@ -202,7 +203,7 @@ mod tests { .arg(("attestation_pod", literal("attestation"))) .arg(("dst_ori", "dst_key")), ], - ); + )?; println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); let eth_friend = builder.finish(); @@ -216,6 +217,7 @@ mod tests { // > let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); let eth_dos_distance_base = builder.predicate_and( + ¶ms, &[ // arguments: "src_ori", @@ -236,7 +238,7 @@ mod tests { .arg(("distance_ori", "distance_key")) .arg(0), ], - ); + )?; println!( "b.0. eth_dos_distance_base = {}", builder.predicates.last().unwrap() @@ -246,6 +248,7 @@ mod tests { // next chunk builds: let eth_dos_distance_ind = builder.predicate_and( + ¶ms, &[ // arguments: "src_ori", @@ -281,7 +284,7 @@ mod tests { .arg(("intermed_ori", "intermed_key")) .arg(("dst_ori", "dst_key")), ], - ); + )?; println!( "b.1. eth_dos_distance_ind = {}", @@ -289,6 +292,7 @@ mod tests { ); let _eth_dos_distance = builder.predicate_or( + ¶ms, &[ "src_ori", "src_key", @@ -308,7 +312,7 @@ mod tests { .arg(("dst_ori", "dst_key")) .arg(("distance_ori", "distance_key")), ], - ); + )?; println!( "b.2. eth_dos_distance = {}", @@ -318,5 +322,7 @@ mod tests { let eth_dos_batch_b = builder.finish(); let fields = eth_dos_batch_b.to_fields(¶ms); println!("Batch b, serialized: {:?}", fields); + + Ok(()) } } diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index ef8114f..85ccf4b 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,7 +1,7 @@ use std::fmt; use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; -use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate}; +use crate::middleware::{hash_str, NativePredicate, OperationType, Predicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 8756b97..5dbb1f4 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -195,14 +195,42 @@ impl ToFields for StatementTmpl { #[derive(Clone, Debug, PartialEq, Eq)] pub struct CustomPredicate { + /// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through + /// the `::and/or` methods, which performs checks on the values. + /// true for "and", false for "or" - pub conjunction: bool, - pub statements: Vec, - pub args_len: usize, + pub(crate) conjunction: bool, + pub(crate) statements: Vec, + pub(crate) args_len: usize, // TODO: Add private args length? // TODO: Add args type information? } +impl CustomPredicate { + pub fn and(params: &Params, statements: Vec, args_len: usize) -> Result { + Self::new(params, true, statements, args_len) + } + pub fn or(params: &Params, statements: Vec, args_len: usize) -> Result { + Self::new(params, false, statements, args_len) + } + pub fn new( + params: &Params, + conjunction: bool, + statements: Vec, + args_len: usize, + ) -> Result { + if statements.len() > params.max_custom_predicate_arity { + return Err(anyhow!("Custom predicate depends on too many statements")); + } + + Ok(Self { + conjunction, + statements, + args_len, + }) + } +} + impl ToFields for CustomPredicate { fn to_fields(&self, params: &Params) -> (Vec, usize) { // serialize as: @@ -212,9 +240,9 @@ impl ToFields for CustomPredicate { // (params.max_custom_predicate_arity * params.statement_tmpl_size()) // field elements - // TODO think if this check should go into the StatementTmpl creation, - // instead of at the `to_fields` method, where we should assume that the - // values are already valid + // NOTE: this method assumes that the self.params.len() is inside the + // expected bound, as Self should be instantiated with the constructor + // method `new` which performs the check. if self.statements.len() > params.max_custom_predicate_arity { panic!("Custom predicate depends on too many statements"); } @@ -353,7 +381,7 @@ mod tests { use crate::middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash, - HashOrWildcard, NativePredicate, Operation, PodId, PodType, Predicate, Statement, + HashOrWildcard, NativePredicate, Operation, Params, PodId, PodType, Predicate, Statement, StatementTmpl, StatementTmplArg, SELF, }; @@ -368,6 +396,8 @@ mod tests { #[test] fn is_double_test() -> Result<()> { + let params = Params::default(); + /* is_double(S1, S2) :- p:value_of(Constant, 2), @@ -375,9 +405,9 @@ mod tests { */ let cust_pred_batch = Arc::new(CustomPredicateBatch { name: "is_double".to_string(), - predicates: vec![CustomPredicate { - conjunction: true, - statements: vec![ + predicates: vec![CustomPredicate::and( + ¶ms, + vec![ st( P::Native(NP::ValueOf), vec![ @@ -394,8 +424,8 @@ mod tests { ], ), ], - args_len: 4, - }], + 4, + )?], }); let custom_statement = Statement::Custom( @@ -418,16 +448,18 @@ mod tests { ], ); - assert!(custom_deduction.check(&custom_statement)?); + assert!(custom_deduction.check(¶ms, &custom_statement)?); Ok(()) } #[test] fn ethdos_test() -> Result<()> { - let eth_friend_cp = CustomPredicate { - conjunction: true, - statements: vec![ + let params = Params::default(); + + let eth_friend_cp = CustomPredicate::and( + ¶ms, + vec![ st( P::Native(NP::ValueOf), vec![ @@ -450,17 +482,17 @@ mod tests { ], ), ], - args_len: 4, - }; + 4, + )?; let eth_friend_batch = Arc::new(CustomPredicateBatch { name: "eth_friend".to_string(), predicates: vec![eth_friend_cp], }); - let eth_dos_base = CustomPredicate { - conjunction: true, - statements: vec![ + let eth_dos_base = CustomPredicate::and( + ¶ms, + vec![ st( P::Native(NP::Equal), vec![ @@ -476,12 +508,12 @@ mod tests { ], ), ], - args_len: 6, - }; + 6, + )?; - let eth_dos_ind = CustomPredicate { - conjunction: true, - statements: vec![ + let eth_dos_ind = CustomPredicate::and( + ¶ms, + vec![ st( P::BatchSelf(2), vec![ @@ -513,12 +545,12 @@ mod tests { ], ), ], - args_len: 6, - }; + 6, + )?; - let eth_dos_distance_either = CustomPredicate { - conjunction: false, - statements: vec![ + let eth_dos_distance_either = CustomPredicate::or( + ¶ms, + vec![ st( P::BatchSelf(0), vec![ @@ -536,8 +568,8 @@ mod tests { ], ), ], - args_len: 6, - }; + 6, + )?; let eth_dos_distance_batch = Arc::new(CustomPredicateBatch { name: "ETHDoS_distance".to_string(), @@ -561,7 +593,7 @@ mod tests { ); // Copies should work. - assert!(Operation::CopyStatement(ethdos_example.clone()).check(ðdos_example)?); + assert!(Operation::CopyStatement(ethdos_example.clone()).check(¶ms, ðdos_example)?); // This could arise as the inductive step. let ethdos_ind_example = Statement::Custom( @@ -577,7 +609,7 @@ mod tests { CustomPredicateRef(eth_dos_distance_batch.clone(), 2), vec![ethdos_ind_example.clone()] ) - .check(ðdos_example)?); + .check(¶ms, ðdos_example)?); // And the inductive step would arise as follows: Say the // ETHDoS distance from Alice to Charlie is 6, which is one @@ -610,7 +642,7 @@ mod tests { CustomPredicateRef(eth_dos_distance_batch.clone(), 1), ethdos_facts ) - .check(ðdos_ind_example)?); + .check(¶ms, ðdos_ind_example)?); Ok(()) } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 4fcbeba..217db74 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -92,6 +92,22 @@ pub struct Params { pub max_custom_batch_size: usize, } +impl Default for Params { + fn default() -> Self { + Self { + max_input_signed_pods: 3, + max_input_main_pods: 3, + max_statements: 20, + max_signed_pod_values: 8, + max_public_statements: 10, + max_statement_args: 5, + max_operation_args: 5, + max_custom_predicate_arity: 5, + max_custom_batch_size: 5, + } + } +} + impl Params { pub fn max_priv_statements(&self) -> usize { self.max_statements - self.max_public_statements @@ -134,22 +150,6 @@ impl Params { } } -impl Default for Params { - fn default() -> Self { - Self { - max_input_signed_pods: 3, - max_input_main_pods: 3, - max_statements: 20, - max_signed_pod_values: 8, - max_public_statements: 10, - max_statement_args: 5, - max_operation_args: 5, - max_custom_predicate_arity: 5, - max_custom_batch_size: 5, - } - } -} - pub trait Pod: fmt::Debug + DynClone { fn verify(&self) -> bool; fn id(&self) -> PodId; diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index dd22cc3..3b921b1 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -4,7 +4,9 @@ use anyhow::{anyhow, Result}; use super::{CustomPredicateRef, Statement}; use crate::{ - middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF}, + middleware::{ + AnchoredKey, CustomPredicate, Params, PodId, Predicate, StatementTmpl, Value, SELF, + }, util::hashmap_insert_no_dupe, }; @@ -145,7 +147,7 @@ impl Operation { }) } /// Checks the given operation against a statement. - pub fn check(&self, output_statement: &Statement) -> Result { + pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; match (self, output_statement) { (Self::None, None) => Ok(true), @@ -211,10 +213,10 @@ impl Operation { // references with custom predicate references. let custom_predicate = { let cp = (**cpb).predicates[*i].clone(); - CustomPredicate { - conjunction: cp.conjunction, - statements: cp - .statements + CustomPredicate::new( + params, + cp.conjunction, + cp.statements .into_iter() .map(|StatementTmpl(p, args)| { StatementTmpl( @@ -228,8 +230,8 @@ impl Operation { ) }) .collect(), - args_len: cp.args_len, - } + cp.args_len, + )? }; match custom_predicate.conjunction { true if custom_predicate.statements.len() == args.len() => {