Constraints for custom predicates (#227)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* precalculate CustomPredicateBatch id

* wip

* wip

* move code back

* great progress

* wip

* code complete, hopefully; missing tests

* fill aux for custom predicate op

* fix clippy warnings

* fix typos

* fix test import

* fix missing assignment in lt_mask, test custom_operation_verify_gadget

* fix mistake

* wip

* fix

* debug revert except for let entry = CustomPredicateVerifyEntryTarget

* fix batch_id calculation by fixing padding

* oops

* remove completed TODOs
This commit is contained in:
Eduard S. 2025-05-13 11:00:45 +02:00 committed by GitHub
parent 4fa9e20ecd
commit 024ed8bd04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1597 additions and 291 deletions

View file

@ -128,13 +128,15 @@ impl StatementTmplBuilder {
}
pub struct CustomPredicateBatchBuilder {
params: Params,
pub name: String,
pub predicates: Vec<CustomPredicate>,
}
impl CustomPredicateBatchBuilder {
pub fn new(name: String) -> Self {
pub fn new(params: Params, name: String) -> Self {
Self {
params,
name,
predicates: Vec::new(),
}
@ -143,23 +145,21 @@ impl CustomPredicateBatchBuilder {
pub fn predicate_and(
&mut self,
name: &str,
params: &Params,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Result<Predicate> {
self.predicate(name, params, true, args, priv_args, sts)
self.predicate(name, true, args, priv_args, sts)
}
pub fn predicate_or(
&mut self,
name: &str,
params: &Params,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Result<Predicate> {
self.predicate(name, params, false, args, priv_args, sts)
self.predicate(name, false, args, priv_args, sts)
}
/// creates the custom predicate from the given input, adds it to the
@ -167,24 +167,23 @@ impl CustomPredicateBatchBuilder {
fn predicate(
&mut self,
name: &str,
params: &Params,
conjunction: bool,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Result<Predicate> {
if args.len() > params.max_statement_args {
if args.len() > self.params.max_statement_args {
return Err(Error::max_length(
"args.len".to_string(),
args.len(),
params.max_statement_args,
self.params.max_statement_args,
));
}
if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards {
if (args.len() + priv_args.len()) > self.params.max_custom_predicate_wildcards {
return Err(Error::max_length(
"wildcards.len".to_string(),
args.len() + priv_args.len(),
params.max_custom_predicate_wildcards,
self.params.max_custom_predicate_wildcards,
));
}
@ -197,7 +196,7 @@ impl CustomPredicateBatchBuilder {
.iter()
.map(|a| match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(pod_id, key) => StatementTmplArg::Key(
BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey(
resolve_wildcard(args, priv_args, pod_id),
resolve_key_or_wildcard(args, priv_args, key),
),
@ -212,17 +211,19 @@ impl CustomPredicateBatchBuilder {
}
})
.collect();
let custom_predicate =
CustomPredicate::new(name.into(), params, conjunction, statements, args.len())?;
let custom_predicate = CustomPredicate::new(
&self.params,
name.into(),
conjunction,
statements,
args.len(),
)?;
self.predicates.push(custom_predicate);
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
}
pub fn finish(self) -> Arc<CustomPredicateBatch> {
Arc::new(CustomPredicateBatch {
name: self.name,
predicates: self.predicates,
})
CustomPredicateBatch::new(&self.params, self.name, self.predicates)
}
}
@ -290,7 +291,7 @@ mod tests {
#[test]
fn test_desugared_gt_custom_pred() -> Result<()> {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into());
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
.arg(("s1_origin", "s1_key"))
@ -298,7 +299,6 @@ mod tests {
builder.predicate_and(
"gt_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[],
&[gt_stb],
@ -322,7 +322,7 @@ mod tests {
// Check that the desugared predicate is the same as the one in the statement template
assert_eq!(
desugared_gt.predicate(),
*batch_clone.predicates[0].statements[0].pred()
*batch_clone.predicates()[0].statements[0].pred()
);
// Check that our custom predicate matches the statement template
@ -339,7 +339,8 @@ mod tests {
#[test]
fn test_desugared_set_contains_custom_pred() -> Result<()> {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into());
let mut builder =
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
.arg(("s1_origin", "s1_key"))
@ -347,7 +348,6 @@ mod tests {
builder.predicate_and(
"set_contains_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[],
&[set_contains_stb],
@ -368,7 +368,7 @@ mod tests {
);
assert_eq!(
set_contains.predicate(),
*batch_clone.predicates[0].statements[0].pred()
*batch_clone.predicates()[0].statements[0].pred()
);
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);

View file

@ -466,7 +466,7 @@ impl MainPodBuilder {
)))?,
},
OperationType::Custom(cpr) => {
let pred = &cpr.batch.predicates[cpr.index];
let pred = &cpr.batch.predicates()[cpr.index];
if pred.statements.len() != args.len() {
return Err(Error::custom(format!(
"Custom predicate operation needs {} statements but has {}.",