diff --git a/src/frontend/error.rs b/src/frontend/error.rs index dc14dfd..45491fe 100644 --- a/src/frontend/error.rs +++ b/src/frontend/error.rs @@ -22,8 +22,13 @@ fn display_wc_map(wc_map: &[Option]) -> String { pub enum InnerError { #[error("{0} {1} is over the limit {2}")] MaxLength(String, usize, usize), - #[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}", map=display_wc_map(.2))] - StatementsDontMatch(Statement, StatementTmpl, Vec>), + #[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))] + StatementsDontMatch( + Statement, + StatementTmpl, + Vec>, + crate::middleware::Error, + ), #[error("invalid arguments to {0} operation")] OpInvalidArgs(String), // Other @@ -76,8 +81,9 @@ impl Error { s0: Statement, s1: StatementTmpl, wc_map: Vec>, + mid_error: crate::middleware::Error, ) -> Self { - new!(StatementsDontMatch(s0, s1, wc_map)) + new!(StatementsDontMatch(s0, s1, wc_map, mid_error)) } pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self { new!(MaxLength(obj, found, expect)) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 6a0d568..f62e645 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -471,11 +471,14 @@ impl MainPodBuilder { for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { let st_args = st.args(); for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { - if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { + if let Err(st_tmpl_check_error) = + check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) + { return Err(Error::statements_dont_match( st.clone(), st_tmpl.clone(), wildcard_map, + st_tmpl_check_error, )); } } diff --git a/src/middleware/error.rs b/src/middleware/error.rs index 3ef660b..1c35bc3 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -2,7 +2,10 @@ use std::{backtrace::Backtrace, fmt::Debug}; -use crate::middleware::{Operation, Statement, StatementArg}; +use crate::middleware::{ + CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value, + Wildcard, +}; pub type Result = core::result::Result; @@ -18,6 +21,22 @@ pub enum MiddlewareInnerError { MaxLength(String, usize, usize), #[error("{0} amount of {1} should be {1} but it's {2}")] DiffAmount(String, String, usize, usize), + #[error("{0} should be assigned the value {1} but has previously been assigned {2}")] + InvalidWildcardAssignment(Wildcard, Value, Value), + #[error("{0} matches POD ID {1}, yet the template key {2} does not match {3}")] + MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key), + #[error("{0} does not match against {1}")] + MismatchedStatementTmplArg(StatementTmplArg, StatementArg), + #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] + MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), + #[error( + "Not all statement templates of the following custom predicate have been matched:\n{0}" + )] + UnsatisfiedCustomPredicateConjunction(CustomPredicate), + #[error( + "None of the statement templates of the following custom predicate have been matched:\n{0}" + )] + UnsatisfiedCustomPredicateDisjunction(CustomPredicate), // Other #[error("{0}")] Custom(String), @@ -65,6 +84,48 @@ impl Error { pub(crate) fn diff_amount(obj: String, unit: String, expect: usize, found: usize) -> Self { new!(DiffAmount(obj, unit, expect, found)) } + pub(crate) fn invalid_wildcard_assignment( + wildcard: Wildcard, + value: Value, + prev_value: Value, + ) -> Self { + new!(InvalidWildcardAssignment(wildcard, value, prev_value)) + } + pub(crate) fn mismatched_anchored_key_in_statement_tmpl_arg( + pod_id_wildcard: Wildcard, + pod_id: PodId, + key_tmpl: Key, + key: Key, + ) -> Self { + new!(MismatchedAnchoredKeyInStatementTmplArg( + pod_id_wildcard, + pod_id, + key_tmpl, + key + )) + } + pub(crate) fn mismatched_statement_tmpl_arg( + st_tmpl_arg: StatementTmplArg, + st_arg: StatementArg, + ) -> Self { + new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg)) + } + pub(crate) fn mismatched_wildcard_value_and_statement_arg( + wc_value: Value, + st_arg: Value, + arg_index: usize, + pred: CustomPredicate, + ) -> Self { + new!(MismatchedWildcardValueAndStatementArg( + wc_value, st_arg, arg_index, pred + )) + } + pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self { + new!(UnsatisfiedCustomPredicateConjunction(pred)) + } + pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { + new!(UnsatisfiedCustomPredicateDisjunction(pred)) + } pub(crate) fn custom(s: String) -> Self { new!(Custom(s)) } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 5da69fb..2600e3d 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -355,7 +355,7 @@ impl Operation { (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - check_custom_pred(params, cpr, args, s_args)? + check_custom_pred(params, cpr, args, s_args).map(|_| true)? } _ => return Err(deduction_err()), }; @@ -370,37 +370,49 @@ pub fn check_st_tmpl( st_arg: &StatementArg, // Map from wildcards to values that we have seen so far. wildcard_map: &mut [Option], -) -> bool { +) -> Result<()> { // Check that the value `v` at wildcard `wc` exists in the map or set it. - fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option]) -> bool { + fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option]) -> Result<()> { if let Some(prev) = &wildcard_map[wc.index] { if *prev != v { - // TODO: Return nice error - return false; + return Err(Error::invalid_wildcard_assignment( + wc.clone(), + v, + prev.clone(), + )); } } else { wildcard_map[wc.index] = Some(v); } - true + Ok(()) } match (st_tmpl_arg, st_arg) { - (StatementTmplArg::None, StatementArg::None) => true, - (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, + (StatementTmplArg::None, StatementArg::None) => Ok(()), + (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()), ( StatementTmplArg::AnchoredKey(pod_id_wc, key_tmpl), StatementArg::Key(AnchoredKey { pod_id, key }), ) => { let pod_id_ok = check_or_set(Value::from(*pod_id), pod_id_wc, wildcard_map); - pod_id_ok && (key_tmpl == key) + pod_id_ok.and_then(|_| { + (key_tmpl == key).then_some(()).ok_or( + Error::mismatched_anchored_key_in_statement_tmpl_arg( + pod_id_wc.clone(), + *pod_id, + key_tmpl.clone(), + key.clone(), + ), + ) + }) } (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { check_or_set(v.clone(), wc, wildcard_map) } - _ => { - println!("DBG {:?} {:?}", st_tmpl_arg, st_arg); - false - } + _ => Err(Error::mismatched_statement_tmpl_arg( + st_tmpl_arg.clone(), + st_arg.clone(), + )), } } @@ -408,7 +420,7 @@ pub fn resolve_wildcard_values( params: &Params, pred: &CustomPredicate, args: &[Statement], -) -> Option> { +) -> Result> { // Check that all wildcard have consistent values as assigned in the statements while storing a // map of their values. // NOTE: We assume the statements have the same order as defined in the custom predicate. For @@ -416,25 +428,22 @@ pub fn resolve_wildcard_values( let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; for (st_tmpl, st) in pred.statements.iter().zip(args) { let st_args = st.args(); - for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { - if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { - // TODO: Better errors. Example: - // println!("{} doesn't match {}", st_arg, st_tmpl_arg); - // println!("{} doesn't match {}", st, st_tmpl); - return None; - } - } + st_tmpl + .args + .iter() + .zip(&st_args) + .try_for_each(|(st_tmpl_arg, st_arg)| { + check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) + })?; } // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // they are beyond the number of used wildcards in this custom predicate, or they could be // private arguments that are unused in a particular disjunction. - Some( - wildcard_map - .into_iter() - .map(|opt| opt.unwrap_or(Value::from(0))) - .collect(), - ) + Ok(wildcard_map + .into_iter() + .map(|opt| opt.unwrap_or(Value::from(0))) + .collect()) } fn check_custom_pred( @@ -442,7 +451,7 @@ fn check_custom_pred( custom_pred_ref: &CustomPredicateRef, args: &[Statement], s_args: &[Value], -) -> Result { +) -> Result<()> { let pred = custom_pred_ref.predicate(); if pred.statements.len() != args.len() { return Err(Error::diff_amount( @@ -476,23 +485,33 @@ fn check_custom_pred( } } - let wildcard_map = match resolve_wildcard_values(params, pred, args) { - Some(wc_map) => wc_map, - None => return Ok(false), - }; + let wildcard_map = resolve_wildcard_values(params, pred, args)?; - // Check that the resolved wildcard match the statement arguments. - for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) { + // Check that the resolved wildcards match the statement arguments. + for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() { if *wc_value != *s_arg { - return Ok(false); + return Err(Error::mismatched_wildcard_value_and_statement_arg( + wc_value.clone(), + s_arg.clone(), + arg_index, + pred.clone(), + )); } } if pred.conjunction { - Ok(num_matches == pred.statements.len()) - } else { - Ok(num_matches > 0) + if num_matches != pred.statements.len() { + return Err(Error::unsatisfied_custom_predicate_conjunction( + pred.clone(), + )); + } + } else if num_matches == 0 { + return Err(Error::unsatisfied_custom_predicate_disjunction( + pred.clone(), + )); } + + Ok(()) } impl ToFields for Operation {