parent
1d14338351
commit
26548cf612
2 changed files with 106 additions and 47 deletions
|
|
@ -1,5 +1,6 @@
|
|||
use std::{fmt, iter};
|
||||
|
||||
use itertools::Itertools;
|
||||
use log::error;
|
||||
use plonky2::field::types::Field;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -14,8 +15,8 @@ use crate::{
|
|||
},
|
||||
middleware::{
|
||||
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
|
||||
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmpl,
|
||||
StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
|
||||
MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
|
||||
StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -613,16 +614,11 @@ pub fn check_st_tmpl(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn resolve_wildcard_values(
|
||||
params: &Params,
|
||||
pub fn fill_wildcard_values(
|
||||
pred: &CustomPredicate,
|
||||
args: &[Statement],
|
||||
) -> Result<Vec<Value>> {
|
||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||
// map of their values.
|
||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||
// disjunctions we expect Statement::None for the unused statements.
|
||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||
wildcard_map: &mut [Option<Value>],
|
||||
) -> Result<()> {
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||
let st_args = st.args();
|
||||
st_tmpl
|
||||
|
|
@ -630,10 +626,25 @@ pub fn resolve_wildcard_values(
|
|||
.iter()
|
||||
.zip(&st_args)
|
||||
.try_for_each(|(st_tmpl_arg, st_arg)| {
|
||||
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
|
||||
check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map)
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn wildcard_values_from_op_st(
|
||||
params: &Params,
|
||||
pred: &CustomPredicate,
|
||||
op_args: &[Statement],
|
||||
st_args: &[Value],
|
||||
) -> Result<Vec<Value>> {
|
||||
let mut wildcard_map = st_args
|
||||
.iter()
|
||||
.map(|v| Some(v.clone()))
|
||||
.chain(core::iter::repeat(None))
|
||||
.take(params.max_custom_predicate_wildcards)
|
||||
.collect_vec();
|
||||
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
|
||||
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
||||
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
||||
// private arguments that are unused in a particular disjunction.
|
||||
|
|
@ -717,21 +728,24 @@ pub(crate) fn check_custom_pred(
|
|||
));
|
||||
}
|
||||
|
||||
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
|
||||
|
||||
// Check that the resolved wildcards match the statement arguments.
|
||||
for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() {
|
||||
if *wc_value != *s_arg {
|
||||
return Err(Error::mismatched_wildcard_value_and_statement_arg(
|
||||
wc_value.clone(),
|
||||
s_arg.clone(),
|
||||
arg_index,
|
||||
pred.clone(),
|
||||
));
|
||||
}
|
||||
match wildcard_values_from_op_st(params, pred, args, s_args) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(Error::Inner { inner, backtrace }) => match *inner {
|
||||
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
|
||||
if wc.index <= s_args.len() =>
|
||||
{
|
||||
Err(Error::mismatched_wildcard_value_and_statement_arg(
|
||||
v,
|
||||
prev,
|
||||
wc.index,
|
||||
pred.clone(),
|
||||
))
|
||||
}
|
||||
_ => Err(Error::Inner { inner, backtrace }),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl ToFields for Operation {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue