limit the number of StatementTmpl in CustomPredicate: (#101)

* limit the number of StatementTmpl in CustomPredicate:

- add constructor method for CustomPredicate
- make size checks at the CustomPredicate creation, so that once instantiated we can assume that contains valid data

This resolves #79

* Update tests to use new interface

---------

Co-authored-by: Ahmad <root@ahmadafuni.com>
This commit is contained in:
arnaucube 2025-03-03 05:38:51 +01:00 committed by GitHub
parent c9f7427967
commit c92839d897
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 119 additions and 79 deletions

View file

@ -1,8 +1,9 @@
#![allow(unused)]
use anyhow::Result;
use std::sync::Arc;
use crate::middleware::{
hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate,
hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, Params,
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F,
};
@ -96,31 +97,34 @@ impl CustomPredicateBatchBuilder {
fn predicate_and(
&mut self,
params: &Params,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
self.predicate(true, args, priv_args, sts)
) -> Result<Predicate> {
self.predicate(params, true, args, priv_args, sts)
}
fn predicate_or(
&mut self,
params: &Params,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
self.predicate(false, args, priv_args, sts)
) -> Result<Predicate> {
self.predicate(params, false, args, priv_args, sts)
}
/// creates the custom predicate from the given input, adds it to the
/// self.predicates, and returns the index of the created predicate
fn predicate(
&mut self,
params: &Params,
conjunction: bool,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
) -> Result<Predicate> {
let statements = sts
.iter()
.map(|sb| {
@ -138,13 +142,9 @@ impl CustomPredicateBatchBuilder {
StatementTmpl(sb.predicate.clone(), args)
})
.collect();
let custom_predicate = CustomPredicate {
conjunction,
statements,
args_len: args.len(),
};
let custom_predicate = CustomPredicate::new(params, conjunction, statements, args.len())?;
self.predicates.push(custom_predicate);
Predicate::BatchSelf(self.predicates.len() - 1)
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
}
fn finish(self) -> Arc<CustomPredicateBatch> {
@ -174,7 +174,7 @@ mod tests {
use crate::middleware::{CustomPredicateRef, Params, PodType};
#[test]
fn test_custom_pred() {
fn test_custom_pred() -> Result<()> {
use NativePredicate as NP;
use StatementTmplBuilder as STB;
@ -183,6 +183,7 @@ mod tests {
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into());
let _eth_friend = builder.predicate_and(
&params,
// arguments:
&["src_ori", "src_key", "dst_ori", "dst_key"],
// private arguments:
@ -202,7 +203,7 @@ mod tests {
.arg(("attestation_pod", literal("attestation")))
.arg(("dst_ori", "dst_key")),
],
);
)?;
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
let eth_friend = builder.finish();
@ -216,6 +217,7 @@ mod tests {
// >
let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into());
let eth_dos_distance_base = builder.predicate_and(
&params,
&[
// arguments:
"src_ori",
@ -236,7 +238,7 @@ mod tests {
.arg(("distance_ori", "distance_key"))
.arg(0),
],
);
)?;
println!(
"b.0. eth_dos_distance_base = {}",
builder.predicates.last().unwrap()
@ -246,6 +248,7 @@ mod tests {
// next chunk builds:
let eth_dos_distance_ind = builder.predicate_and(
&params,
&[
// arguments:
"src_ori",
@ -281,7 +284,7 @@ mod tests {
.arg(("intermed_ori", "intermed_key"))
.arg(("dst_ori", "dst_key")),
],
);
)?;
println!(
"b.1. eth_dos_distance_ind = {}",
@ -289,6 +292,7 @@ mod tests {
);
let _eth_dos_distance = builder.predicate_or(
&params,
&[
"src_ori",
"src_key",
@ -308,7 +312,7 @@ mod tests {
.arg(("dst_ori", "dst_key"))
.arg(("distance_ori", "distance_key")),
],
);
)?;
println!(
"b.2. eth_dos_distance = {}",
@ -318,5 +322,7 @@ mod tests {
let eth_dos_batch_b = builder.finish();
let fields = eth_dos_batch_b.to_fields(&params);
println!("Batch b, serialized: {:?}", fields);
Ok(())
}
}

View file

@ -1,7 +1,7 @@
use std::fmt;
use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate};
use crate::middleware::{hash_str, NativePredicate, OperationType, Predicate};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg {