chore(middleware): additional error reporting for custom predicates (#330)

* Additional error reporting for custom predicates

* Code review

* Typo
This commit is contained in:
Ahmad Afuni 2025-07-14 23:27:33 +10:00 committed by GitHub
parent aeedf55bad
commit e8468d7fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 133 additions and 44 deletions

View file

@ -22,8 +22,13 @@ fn display_wc_map(wc_map: &[Option<Value>]) -> String {
pub enum InnerError { pub enum InnerError {
#[error("{0} {1} is over the limit {2}")] #[error("{0} {1} is over the limit {2}")]
MaxLength(String, usize, usize), MaxLength(String, usize, usize),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}", map=display_wc_map(.2))] #[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
StatementsDontMatch(Statement, StatementTmpl, Vec<Option<Value>>), StatementsDontMatch(
Statement,
StatementTmpl,
Vec<Option<Value>>,
crate::middleware::Error,
),
#[error("invalid arguments to {0} operation")] #[error("invalid arguments to {0} operation")]
OpInvalidArgs(String), OpInvalidArgs(String),
// Other // Other
@ -76,8 +81,9 @@ impl Error {
s0: Statement, s0: Statement,
s1: StatementTmpl, s1: StatementTmpl,
wc_map: Vec<Option<Value>>, wc_map: Vec<Option<Value>>,
mid_error: crate::middleware::Error,
) -> Self { ) -> Self {
new!(StatementsDontMatch(s0, s1, wc_map)) new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
} }
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self { pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
new!(MaxLength(obj, found, expect)) new!(MaxLength(obj, found, expect))

View file

@ -471,11 +471,14 @@ impl MainPodBuilder {
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
let st_args = st.args(); let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { if let Err(st_tmpl_check_error) =
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
{
return Err(Error::statements_dont_match( return Err(Error::statements_dont_match(
st.clone(), st.clone(),
st_tmpl.clone(), st_tmpl.clone(),
wildcard_map, wildcard_map,
st_tmpl_check_error,
)); ));
} }
} }

View file

@ -2,7 +2,10 @@
use std::{backtrace::Backtrace, fmt::Debug}; use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::{Operation, Statement, StatementArg}; use crate::middleware::{
CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value,
Wildcard,
};
pub type Result<T, E = Error> = core::result::Result<T, E>; pub type Result<T, E = Error> = core::result::Result<T, E>;
@ -18,6 +21,22 @@ pub enum MiddlewareInnerError {
MaxLength(String, usize, usize), MaxLength(String, usize, usize),
#[error("{0} amount of {1} should be {1} but it's {2}")] #[error("{0} amount of {1} should be {1} but it's {2}")]
DiffAmount(String, String, usize, usize), DiffAmount(String, String, usize, usize),
#[error("{0} should be assigned the value {1} but has previously been assigned {2}")]
InvalidWildcardAssignment(Wildcard, Value, Value),
#[error("{0} matches POD ID {1}, yet the template key {2} does not match {3}")]
MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key),
#[error("{0} does not match against {1}")]
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
#[error(
"Not all statement templates of the following custom predicate have been matched:\n{0}"
)]
UnsatisfiedCustomPredicateConjunction(CustomPredicate),
#[error(
"None of the statement templates of the following custom predicate have been matched:\n{0}"
)]
UnsatisfiedCustomPredicateDisjunction(CustomPredicate),
// Other // Other
#[error("{0}")] #[error("{0}")]
Custom(String), Custom(String),
@ -65,6 +84,48 @@ impl Error {
pub(crate) fn diff_amount(obj: String, unit: String, expect: usize, found: usize) -> Self { pub(crate) fn diff_amount(obj: String, unit: String, expect: usize, found: usize) -> Self {
new!(DiffAmount(obj, unit, expect, found)) new!(DiffAmount(obj, unit, expect, found))
} }
pub(crate) fn invalid_wildcard_assignment(
wildcard: Wildcard,
value: Value,
prev_value: Value,
) -> Self {
new!(InvalidWildcardAssignment(wildcard, value, prev_value))
}
pub(crate) fn mismatched_anchored_key_in_statement_tmpl_arg(
pod_id_wildcard: Wildcard,
pod_id: PodId,
key_tmpl: Key,
key: Key,
) -> Self {
new!(MismatchedAnchoredKeyInStatementTmplArg(
pod_id_wildcard,
pod_id,
key_tmpl,
key
))
}
pub(crate) fn mismatched_statement_tmpl_arg(
st_tmpl_arg: StatementTmplArg,
st_arg: StatementArg,
) -> Self {
new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg))
}
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
wc_value: Value,
st_arg: Value,
arg_index: usize,
pred: CustomPredicate,
) -> Self {
new!(MismatchedWildcardValueAndStatementArg(
wc_value, st_arg, arg_index, pred
))
}
pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateConjunction(pred))
}
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateDisjunction(pred))
}
pub(crate) fn custom(s: String) -> Self { pub(crate) fn custom(s: String) -> Self {
new!(Custom(s)) new!(Custom(s))
} }

View file

