Feat/fst order pred part3 & part4 (#457)
* support wildcard predicates in frontend * suport wildcard predicate in podlang * add validation test * test full flow and apply some fixes * fix clippy * fix merge issues * use desugared predicate * Fix parsing of intro statement templates inside custom predicates * Tidy up comments * lang: handle wildcard predicate * add unreachable message --------- Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
parent
b66f5051b5
commit
498e946612
11 changed files with 324 additions and 180 deletions
|
|
@ -3255,8 +3255,12 @@ mod tests {
|
||||||
use NativePredicate as NP;
|
use NativePredicate as NP;
|
||||||
use StatementTmplBuilder as STB;
|
use StatementTmplBuilder as STB;
|
||||||
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42));
|
let stb0 = STB::new_from_pred(NP::Equal)
|
||||||
let stb1 = STB::new(NP::Equal).arg(("id", "key")).arg("secret");
|
.arg(("id", "score"))
|
||||||
|
.arg(literal(42));
|
||||||
|
let stb1 = STB::new_from_pred(NP::Equal)
|
||||||
|
.arg(("id", "key"))
|
||||||
|
.arg("secret");
|
||||||
let _ = builder.predicate_and(
|
let _ = builder.predicate_and(
|
||||||
"pred_and",
|
"pred_and",
|
||||||
&["id"],
|
&["id"],
|
||||||
|
|
@ -3349,8 +3353,10 @@ mod tests {
|
||||||
use NativePredicate as NP;
|
use NativePredicate as NP;
|
||||||
use StatementTmplBuilder as STB;
|
use StatementTmplBuilder as STB;
|
||||||
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42));
|
let stb0 = STB::new_from_pred(NP::Equal)
|
||||||
let stb1 = STB::new(NP::Equal)
|
.arg(("id", "score"))
|
||||||
|
.arg(literal(42));
|
||||||
|
let stb1 = STB::new_from_pred(NP::Equal)
|
||||||
.arg(("secret_id", "key"))
|
.arg(("secret_id", "key"))
|
||||||
.arg(("id", "score"));
|
.arg(("id", "score"));
|
||||||
let _ = builder.predicate_and(
|
let _ = builder.predicate_and(
|
||||||
|
|
|
||||||
|
|
@ -1083,11 +1083,11 @@ pub mod tests {
|
||||||
let vd_set = VDSet::new(&vds);
|
let vd_set = VDSet::new(&vds);
|
||||||
|
|
||||||
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into());
|
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into());
|
||||||
let stb0 = STB::new(NP::Contains)
|
let stb0 = STB::new_from_pred(NP::Contains)
|
||||||
.arg("dict")
|
.arg("dict")
|
||||||
.arg(literal("score"))
|
.arg(literal("score"))
|
||||||
.arg(literal(42));
|
.arg(literal(42));
|
||||||
let stb1 = STB::new(NP::Equal)
|
let stb1 = STB::new_from_pred(NP::Equal)
|
||||||
.arg(("secret_dict", "key"))
|
.arg(("secret_dict", "key"))
|
||||||
.arg(("dict", "score"));
|
.arg(("dict", "score"));
|
||||||
let _ = cpb_builder.predicate_and(
|
let _ = cpb_builder.predicate_and(
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ use crate::{
|
||||||
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
|
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
|
||||||
middleware::{
|
middleware::{
|
||||||
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
|
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
|
||||||
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -41,16 +41,34 @@ pub fn literal(v: impl Into<Value>) -> BuilderArg {
|
||||||
BuilderArg::Literal(v.into())
|
BuilderArg::Literal(v.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum PredicateOrWildcard {
|
||||||
|
Predicate(Predicate),
|
||||||
|
Wildcard(String),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StatementTmplBuilder {
|
pub struct StatementTmplBuilder {
|
||||||
pub(crate) predicate: Predicate,
|
pub(crate) pred_or_wc: PredicateOrWildcard,
|
||||||
pub(crate) args: Vec<BuilderArg>,
|
pub(crate) args: Vec<BuilderArg>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StatementTmplBuilder {
|
impl StatementTmplBuilder {
|
||||||
pub fn new(p: impl Into<Predicate>) -> StatementTmplBuilder {
|
pub fn new_from_pred(p: impl Into<Predicate>) -> StatementTmplBuilder {
|
||||||
StatementTmplBuilder {
|
StatementTmplBuilder {
|
||||||
predicate: p.into(),
|
pred_or_wc: PredicateOrWildcard::Predicate(p.into()),
|
||||||
|
args: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn new_from_wc(p: impl Into<String>) -> StatementTmplBuilder {
|
||||||
|
StatementTmplBuilder {
|
||||||
|
pred_or_wc: PredicateOrWildcard::Wildcard(p.into()),
|
||||||
|
args: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn new(pred_or_wc: PredicateOrWildcard) -> StatementTmplBuilder {
|
||||||
|
StatementTmplBuilder {
|
||||||
|
pred_or_wc,
|
||||||
args: Vec::new(),
|
args: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -62,68 +80,48 @@ impl StatementTmplBuilder {
|
||||||
|
|
||||||
/// Desugar the predicate to a simpler form
|
/// Desugar the predicate to a simpler form
|
||||||
/// Should mirror the logic in `MainPodBuilder::lower_op`
|
/// Should mirror the logic in `MainPodBuilder::lower_op`
|
||||||
pub(crate) fn desugar(self) -> StatementTmplBuilder {
|
pub(crate) fn desugar(mut self) -> StatementTmplBuilder {
|
||||||
match self.predicate {
|
let pred = match self.pred_or_wc {
|
||||||
Predicate::Native(NativePredicate::Gt) => {
|
PredicateOrWildcard::Predicate(p) => p,
|
||||||
let mut stb = StatementTmplBuilder {
|
PredicateOrWildcard::Wildcard(_) => return self,
|
||||||
predicate: Predicate::Native(NativePredicate::Lt),
|
|
||||||
args: self.args,
|
|
||||||
};
|
};
|
||||||
stb.args.swap(0, 1);
|
let pred = match pred {
|
||||||
stb
|
Predicate::Native(nat_pred) => Predicate::Native(match nat_pred {
|
||||||
|
NativePredicate::Gt => {
|
||||||
|
self.args.swap(0, 1);
|
||||||
|
NativePredicate::Lt
|
||||||
}
|
}
|
||||||
Predicate::Native(NativePredicate::GtEq) => {
|
NativePredicate::GtEq => {
|
||||||
let mut stb = StatementTmplBuilder {
|
self.args.swap(0, 1);
|
||||||
predicate: Predicate::Native(NativePredicate::LtEq),
|
NativePredicate::LtEq
|
||||||
args: self.args,
|
}
|
||||||
|
NativePredicate::ArrayContains | NativePredicate::DictContains => {
|
||||||
|
NativePredicate::Contains
|
||||||
|
}
|
||||||
|
NativePredicate::DictNotContains | NativePredicate::SetNotContains => {
|
||||||
|
NativePredicate::NotContains
|
||||||
|
}
|
||||||
|
NativePredicate::SetContains => {
|
||||||
|
self.args.push(self.args[1].clone());
|
||||||
|
NativePredicate::Contains
|
||||||
|
}
|
||||||
|
NativePredicate::DictInsert => NativePredicate::ContainerInsert,
|
||||||
|
NativePredicate::SetInsert => {
|
||||||
|
self.args.push(self.args[2].clone());
|
||||||
|
NativePredicate::ContainerInsert
|
||||||
|
}
|
||||||
|
NativePredicate::DictUpdate | NativePredicate::ArrayUpdate => {
|
||||||
|
NativePredicate::ContainerUpdate
|
||||||
|
}
|
||||||
|
NativePredicate::DictDelete => NativePredicate::ContainerDelete,
|
||||||
|
NativePredicate::SetDelete => NativePredicate::ContainerDelete,
|
||||||
|
_ => nat_pred,
|
||||||
|
}),
|
||||||
|
_ => pred,
|
||||||
};
|
};
|
||||||
stb.args.swap(0, 1);
|
|
||||||
stb
|
|
||||||
}
|
|
||||||
Predicate::Native(NativePredicate::ArrayContains)
|
|
||||||
| Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::Contains),
|
|
||||||
args: self.args,
|
|
||||||
},
|
|
||||||
Predicate::Native(NativePredicate::DictNotContains)
|
|
||||||
| Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::NotContains),
|
|
||||||
args: self.args,
|
|
||||||
},
|
|
||||||
Predicate::Native(NativePredicate::SetContains) => {
|
|
||||||
let mut new_args = self.args.clone();
|
|
||||||
new_args.push(self.args[1].clone());
|
|
||||||
StatementTmplBuilder {
|
StatementTmplBuilder {
|
||||||
predicate: Predicate::Native(NativePredicate::Contains),
|
pred_or_wc: PredicateOrWildcard::Predicate(pred),
|
||||||
args: new_args,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Predicate::Native(NativePredicate::DictInsert) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::ContainerInsert),
|
|
||||||
args: self.args,
|
args: self.args,
|
||||||
},
|
|
||||||
Predicate::Native(NativePredicate::SetInsert) => {
|
|
||||||
let mut new_args = self.args.clone();
|
|
||||||
new_args.push(self.args[2].clone());
|
|
||||||
StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::ContainerInsert),
|
|
||||||
args: new_args,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Predicate::Native(NativePredicate::DictUpdate)
|
|
||||||
| Predicate::Native(NativePredicate::ArrayUpdate) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::ContainerUpdate),
|
|
||||||
args: self.args,
|
|
||||||
},
|
|
||||||
Predicate::Native(NativePredicate::DictDelete) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::ContainerDelete),
|
|
||||||
args: self.args,
|
|
||||||
},
|
|
||||||
Predicate::Native(NativePredicate::SetDelete) => StatementTmplBuilder {
|
|
||||||
predicate: Predicate::Native(NativePredicate::ContainerDelete),
|
|
||||||
args: self.args,
|
|
||||||
},
|
|
||||||
_ => self,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -200,7 +198,7 @@ impl CustomPredicateBatchBuilder {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|sb| {
|
.map(|sb| {
|
||||||
let stb = sb.clone().desugar();
|
let stb = sb.clone().desugar();
|
||||||
let args = stb
|
let st_tmpl_args = stb
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| {
|
.map(|a| {
|
||||||
|
|
@ -216,10 +214,17 @@ impl CustomPredicateBatchBuilder {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<_>>()?;
|
.collect::<Result<_>>()?;
|
||||||
|
let pred_or_wc = match stb.pred_or_wc {
|
||||||
|
PredicateOrWildcard::Predicate(p) => {
|
||||||
|
middleware::PredicateOrWildcard::Predicate(p)
|
||||||
|
}
|
||||||
|
PredicateOrWildcard::Wildcard(v) => middleware::PredicateOrWildcard::Wildcard(
|
||||||
|
resolve_wildcard(args, priv_args, &v)?,
|
||||||
|
),
|
||||||
|
};
|
||||||
Ok(StatementTmpl {
|
Ok(StatementTmpl {
|
||||||
// TODO: Support wildcard
|
pred_or_wc,
|
||||||
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()),
|
args: st_tmpl_args,
|
||||||
args,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<_>>()?;
|
.collect::<Result<_>>()?;
|
||||||
|
|
@ -299,7 +304,7 @@ mod tests {
|
||||||
let vd_set = &*MOCK_VD_SET;
|
let vd_set = &*MOCK_VD_SET;
|
||||||
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
|
||||||
|
|
||||||
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
|
let gt_stb = StatementTmplBuilder::new_from_pred(NativePredicate::Gt)
|
||||||
.arg("s1")
|
.arg("s1")
|
||||||
.arg("s2");
|
.arg("s2");
|
||||||
|
|
||||||
|
|
@ -344,7 +349,7 @@ mod tests {
|
||||||
let mut builder =
|
let mut builder =
|
||||||
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
|
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
|
||||||
|
|
||||||
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
|
let set_contains_stb = StatementTmplBuilder::new_from_pred(NativePredicate::SetContains)
|
||||||
.arg("s1")
|
.arg("s1")
|
||||||
.arg("s2");
|
.arg("s2");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,13 @@
|
||||||
use std::{backtrace::Backtrace, fmt::Debug};
|
use std::{backtrace::Backtrace, fmt::Debug};
|
||||||
|
|
||||||
use crate::middleware::{BackendError, Statement, StatementTmpl, Value};
|
use crate::middleware::BackendError;
|
||||||
|
|
||||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||||
|
|
||||||
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
|
|
||||||
let mut out = String::new();
|
|
||||||
use std::fmt::Write;
|
|
||||||
for (i, v) in wc_map.iter().enumerate() {
|
|
||||||
write!(out, "- {}: ", i).unwrap();
|
|
||||||
if let Some(v) = v {
|
|
||||||
writeln!(out, "{}", v).unwrap();
|
|
||||||
} else {
|
|
||||||
writeln!(out, "none").unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum InnerError {
|
pub enum InnerError {
|
||||||
#[error("{0} {1} is over the limit {2}")]
|
#[error("{0} {1} is over the limit {2}")]
|
||||||
MaxLength(String, usize, usize),
|
MaxLength(String, usize, usize),
|
||||||
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
|
|
||||||
StatementsDontMatch(
|
|
||||||
Statement,
|
|
||||||
StatementTmpl,
|
|
||||||
Vec<Option<Value>>,
|
|
||||||
crate::middleware::Error,
|
|
||||||
),
|
|
||||||
#[error("invalid arguments to {0} operation")]
|
#[error("invalid arguments to {0} operation")]
|
||||||
OpInvalidArgs(String),
|
OpInvalidArgs(String),
|
||||||
#[error("Podlang parse error: {0}")]
|
#[error("Podlang parse error: {0}")]
|
||||||
|
|
@ -99,14 +78,6 @@ impl Error {
|
||||||
pub(crate) fn op_invalid_args(s: String) -> Self {
|
pub(crate) fn op_invalid_args(s: String) -> Self {
|
||||||
new!(OpInvalidArgs(s))
|
new!(OpInvalidArgs(s))
|
||||||
}
|
}
|
||||||
pub(crate) fn statements_dont_match(
|
|
||||||
s0: Statement,
|
|
||||||
s1: StatementTmpl,
|
|
||||||
wc_map: Vec<Option<Value>>,
|
|
||||||
mid_error: crate::middleware::Error,
|
|
||||||
) -> Self {
|
|
||||||
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
|
|
||||||
}
|
|
||||||
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
|
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
|
||||||
new!(MaxLength(obj, found, expect))
|
new!(MaxLength(obj, found, expect))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize};
|
||||||
pub use serialization::SerializedMainPod;
|
pub use serialization::SerializedMainPod;
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
self, check_custom_pred, check_st_tmpl, containers::Dictionary, hash_op, max_op, prod_op,
|
self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op,
|
||||||
sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux,
|
prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation,
|
||||||
OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, StatementArg, VDSet,
|
OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement,
|
||||||
Value, ValueRef,
|
StatementArg, VDSet, Value, ValueRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
mod custom;
|
mod custom;
|
||||||
|
|
@ -600,21 +600,7 @@ impl MainPodBuilder {
|
||||||
}
|
}
|
||||||
wildcard_map[index] = Some(value);
|
wildcard_map[index] = Some(value);
|
||||||
}
|
}
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
|
fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?;
|
||||||
let st_args = st.args();
|
|
||||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
|
||||||
if let Err(st_tmpl_check_error) =
|
|
||||||
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
|
|
||||||
{
|
|
||||||
return Err(Error::statements_dont_match(
|
|
||||||
st.clone(),
|
|
||||||
st_tmpl.clone(),
|
|
||||||
wildcard_map,
|
|
||||||
st_tmpl_check_error,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let v_default = Value::from(0);
|
let v_default = Value::from(0);
|
||||||
let st_args: Vec<_> = wildcard_map
|
let st_args: Vec<_> = wildcard_map
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
@ -817,7 +803,8 @@ pub mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::{
|
backends::plonky2::{
|
||||||
mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer,
|
basetypes::DEFAULT_VD_SET, mainpod::Prover, mock::mainpod::MockProver,
|
||||||
|
primitives::ec::schnorr::SecretKey, signer::Signer,
|
||||||
},
|
},
|
||||||
dict,
|
dict,
|
||||||
examples::{
|
examples::{
|
||||||
|
|
@ -1423,6 +1410,75 @@ pub mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_predicate() -> Result<()> {
|
||||||
|
let mock = true;
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let mock_prover = MockProver {};
|
||||||
|
let real_prover = Prover {};
|
||||||
|
let (vd_set, prover): (_, &dyn MainPodProver) = if mock {
|
||||||
|
(&VDSet::new(&[]), &mock_prover)
|
||||||
|
} else {
|
||||||
|
println!("Prebuilding circuits to calculate vd_set...");
|
||||||
|
let vd_set = &*DEFAULT_VD_SET;
|
||||||
|
println!("vd_set calculation complete");
|
||||||
|
(vd_set, &real_prover)
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = r#"
|
||||||
|
Test(a, b, private: c) = AND(
|
||||||
|
Equal(a, 5)
|
||||||
|
b(1, 5)
|
||||||
|
c(6, 3)
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
let batch = parse(input, ¶ms, &[])
|
||||||
|
.unwrap()
|
||||||
|
.first_batch()
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||||
|
|
||||||
|
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap();
|
||||||
|
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
|
||||||
|
let st2 = builder.priv_op(Operation::ne(6, 3)).unwrap();
|
||||||
|
let _st = builder
|
||||||
|
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let pod = builder.prove(prover).unwrap();
|
||||||
|
pod.pod.verify().unwrap();
|
||||||
|
|
||||||
|
let input = r#"
|
||||||
|
Test(a, b, private: c) = OR(
|
||||||
|
Equal(a, 5)
|
||||||
|
b(1, 5)
|
||||||
|
c(6, 3)
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
let batch = parse(input, ¶ms, &[])
|
||||||
|
.unwrap()
|
||||||
|
.first_batch()
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||||
|
|
||||||
|
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
let st0 = Statement::None;
|
||||||
|
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
|
||||||
|
let st2 = Statement::None;
|
||||||
|
let _st = builder
|
||||||
|
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let pod = builder.prove(prover).unwrap();
|
||||||
|
pod.pod.verify().unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_apply_predicate_e2e() -> Result<()> {
|
fn test_apply_predicate_e2e() -> Result<()> {
|
||||||
// End-to-end test of apply_predicate with MockProver
|
// End-to-end test of apply_predicate with MockProver
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,9 @@ pub enum ValidationError {
|
||||||
first_span: Option<Span>,
|
first_span: Option<Span>,
|
||||||
second_span: Option<Span>,
|
second_span: Option<Span>,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("Wildcard '{name}' collides with a predicate name")]
|
||||||
|
WildcardPredicateNameCollision { name: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lowering errors from frontend AST lowering to middleware
|
/// Lowering errors from frontend AST lowering to middleware
|
||||||
|
|
|
||||||
|
|
@ -619,6 +619,7 @@ fn build_single_batch(
|
||||||
batch_idx,
|
batch_idx,
|
||||||
reference_map,
|
reference_map,
|
||||||
existing_batches,
|
existing_batches,
|
||||||
|
name,
|
||||||
symbols,
|
symbols,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
@ -648,6 +649,7 @@ fn build_statement_with_resolved_refs(
|
||||||
current_batch_idx: usize,
|
current_batch_idx: usize,
|
||||||
reference_map: &HashMap<String, (usize, usize)>,
|
reference_map: &HashMap<String, (usize, usize)>,
|
||||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
existing_batches: &[Arc<CustomPredicateBatch>],
|
||||||
|
custom_predicate_name: &str, // custom pred that defines this statement template
|
||||||
symbols: &SymbolTable,
|
symbols: &SymbolTable,
|
||||||
) -> Result<StatementTmplBuilder, BatchingError> {
|
) -> Result<StatementTmplBuilder, BatchingError> {
|
||||||
let callee_name = &stmt.predicate.name;
|
let callee_name = &stmt.predicate.name;
|
||||||
|
|
@ -657,16 +659,17 @@ fn build_statement_with_resolved_refs(
|
||||||
current_batch_idx,
|
current_batch_idx,
|
||||||
reference_map,
|
reference_map,
|
||||||
existing_batches,
|
existing_batches,
|
||||||
|
custom_predicate_name,
|
||||||
};
|
};
|
||||||
|
|
||||||
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
|
let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
|
||||||
BatchingError::Internal {
|
BatchingError::Internal {
|
||||||
message: format!("Unknown predicate reference: '{}'", callee_name),
|
message: format!("Unknown predicate reference: '{}'", callee_name),
|
||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Build the statement template
|
// Build the statement template
|
||||||
let mut builder = StatementTmplBuilder::new(predicate);
|
let mut builder = StatementTmplBuilder::new(pred_or_wc);
|
||||||
|
|
||||||
for arg in &stmt.args {
|
for arg in &stmt.args {
|
||||||
builder = builder.arg(lower_statement_arg(arg));
|
builder = builder.arg(lower_statement_arg(arg));
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
frontend::{BuilderArg, StatementTmplBuilder},
|
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
|
||||||
lang::{
|
lang::{
|
||||||
frontend_ast::*,
|
frontend_ast::*,
|
||||||
frontend_ast_batch::{self, PredicateBatches},
|
frontend_ast_batch::{self, PredicateBatches},
|
||||||
|
|
@ -18,8 +18,8 @@ use crate::{
|
||||||
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
||||||
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
|
||||||
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
@ -35,6 +35,7 @@ pub enum ResolutionContext<'a> {
|
||||||
current_batch_idx: usize,
|
current_batch_idx: usize,
|
||||||
reference_map: &'a HashMap<String, (usize, usize)>,
|
reference_map: &'a HashMap<String, (usize, usize)>,
|
||||||
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
||||||
|
custom_predicate_name: &'a str,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -43,10 +44,23 @@ pub fn resolve_predicate(
|
||||||
pred_name: &str,
|
pred_name: &str,
|
||||||
symbols: &SymbolTable,
|
symbols: &SymbolTable,
|
||||||
context: &ResolutionContext,
|
context: &ResolutionContext,
|
||||||
) -> Option<Predicate> {
|
) -> Option<PredicateOrWildcard> {
|
||||||
// 1. Try native predicate first
|
// 0. Try wildcard first
|
||||||
|
if let ResolutionContext::Batch {
|
||||||
|
custom_predicate_name,
|
||||||
|
..
|
||||||
|
} = context
|
||||||
|
{
|
||||||
|
if let Some(wc_scope) = symbols.wildcard_scopes.get(*custom_predicate_name) {
|
||||||
|
if wc_scope.wildcards.contains_key(pred_name) {
|
||||||
|
return Some(PredicateOrWildcard::Wildcard(pred_name.to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Try native predicate second
|
||||||
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||||
return Some(Predicate::Native(native));
|
return Some(PredicateOrWildcard::Predicate(Predicate::Native(native)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Look up in symbol table
|
// 2. Look up in symbol table
|
||||||
|
|
@ -64,6 +78,7 @@ pub fn resolve_predicate(
|
||||||
current_batch_idx,
|
current_batch_idx,
|
||||||
reference_map,
|
reference_map,
|
||||||
existing_batches,
|
existing_batches,
|
||||||
|
..
|
||||||
} => resolve_local_predicate(
|
} => resolve_local_predicate(
|
||||||
pred_name,
|
pred_name,
|
||||||
*current_batch_idx,
|
*current_batch_idx,
|
||||||
|
|
@ -85,7 +100,7 @@ pub fn resolve_predicate(
|
||||||
verifier_data_hash: *verifier_data_hash,
|
verifier_data_hash: *verifier_data_hash,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
return Some(predicate);
|
return Some(PredicateOrWildcard::Predicate(predicate));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. In batch context, also check reference_map for split chain pieces
|
// 3. In batch context, also check reference_map for split chain pieces
|
||||||
|
|
@ -94,6 +109,7 @@ pub fn resolve_predicate(
|
||||||
current_batch_idx,
|
current_batch_idx,
|
||||||
reference_map,
|
reference_map,
|
||||||
existing_batches,
|
existing_batches,
|
||||||
|
..
|
||||||
} = context
|
} = context
|
||||||
{
|
{
|
||||||
if reference_map.contains_key(pred_name) {
|
if reference_map.contains_key(pred_name) {
|
||||||
|
|
@ -102,7 +118,8 @@ pub fn resolve_predicate(
|
||||||
*current_batch_idx,
|
*current_batch_idx,
|
||||||
reference_map,
|
reference_map,
|
||||||
existing_batches,
|
existing_batches,
|
||||||
);
|
)
|
||||||
|
.map(PredicateOrWildcard::Predicate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -328,7 +345,7 @@ impl<'a> Lowerer<'a> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Create a builder with the resolved predicate and desugar
|
// Create a builder with the resolved predicate and desugar
|
||||||
let mut builder = StatementTmplBuilder::new(predicate);
|
let mut builder = StatementTmplBuilder::new(predicate.clone());
|
||||||
for arg in &stmt.args {
|
for arg in &stmt.args {
|
||||||
let builder_arg = lower_statement_arg(arg);
|
let builder_arg = lower_statement_arg(arg);
|
||||||
builder = builder.arg(builder_arg);
|
builder = builder.arg(builder_arg);
|
||||||
|
|
@ -356,9 +373,14 @@ impl<'a> Lowerer<'a> {
|
||||||
mw_args.push(mw_arg);
|
mw_args.push(mw_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let predicate = match desugared.pred_or_wc {
|
||||||
|
PredicateOrWildcard::Predicate(p) => p,
|
||||||
|
PredicateOrWildcard::Wildcard(_) => {
|
||||||
|
unreachable!("wildcard predicates aren't considered in requests")
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(MWStatementTmpl {
|
Ok(MWStatementTmpl {
|
||||||
// TODO: Support wildcard
|
pred_or_wc: middleware::PredicateOrWildcard::Predicate(predicate),
|
||||||
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
|
|
||||||
args: mw_args,
|
args: mw_args,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -424,7 +446,6 @@ impl<'a> Lowerer<'a> {
|
||||||
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
||||||
split_results.push(result);
|
split_results.push(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(split_results)
|
Ok(split_results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -601,7 +622,7 @@ mod tests {
|
||||||
// Should be BatchSelf(0) referring to pred1
|
// Should be BatchSelf(0) referring to pred1
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
stmt.pred_or_wc,
|
stmt.pred_or_wc,
|
||||||
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
middleware::PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -639,10 +660,25 @@ mod tests {
|
||||||
// Should desugar to the Contains predicate
|
// Should desugar to the Contains predicate
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
stmt.pred_or_wc,
|
stmt.pred_or_wc,
|
||||||
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
|
middleware::PredicateOrWildcard::Predicate(Predicate::Native(
|
||||||
|
NativePredicate::Contains
|
||||||
|
))
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wc_pred() {
|
||||||
|
let input = r#"
|
||||||
|
my_pred(X, DynPred) = AND (
|
||||||
|
Equal(X["pred"], DynPred)
|
||||||
|
DynPred(X)
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let params = Params::default();
|
||||||
|
parse_validate_and_lower(input, ¶ms).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_multi_batch_packing() {
|
fn test_multi_batch_packing() {
|
||||||
// Create more predicates than fit in a single batch
|
// Create more predicates than fit in a single batch
|
||||||
|
|
@ -749,7 +785,7 @@ mod tests {
|
||||||
// Verify the second statement is an intro predicate reference
|
// Verify the second statement is an intro predicate reference
|
||||||
let intro_stmt = &pred.statements()[1];
|
let intro_stmt = &pred.statements()[1];
|
||||||
match intro_stmt.pred_or_wc() {
|
match intro_stmt.pred_or_wc() {
|
||||||
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
middleware::PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
||||||
assert_eq!(intro_ref.name, "external_check");
|
assert_eq!(intro_ref.name, "external_check");
|
||||||
assert_eq!(intro_ref.args_len, 1);
|
assert_eq!(intro_ref.args_len, 1);
|
||||||
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,11 @@
|
||||||
//! This module provides semantic validation for parsed AST documents,
|
//! This module provides semantic validation for parsed AST documents,
|
||||||
//! including name resolution, arity checking, and wildcard validation.
|
//! including name resolution, arity checking, and wildcard validation.
|
||||||
|
|
||||||
use std::{collections::HashMap, str::FromStr, sync::Arc};
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
str::FromStr,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
use hex::ToHex;
|
use hex::ToHex;
|
||||||
|
|
||||||
|
|
@ -411,6 +415,21 @@ impl Validator {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using
|
||||||
|
/// wildcard predicates.
|
||||||
|
fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> {
|
||||||
|
for name in names {
|
||||||
|
if NativePredicate::from_str(name).is_ok()
|
||||||
|
|| self.symbols.predicates.contains_key(*name)
|
||||||
|
{
|
||||||
|
return Err(ValidationError::WildcardPredicateNameCollision {
|
||||||
|
name: (*name).clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_statement(
|
fn validate_statement(
|
||||||
&self,
|
&self,
|
||||||
stmt: &StatementTmpl,
|
stmt: &StatementTmpl,
|
||||||
|
|
@ -418,18 +437,26 @@ impl Validator {
|
||||||
) -> Result<(), ValidationError> {
|
) -> Result<(), ValidationError> {
|
||||||
let pred_name = &stmt.predicate.name;
|
let pred_name = &stmt.predicate.name;
|
||||||
|
|
||||||
|
let wc_names = match wildcard_context {
|
||||||
|
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
|
||||||
|
None => HashSet::new(),
|
||||||
|
};
|
||||||
|
self.validate_wildcard_names(&wc_names)?;
|
||||||
|
|
||||||
// Check if predicate exists
|
// Check if predicate exists
|
||||||
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||||
// Native predicate
|
// Native predicate
|
||||||
PredicateInfo {
|
Some(PredicateInfo {
|
||||||
kind: PredicateKind::Native(native),
|
kind: PredicateKind::Native(native),
|
||||||
arity: native.arity(),
|
arity: native.arity(),
|
||||||
public_arity: native.arity(),
|
public_arity: native.arity(),
|
||||||
source_span: None,
|
source_span: None,
|
||||||
}
|
})
|
||||||
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
||||||
// Custom or imported predicate
|
// Custom or imported predicate
|
||||||
info.clone()
|
Some(info.clone())
|
||||||
|
} else if wc_names.contains(pred_name) {
|
||||||
|
None
|
||||||
} else {
|
} else {
|
||||||
return Err(ValidationError::UndefinedPredicate {
|
return Err(ValidationError::UndefinedPredicate {
|
||||||
name: pred_name.clone(),
|
name: pred_name.clone(),
|
||||||
|
|
@ -437,8 +464,8 @@ impl Validator {
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Some(ref pred_info) = pred_info {
|
||||||
let expected_arity = pred_info.public_arity;
|
let expected_arity = pred_info.public_arity;
|
||||||
|
|
||||||
if stmt.args.len() != expected_arity {
|
if stmt.args.len() != expected_arity {
|
||||||
return Err(ValidationError::ArgumentCountMismatch {
|
return Err(ValidationError::ArgumentCountMismatch {
|
||||||
predicate: pred_name.clone(),
|
predicate: pred_name.clone(),
|
||||||
|
|
@ -447,9 +474,10 @@ impl Validator {
|
||||||
span: stmt.span,
|
span: stmt.span,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate arguments
|
// Validate arguments
|
||||||
self.validate_statement_args(stmt, &pred_info, wildcard_context)?;
|
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -457,13 +485,13 @@ impl Validator {
|
||||||
fn validate_statement_args(
|
fn validate_statement_args(
|
||||||
&self,
|
&self,
|
||||||
stmt: &StatementTmpl,
|
stmt: &StatementTmpl,
|
||||||
pred_info: &PredicateInfo,
|
pred_info: Option<&PredicateInfo>,
|
||||||
wildcard_context: Option<(&str, &WildcardScope)>,
|
wildcard_context: Option<(&str, &WildcardScope)>,
|
||||||
) -> Result<(), ValidationError> {
|
) -> Result<(), ValidationError> {
|
||||||
// For custom predicates, only wildcards and literals are allowed
|
// For custom predicates, only wildcards and literals are allowed
|
||||||
if matches!(
|
if matches!(
|
||||||
pred_info.kind,
|
pred_info.map(|i| &i.kind),
|
||||||
PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. }
|
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
|
||||||
) {
|
) {
|
||||||
for arg in &stmt.args {
|
for arg in &stmt.args {
|
||||||
match arg {
|
match arg {
|
||||||
|
|
@ -631,6 +659,18 @@ mod tests {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wildcard_predicate_collision() {
|
||||||
|
let input = r#"
|
||||||
|
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
|
||||||
|
"#;
|
||||||
|
let result = parse_and_validate(input, &[]);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Err(ValidationError::WildcardPredicateNameCollision { .. })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_custom_predicate_with_anchored_key() {
|
fn test_custom_predicate_with_anchored_key() {
|
||||||
let input = r#"
|
let input = r#"
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,26 @@
|
||||||
use std::{backtrace::Backtrace, fmt::Debug};
|
use std::{backtrace::Backtrace, fmt::Debug};
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmplArg,
|
CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmpl,
|
||||||
Value, Wildcard,
|
StatementTmplArg, Value, Wildcard,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||||
|
|
||||||
|
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
|
||||||
|
let mut out = String::new();
|
||||||
|
use std::fmt::Write;
|
||||||
|
for (i, v) in wc_map.iter().enumerate() {
|
||||||
|
write!(out, "- {}: ", i).unwrap();
|
||||||
|
if let Some(v) = v {
|
||||||
|
writeln!(out, "{}", v).unwrap();
|
||||||
|
} else {
|
||||||
|
writeln!(out, "none").unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum MiddlewareInnerError {
|
pub enum MiddlewareInnerError {
|
||||||
#[error("incorrect statement args")]
|
#[error("incorrect statement args")]
|
||||||
|
|
@ -33,6 +47,13 @@ pub enum MiddlewareInnerError {
|
||||||
MismatchedStatementWildcardPredicate(Value, Value, Predicate),
|
MismatchedStatementWildcardPredicate(Value, Value, Predicate),
|
||||||
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
||||||
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
||||||
|
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
|
||||||
|
StatementsDontMatch(
|
||||||
|
Statement,
|
||||||
|
StatementTmpl,
|
||||||
|
Vec<Option<Value>>,
|
||||||
|
crate::middleware::Error,
|
||||||
|
),
|
||||||
#[error(
|
#[error(
|
||||||
"None of the statement templates of the following custom predicate have been matched:\n{0}"
|
"None of the statement templates of the following custom predicate have been matched:\n{0}"
|
||||||
)]
|
)]
|
||||||
|
|
@ -132,6 +153,14 @@ impl Error {
|
||||||
wc_value, st_arg, arg_index, pred
|
wc_value, st_arg, arg_index, pred
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
pub(crate) fn statements_dont_match(
|
||||||
|
s0: Statement,
|
||||||
|
s1: StatementTmpl,
|
||||||
|
wc_map: Vec<Option<Value>>,
|
||||||
|
mid_error: crate::middleware::Error,
|
||||||
|
) -> Self {
|
||||||
|
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
|
||||||
|
}
|
||||||
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
|
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
|
||||||
new!(UnsatisfiedCustomPredicateDisjunction(pred))
|
new!(UnsatisfiedCustomPredicateDisjunction(pred))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -611,17 +611,21 @@ pub fn fill_wildcard_values(
|
||||||
wildcard_map: &mut [Option<Value>],
|
wildcard_map: &mut [Option<Value>],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
let st_args = st.args();
|
|
||||||
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
|
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
|
||||||
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
|
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
|
||||||
}
|
}
|
||||||
st_tmpl
|
let st_args = st.args();
|
||||||
.args
|
|
||||||
.iter()
|
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||||
.zip(&st_args)
|
if let Err(st_tmpl_check_error) = check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) {
|
||||||
.try_for_each(|(st_tmpl_arg, st_arg)| {
|
return Err(Error::statements_dont_match(
|
||||||
check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map)
|
st.clone(),
|
||||||
})?;
|
st_tmpl.clone(),
|
||||||
|
wildcard_map.to_vec(),
|
||||||
|
st_tmpl_check_error,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -741,16 +745,7 @@ pub(crate) fn check_custom_pred(
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
// For `or` predicates, only one statement needs to match the template.
|
// For `or` predicates, only one statement needs to match the template.
|
||||||
// The rest of the statements can be `None`.
|
// The rest of the statements can be `None`.
|
||||||
let expected_pred_is_none = match &st_tmpl.pred_or_wc {
|
if !pred.conjunction && matches!(st, Statement::None) {
|
||||||
PredicateOrWildcard::Predicate(st_tmpl_pred) => {
|
|
||||||
*st_tmpl_pred == Predicate::Native(NativePredicate::None)
|
|
||||||
}
|
|
||||||
PredicateOrWildcard::Wildcard(wc) => {
|
|
||||||
wc_values[wc.index]
|
|
||||||
== Value::from(Predicate::Native(NativePredicate::None).hash(params))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;
|
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue