Frontend work (#109)

This commit is contained in:
Ahmad Afuni 2025-03-05 21:02:28 +10:00 committed by GitHub
parent 7eeb595dc2
commit 9d60b0ec3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 611 additions and 262 deletions

View file

@ -1,9 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::{fmt, hash as h, iter::zip};
use anyhow::{anyhow, Result};
use plonky2::field::types::Field;
use crate::util::hashmap_insert_no_dupe;
use super::{
hash_fields, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg,
ToFields, Value, F,
@ -25,7 +28,11 @@ impl HashOrWildcard {
match self {
HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None),
HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))),
_ => Err(anyhow!("Failed to match {} against {}.", self, v)),
_ => Err(anyhow!(
"Failed to match hash or wildcard {} against value {}.",
self,
v
)),
}
}
}
@ -76,7 +83,11 @@ impl StatementTmplArg {
let k_corr = tmpl_k.match_against(&k.clone().into())?;
Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect())
}
_ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)),
_ => Err(anyhow!(
"Failed to match statement template argument {:?} against statement argument {:?}.",
self,
s_arg
)),
}
}
}
@ -322,6 +333,62 @@ impl CustomPredicateBatch {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicateRef(pub Arc<CustomPredicateBatch>, pub usize);
impl CustomPredicateRef {
pub fn arg_len(&self) -> usize {
(*self.0).predicates[self.1].args_len
}
pub fn match_against(&self, statements: &[Statement]) -> Result<HashMap<usize, Value>> {
let mut bindings = HashMap::new();
// Single out custom predicate, replacing batch-self
// references with custom predicate references.
let custom_predicate = {
let cp = &Arc::unwrap_or_clone(self.0.clone()).predicates[self.1];
CustomPredicate {
conjunction: cp.conjunction,
statements: cp
.statements
.iter()
.map(|StatementTmpl(p, args)| {
StatementTmpl(
match p {
Predicate::BatchSelf(i) => {
Predicate::Custom(CustomPredicateRef(self.0.clone(), *i))
}
_ => p.clone(),
},
args.to_vec(),
)
})
.collect(),
args_len: cp.args_len,
}
};
match custom_predicate.conjunction {
true if custom_predicate.statements.len() == statements.len() => {
// Match op args against statement templates
let match_bindings = std::iter::zip(custom_predicate.statements, statements).map(
|(s_tmpl, s)| s_tmpl.match_against(s)
).collect::<Result<Vec<_>>>()
.map(|v| v.concat())?;
// Add bindings to binding table, throwing if there is an inconsistency.
match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok(bindings)
},
false if statements.len() == 1 => {
// Match op arg against each statement template
custom_predicate.statements.iter().map(
|s_tmpl| {
let mut bindings = bindings.clone();
s_tmpl.match_against(&statements[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok::<_, anyhow::Error>(bindings)
}
).find(|m| m.is_ok()).unwrap_or(Err(anyhow!("Statement {} does not match disjunctive custom predicate {}.", &statements[0], custom_predicate)))
},
_ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, statements))
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Predicate {
Native(NativePredicate),

View file

@ -194,68 +194,34 @@ impl Operation {
let v3: i64 = v3.clone().try_into()?;
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
}
(
Self::Custom(CustomPredicateRef(cpb, i), args),
Custom(CustomPredicateRef(s_cpb, s_i), s_args),
) if cpb == s_cpb && i == s_i => {
// Bind statement arguments
let mut bindings = s_args
.into_iter()
.enumerate()
.flat_map(|(i, AnchoredKey(PodId(o), k))| {
vec![
(2 * i, Value::from(o.clone())),
(2 * i + 1, Value::from(k.clone())),
]
})
.collect::<HashMap<_, _>>();
// Single out custom predicate, replacing batch-self
// references with custom predicate references.
let custom_predicate = {
let cp = (**cpb).predicates[*i].clone();
CustomPredicate::new(
params,
cp.conjunction,
cp.statements
.into_iter()
.map(|StatementTmpl(p, args)| {
StatementTmpl(
match p {
Predicate::BatchSelf(i) => {
Predicate::Custom(CustomPredicateRef(cpb.clone(), i))
}
_ => p,
},
args,
)
})
.collect(),
cp.args_len,
)?
};
match custom_predicate.conjunction {
true if custom_predicate.statements.len() == args.len() => {
// Match op args against statement templates
let match_bindings = std::iter::zip(custom_predicate.statements, args).map(
|(s_tmpl, s)| s_tmpl.match_against(s)
).collect::<Result<Vec<_>>>()
.map(|v| v.concat())?;
// Add bindings to binding table, throwing if there is an inconsistency.
match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok(true)
},
false if args.len() == 1 => {
// Match op arg against each statement template
custom_predicate.statements.into_iter().map(
|s_tmpl| {
let mut bindings = bindings.clone();
s_tmpl.match_against(&args[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
Ok::<_, anyhow::Error>(true)
}
).find(|m| m.is_ok()).unwrap_or(Ok(false))
},
_ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, args))
(Self::Custom(CustomPredicateRef(cpb, i), args), Custom(cpr, s_args))
if cpb == &cpr.0 && i == &cpr.1 =>
{
// Bind according to custom predicate pattern match against arg list.
let bindings = cpr.match_against(args)?;
// Check arg length
let arg_len = cpr.arg_len();
if arg_len != 2 * s_args.len() {
Err(anyhow!("Custom predicate arg list {:?} must have {} arguments after destructuring.", s_args, arg_len))
} else {
let bound_args = (0..arg_len)
.map(|i| {
bindings.get(&i).cloned().ok_or(anyhow!(
"Wildcard {} of custom predicate {:?} is unbound.",
i,
cpr
))
})
.collect::<Result<Vec<_>>>()?;
let s_args = s_args
.into_iter()
.flat_map(|AnchoredKey(o, k)| [Value::from(o.0.clone()), k.clone().into()])
.collect::<Vec<_>>();
if bound_args != s_args {
Err(anyhow!("Arguments to output statement {} do not match those implied by operation {:?}", output_statement,self))
} else {
Ok(true)
}
}
}
_ => Err(anyhow!(