feat: implement custom op check (#92)

* Implement custom op check

* Example
This commit is contained in:
Ahmad Afuni 2025-02-27 22:53:23 +10:00 committed by GitHub
parent a37b96ab4f
commit af46ab7a8d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 372 additions and 5 deletions

View file

@ -1,7 +1,12 @@
use std::collections::HashMap;
use anyhow::{anyhow, Result};
use super::{CustomPredicateRef, Statement};
use crate::middleware::{AnchoredKey, SELF};
use crate::{
middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF},
util::hashmap_insert_no_dupe,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NativeOperation {
@ -175,9 +180,69 @@ impl Operation {
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
}
(
Self::Custom(CustomPredicateRef(cpb, i), _args),
Custom(CustomPredicateRef(s_cpb, s_i), _s_args),
) if cpb == s_cpb && i == s_i => todo!(),
Self::Custom(CustomPredicateRef(cpb, i), args),
Custom(CustomPredicateRef(s_cpb, s_i), s_args),
) if cpb == s_cpb && i == s_i => {
// Bind statement arguments
let mut bindings = s_args
.into_iter()
.enumerate()
.flat_map(|(i, AnchoredKey(PodId(o), k))| {
vec![
(2 * i, Value::from(o.clone())),
(2 * i + 1, Value::from(k.clone())),
]
})
.collect::<HashMap<_, _>>();
// Single out custom predicate, replacing batch-self
// references with custom predicate references.
let custom_predicate = {
let cp = (**cpb).predicates[*i].clone();
CustomPredicate {
conjunction: cp.conjunction,
statements: cp
.statements
.into_iter()
.map(|StatementTmpl(p, args)| {
StatementTmpl(
match p {
Predicate::BatchSelf(i) => {
Predicate::Custom(CustomPredicateRef(cpb.clone(), i))
}
_ => p,
},
args,
)
})
.collect(),
args_len: cp.args_len,
}
};
match custom_predicate.conjunction {
true if custom_predicate.statements.len() == args.len() => {
// Match op args against statement templates
let match_bindings = std::iter::zip(custom_predicate.statements, args).map(
|(s_tmpl, s)| s_tmpl.match_against(s)
).collect::<Result<Vec<_>>>()
.map(|v| v.concat())?;
// Add bindings to binding table, throwing if there is an inconsistency.
match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok(true)
},
false if args.len() == 1 => {
// Match op arg against each statement template
custom_predicate.statements.into_iter().map(
|s_tmpl| {
let mut bindings = bindings.clone();
s_tmpl.match_against(&args[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok::<_, anyhow::Error>(true)
}
).find(|m| m.is_ok()).unwrap_or(Ok(false))
},
_ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, args))
}
}
_ => Err(anyhow!(
"Invalid deduction: {:?} ⇏ {:#}",
self,