limit the number of StatementTmpl in CustomPredicate: (#101)

* limit the number of StatementTmpl in CustomPredicate:

- add constructor method for CustomPredicate
- make size checks at the CustomPredicate creation, so that once instantiated we can assume that contains valid data

This resolves #79

* Update tests to use new interface

---------

Co-authored-by: Ahmad <root@ahmadafuni.com>
This commit is contained in:
arnaucube 2025-03-03 05:38:51 +01:00 committed by GitHub
parent c9f7427967
commit c92839d897
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 119 additions and 79 deletions

View file

@ -435,7 +435,7 @@ impl Pod for MockMainPod {
self.operations[i] self.operations[i]
.deref(&self.statements[..input_statement_offset + i]) .deref(&self.statements[..input_statement_offset + i])
.unwrap() .unwrap()
.check(&s.clone().try_into().unwrap()) .check(&self.params, &s.clone().try_into().unwrap())
}) })
.collect::<Result<Vec<_>>>() .collect::<Result<Vec<_>>>()
.unwrap(); .unwrap();

View file

@ -1,8 +1,9 @@
#![allow(unused)] #![allow(unused)]
use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
use crate::middleware::{ use crate::middleware::{
hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, Params,
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F, Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F,
}; };
@ -96,31 +97,34 @@ impl CustomPredicateBatchBuilder {
fn predicate_and( fn predicate_and(
&mut self, &mut self,
params: &Params,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Predicate { ) -> Result<Predicate> {
self.predicate(true, args, priv_args, sts) self.predicate(params, true, args, priv_args, sts)
} }
fn predicate_or( fn predicate_or(
&mut self, &mut self,
params: &Params,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Predicate { ) -> Result<Predicate> {
self.predicate(false, args, priv_args, sts) self.predicate(params, false, args, priv_args, sts)
} }
/// creates the custom predicate from the given input, adds it to the /// creates the custom predicate from the given input, adds it to the
/// self.predicates, and returns the index of the created predicate /// self.predicates, and returns the index of the created predicate
fn predicate( fn predicate(
&mut self, &mut self,
params: &Params,
conjunction: bool, conjunction: bool,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Predicate { ) -> Result<Predicate> {
let statements = sts let statements = sts
.iter() .iter()
.map(|sb| { .map(|sb| {
@ -138,13 +142,9 @@ impl CustomPredicateBatchBuilder {
StatementTmpl(sb.predicate.clone(), args) StatementTmpl(sb.predicate.clone(), args)
}) })
.collect(); .collect();
let custom_predicate = CustomPredicate { let custom_predicate = CustomPredicate::new(params, conjunction, statements, args.len())?;
conjunction,
statements,
args_len: args.len(),
};
self.predicates.push(custom_predicate); self.predicates.push(custom_predicate);
Predicate::BatchSelf(self.predicates.len() - 1) Ok(Predicate::BatchSelf(self.predicates.len() - 1))
} }
fn finish(self) -> Arc<CustomPredicateBatch> { fn finish(self) -> Arc<CustomPredicateBatch> {
@ -174,7 +174,7 @@ mod tests {
use crate::middleware::{CustomPredicateRef, Params, PodType}; use crate::middleware::{CustomPredicateRef, Params, PodType};
#[test] #[test]
fn test_custom_pred() { fn test_custom_pred() -> Result<()> {
use NativePredicate as NP; use NativePredicate as NP;
use StatementTmplBuilder as STB; use StatementTmplBuilder as STB;
@ -183,6 +183,7 @@ mod tests {
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into());
let _eth_friend = builder.predicate_and( let _eth_friend = builder.predicate_and(
&params,
// arguments: // arguments:
&["src_ori", "src_key", "dst_ori", "dst_key"], &["src_ori", "src_key", "dst_ori", "dst_key"],
// private arguments: // private arguments:
@ -202,7 +203,7 @@ mod tests {
.arg(("attestation_pod", literal("attestation"))) .arg(("attestation_pod", literal("attestation")))
.arg(("dst_ori", "dst_key")), .arg(("dst_ori", "dst_key")),
], ],
); )?;
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
let eth_friend = builder.finish(); let eth_friend = builder.finish();
@ -216,6 +217,7 @@ mod tests {
// > // >
let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into());
let eth_dos_distance_base = builder.predicate_and( let eth_dos_distance_base = builder.predicate_and(
&params,
&[ &[
// arguments: // arguments:
"src_ori", "src_ori",
@ -236,7 +238,7 @@ mod tests {
.arg(("distance_ori", "distance_key")) .arg(("distance_ori", "distance_key"))
.arg(0), .arg(0),
], ],
); )?;
println!( println!(
"b.0. eth_dos_distance_base = {}", "b.0. eth_dos_distance_base = {}",
builder.predicates.last().unwrap() builder.predicates.last().unwrap()
@ -246,6 +248,7 @@ mod tests {
// next chunk builds: // next chunk builds:
let eth_dos_distance_ind = builder.predicate_and( let eth_dos_distance_ind = builder.predicate_and(
&params,
&[ &[
// arguments: // arguments:
"src_ori", "src_ori",
@ -281,7 +284,7 @@ mod tests {
.arg(("intermed_ori", "intermed_key")) .arg(("intermed_ori", "intermed_key"))
.arg(("dst_ori", "dst_key")), .arg(("dst_ori", "dst_key")),
], ],
); )?;
println!( println!(
"b.1. eth_dos_distance_ind = {}", "b.1. eth_dos_distance_ind = {}",
@ -289,6 +292,7 @@ mod tests {
); );
let _eth_dos_distance = builder.predicate_or( let _eth_dos_distance = builder.predicate_or(
&params,
&[ &[
"src_ori", "src_ori",
"src_key", "src_key",
@ -308,7 +312,7 @@ mod tests {
.arg(("dst_ori", "dst_key")) .arg(("dst_ori", "dst_key"))
.arg(("distance_ori", "distance_key")), .arg(("distance_ori", "distance_key")),
], ],
); )?;
println!( println!(
"b.2. eth_dos_distance = {}", "b.2. eth_dos_distance = {}",
@ -318,5 +322,7 @@ mod tests {
let eth_dos_batch_b = builder.finish(); let eth_dos_batch_b = builder.finish();
let fields = eth_dos_batch_b.to_fields(&params); let fields = eth_dos_batch_b.to_fields(&params);
println!("Batch b, serialized: {:?}", fields); println!("Batch b, serialized: {:?}", fields);
Ok(())
} }
} }

View file

@ -1,7 +1,7 @@
use std::fmt; use std::fmt;
use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate}; use crate::middleware::{hash_str, NativePredicate, OperationType, Predicate};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg { pub enum OperationArg {

View file

@ -195,14 +195,42 @@ impl ToFields for StatementTmpl {
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicate { pub struct CustomPredicate {
/// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through
/// the `::and/or` methods, which performs checks on the values.
/// true for "and", false for "or" /// true for "and", false for "or"
pub conjunction: bool, pub(crate) conjunction: bool,
pub statements: Vec<StatementTmpl>, pub(crate) statements: Vec<StatementTmpl>,
pub args_len: usize, pub(crate) args_len: usize,
// TODO: Add private args length? // TODO: Add private args length?
// TODO: Add args type information? // TODO: Add args type information?
} }
impl CustomPredicate {
pub fn and(params: &Params, statements: Vec<StatementTmpl>, args_len: usize) -> Result<Self> {
Self::new(params, true, statements, args_len)
}
pub fn or(params: &Params, statements: Vec<StatementTmpl>, args_len: usize) -> Result<Self> {
Self::new(params, false, statements, args_len)
}
pub fn new(
params: &Params,
conjunction: bool,
statements: Vec<StatementTmpl>,
args_len: usize,
) -> Result<Self> {
if statements.len() > params.max_custom_predicate_arity {
return Err(anyhow!("Custom predicate depends on too many statements"));
}
Ok(Self {
conjunction,
statements,
args_len,
})
}
}
impl ToFields for CustomPredicate { impl ToFields for CustomPredicate {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
// serialize as: // serialize as:
@ -212,9 +240,9 @@ impl ToFields for CustomPredicate {
// (params.max_custom_predicate_arity * params.statement_tmpl_size()) // (params.max_custom_predicate_arity * params.statement_tmpl_size())
// field elements // field elements
// TODO think if this check should go into the StatementTmpl creation, // NOTE: this method assumes that the self.params.len() is inside the
// instead of at the `to_fields` method, where we should assume that the // expected bound, as Self should be instantiated with the constructor
// values are already valid // method `new` which performs the check.
if self.statements.len() > params.max_custom_predicate_arity { if self.statements.len() > params.max_custom_predicate_arity {
panic!("Custom predicate depends on too many statements"); panic!("Custom predicate depends on too many statements");
} }
@ -353,7 +381,7 @@ mod tests {
use crate::middleware::{ use crate::middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash, AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash,
HashOrWildcard, NativePredicate, Operation, PodId, PodType, Predicate, Statement, HashOrWildcard, NativePredicate, Operation, Params, PodId, PodType, Predicate, Statement,
StatementTmpl, StatementTmplArg, SELF, StatementTmpl, StatementTmplArg, SELF,
}; };
@ -368,6 +396,8 @@ mod tests {
#[test] #[test]
fn is_double_test() -> Result<()> { fn is_double_test() -> Result<()> {
let params = Params::default();
/* /*
is_double(S1, S2) :- is_double(S1, S2) :-
p:value_of(Constant, 2), p:value_of(Constant, 2),
@ -375,9 +405,9 @@ mod tests {
*/ */
let cust_pred_batch = Arc::new(CustomPredicateBatch { let cust_pred_batch = Arc::new(CustomPredicateBatch {
name: "is_double".to_string(), name: "is_double".to_string(),
predicates: vec![CustomPredicate { predicates: vec![CustomPredicate::and(
conjunction: true, &params,
statements: vec![ vec![
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![ vec![
@ -394,8 +424,8 @@ mod tests {
], ],
), ),
], ],
args_len: 4, 4,
}], )?],
}); });
let custom_statement = Statement::Custom( let custom_statement = Statement::Custom(
@ -418,16 +448,18 @@ mod tests {
], ],
); );
assert!(custom_deduction.check(&custom_statement)?); assert!(custom_deduction.check(&params, &custom_statement)?);
Ok(()) Ok(())
} }
#[test] #[test]
fn ethdos_test() -> Result<()> { fn ethdos_test() -> Result<()> {
let eth_friend_cp = CustomPredicate { let params = Params::default();
conjunction: true,
statements: vec![ let eth_friend_cp = CustomPredicate::and(
&params,
vec![
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![ vec![
@ -450,17 +482,17 @@ mod tests {
], ],
), ),
], ],
args_len: 4, 4,
}; )?;
let eth_friend_batch = Arc::new(CustomPredicateBatch { let eth_friend_batch = Arc::new(CustomPredicateBatch {
name: "eth_friend".to_string(), name: "eth_friend".to_string(),
predicates: vec![eth_friend_cp], predicates: vec![eth_friend_cp],
}); });
let eth_dos_base = CustomPredicate { let eth_dos_base = CustomPredicate::and(
conjunction: true, &params,
statements: vec![ vec![
st( st(
P::Native(NP::Equal), P::Native(NP::Equal),
vec![ vec![
@ -476,12 +508,12 @@ mod tests {
], ],
), ),
], ],
args_len: 6, 6,
}; )?;
let eth_dos_ind = CustomPredicate { let eth_dos_ind = CustomPredicate::and(
conjunction: true, &params,
statements: vec![ vec![
st( st(
P::BatchSelf(2), P::BatchSelf(2),
vec![ vec![
@ -513,12 +545,12 @@ mod tests {
], ],
), ),
], ],
args_len: 6, 6,
}; )?;
let eth_dos_distance_either = CustomPredicate { let eth_dos_distance_either = CustomPredicate::or(
conjunction: false, &params,
statements: vec![ vec![
st( st(
P::BatchSelf(0), P::BatchSelf(0),
vec![ vec![
@ -536,8 +568,8 @@ mod tests {
], ],
), ),
], ],
args_len: 6, 6,
}; )?;
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch { let eth_dos_distance_batch = Arc::new(CustomPredicateBatch {
name: "ETHDoS_distance".to_string(), name: "ETHDoS_distance".to_string(),
@ -561,7 +593,7 @@ mod tests {
); );
// Copies should work. // Copies should work.
assert!(Operation::CopyStatement(ethdos_example.clone()).check(&ethdos_example)?); assert!(Operation::CopyStatement(ethdos_example.clone()).check(&params, &ethdos_example)?);
// This could arise as the inductive step. // This could arise as the inductive step.
let ethdos_ind_example = Statement::Custom( let ethdos_ind_example = Statement::Custom(
@ -577,7 +609,7 @@ mod tests {
CustomPredicateRef(eth_dos_distance_batch.clone(), 2), CustomPredicateRef(eth_dos_distance_batch.clone(), 2),
vec![ethdos_ind_example.clone()] vec![ethdos_ind_example.clone()]
) )
.check(&ethdos_example)?); .check(&params, &ethdos_example)?);
// And the inductive step would arise as follows: Say the // And the inductive step would arise as follows: Say the
// ETHDoS distance from Alice to Charlie is 6, which is one // ETHDoS distance from Alice to Charlie is 6, which is one
@ -610,7 +642,7 @@ mod tests {
CustomPredicateRef(eth_dos_distance_batch.clone(), 1), CustomPredicateRef(eth_dos_distance_batch.clone(), 1),
ethdos_facts ethdos_facts
) )
.check(&ethdos_ind_example)?); .check(&params, &ethdos_ind_example)?);
Ok(()) Ok(())
} }

View file

@ -92,6 +92,22 @@ pub struct Params {
pub max_custom_batch_size: usize, pub max_custom_batch_size: usize,
} }
impl Default for Params {
fn default() -> Self {
Self {
max_input_signed_pods: 3,
max_input_main_pods: 3,
max_statements: 20,
max_signed_pod_values: 8,
max_public_statements: 10,
max_statement_args: 5,
max_operation_args: 5,
max_custom_predicate_arity: 5,
max_custom_batch_size: 5,
}
}
}
impl Params { impl Params {
pub fn max_priv_statements(&self) -> usize { pub fn max_priv_statements(&self) -> usize {
self.max_statements - self.max_public_statements self.max_statements - self.max_public_statements
@ -134,22 +150,6 @@ impl Params {
} }
} }
impl Default for Params {
fn default() -> Self {
Self {
max_input_signed_pods: 3,
max_input_main_pods: 3,
max_statements: 20,
max_signed_pod_values: 8,
max_public_statements: 10,
max_statement_args: 5,
max_operation_args: 5,
max_custom_predicate_arity: 5,
max_custom_batch_size: 5,
}
}
}
pub trait Pod: fmt::Debug + DynClone { pub trait Pod: fmt::Debug + DynClone {
fn verify(&self) -> bool; fn verify(&self) -> bool;
fn id(&self) -> PodId; fn id(&self) -> PodId;

View file

@ -4,7 +4,9 @@ use anyhow::{anyhow, Result};
use super::{CustomPredicateRef, Statement}; use super::{CustomPredicateRef, Statement};
use crate::{ use crate::{
middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF}, middleware::{
AnchoredKey, CustomPredicate, Params, PodId, Predicate, StatementTmpl, Value, SELF,
},
util::hashmap_insert_no_dupe, util::hashmap_insert_no_dupe,
}; };
@ -145,7 +147,7 @@ impl Operation {
}) })
} }
/// Checks the given operation against a statement. /// Checks the given operation against a statement.
pub fn check(&self, output_statement: &Statement) -> Result<bool> { pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*; use Statement::*;
match (self, output_statement) { match (self, output_statement) {
(Self::None, None) => Ok(true), (Self::None, None) => Ok(true),
@ -211,10 +213,10 @@ impl Operation {
// references with custom predicate references. // references with custom predicate references.
let custom_predicate = { let custom_predicate = {
let cp = (**cpb).predicates[*i].clone(); let cp = (**cpb).predicates[*i].clone();
CustomPredicate { CustomPredicate::new(
conjunction: cp.conjunction, params,
statements: cp cp.conjunction,
.statements cp.statements
.into_iter() .into_iter()
.map(|StatementTmpl(p, args)| { .map(|StatementTmpl(p, args)| {
StatementTmpl( StatementTmpl(
@ -228,8 +230,8 @@ impl Operation {
) )
}) })
.collect(), .collect(),
args_len: cp.args_len, cp.args_len,
} )?
}; };
match custom_predicate.conjunction { match custom_predicate.conjunction {
true if custom_predicate.statements.len() == args.len() => { true if custom_predicate.statements.len() == args.len() => {