@ -355,7 +355,7 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index => if batch == &cpr.batch && index == &cpr.index =>
{ {
check_custom_pred(params, cpr, args, s_args)? check_custom_pred(params, cpr, args, s_args).map(|_| true)?
} }
_ => return Err(deduction_err()), _ => return Err(deduction_err()),
}; };
@ -370,37 +370,49 @@ pub fn check_st_tmpl(
st_arg: &StatementArg, st_arg: &StatementArg,
// Map from wildcards to values that we have seen so far. // Map from wildcards to values that we have seen so far.
wildcard_map: &mut [Option<Value>], wildcard_map: &mut [Option<Value>],
) -> bool { ) -> Result<()> {
// Check that the value `v` at wildcard `wc` exists in the map or set it. // Check that the value `v` at wildcard `wc` exists in the map or set it.
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> bool { fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
if let Some(prev) = &wildcard_map[wc.index] { if let Some(prev) = &wildcard_map[wc.index] {
if *prev != v { if *prev != v {
// TODO: Return nice error return Err(Error::invalid_wildcard_assignment(
return false; wc.clone(),
v,
prev.clone(),
));
} }
} else { } else {
wildcard_map[wc.index] = Some(v); wildcard_map[wc.index] = Some(v);
} }
true Ok(())
} }
match (st_tmpl_arg, st_arg) { match (st_tmpl_arg, st_arg) {
(StatementTmplArg::None, StatementArg::None) => true, (StatementTmplArg::None, StatementArg::None) => Ok(()),
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
( (
StatementTmplArg::AnchoredKey(pod_id_wc, key_tmpl), StatementTmplArg::AnchoredKey(pod_id_wc, key_tmpl),
StatementArg::Key(AnchoredKey { pod_id, key }), StatementArg::Key(AnchoredKey { pod_id, key }),
) => { ) => {
let pod_id_ok = check_or_set(Value::from(*pod_id), pod_id_wc, wildcard_map); let pod_id_ok = check_or_set(Value::from(*pod_id), pod_id_wc, wildcard_map);
pod_id_ok && (key_tmpl == key) pod_id_ok.and_then(|_| {
(key_tmpl == key).then_some(()).ok_or(
Error::mismatched_anchored_key_in_statement_tmpl_arg(
pod_id_wc.clone(),
*pod_id,
key_tmpl.clone(),
key.clone(),
),
)
})
} }
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
check_or_set(v.clone(), wc, wildcard_map) check_or_set(v.clone(), wc, wildcard_map)
} }
_ => { _ => Err(Error::mismatched_statement_tmpl_arg(
println!("DBG {:?} {:?}", st_tmpl_arg, st_arg); st_tmpl_arg.clone(),
false st_arg.clone(),
} )),
} }
} }
@ -408,7 +420,7 @@ pub fn resolve_wildcard_values(
params: &Params, params: &Params,
pred: &CustomPredicate, pred: &CustomPredicate,
args: &[Statement], args: &[Statement],
) -> Option<Vec<Value>> { ) -> Result<Vec<Value>> {
// Check that all wildcard have consistent values as assigned in the statements while storing a // Check that all wildcard have consistent values as assigned in the statements while storing a
// map of their values. // map of their values.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For // NOTE: We assume the statements have the same order as defined in the custom predicate. For
@ -416,25 +428,22 @@ pub fn resolve_wildcard_values(
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) { for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args(); let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { st_tmpl
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { .args
// TODO: Better errors. Example: .iter()
// println!("{} doesn't match {}", st_arg, st_tmpl_arg); .zip(&st_args)
// println!("{} doesn't match {}", st, st_tmpl); .try_for_each(|(st_tmpl_arg, st_arg)| {
return None; check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
} })?;
}
} }
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be // they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction. // private arguments that are unused in a particular disjunction.
Some( Ok(wildcard_map
wildcard_map
.into_iter() .into_iter()
.map(|opt| opt.unwrap_or(Value::from(0))) .map(|opt| opt.unwrap_or(Value::from(0)))
.collect(), .collect())
)
} }
fn check_custom_pred( fn check_custom_pred(
@ -442,7 +451,7 @@ fn check_custom_pred(
custom_pred_ref: &CustomPredicateRef, custom_pred_ref: &CustomPredicateRef,
args: &[Statement], args: &[Statement],
s_args: &[Value], s_args: &[Value],
) -> Result<bool> { ) -> Result<()> {
let pred = custom_pred_ref.predicate(); let pred = custom_pred_ref.predicate();
if pred.statements.len() != args.len() { if pred.statements.len() != args.len() {
return Err(Error::diff_amount( return Err(Error::diff_amount(
@ -476,23 +485,33 @@ fn check_custom_pred(
} }
} }
let wildcard_map = match resolve_wildcard_values(params, pred, args) { let wildcard_map = resolve_wildcard_values(params, pred, args)?;
Some(wc_map) => wc_map,
None => return Ok(false),
};
// Check that the resolved wildcard match the statement arguments. // Check that the resolved wildcards match the statement arguments.
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) { for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() {
if *wc_value != *s_arg { if *wc_value != *s_arg {
return Ok(false); return Err(Error::mismatched_wildcard_value_and_statement_arg(
wc_value.clone(),
s_arg.clone(),
arg_index,
pred.clone(),
));
} }
} }
if pred.conjunction { if pred.conjunction {
Ok(num_matches == pred.statements.len()) if num_matches != pred.statements.len() {
} else { return Err(Error::unsatisfied_custom_predicate_conjunction(
Ok(num_matches > 0) pred.clone(),
));
} }
} else if num_matches == 0 {
return Err(Error::unsatisfied_custom_predicate_disjunction(
pred.clone(),
));
}
Ok(())
} }
impl ToFields for Operation { impl ToFields for Operation {