Feat/fst order pred part3 & part4 (#457)

* support wildcard predicates in frontend

* suport wildcard predicate in podlang

* add validation test

* test full flow and apply some fixes

* fix clippy

* fix merge issues

* use desugared predicate

* Fix parsing of intro statement templates inside custom predicates

* Tidy up comments

* lang: handle wildcard predicate

* add unreachable message

---------

Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
Eduard S. 2026-02-02 10:59:33 +01:00 committed by GitHub
parent b66f5051b5
commit 498e946612
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 324 additions and 180 deletions

View file

@ -3255,8 +3255,12 @@ mod tests {
use NativePredicate as NP; use NativePredicate as NP;
use StatementTmplBuilder as STB; use StatementTmplBuilder as STB;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42)); let stb0 = STB::new_from_pred(NP::Equal)
let stb1 = STB::new(NP::Equal).arg(("id", "key")).arg("secret"); .arg(("id", "score"))
.arg(literal(42));
let stb1 = STB::new_from_pred(NP::Equal)
.arg(("id", "key"))
.arg("secret");
let _ = builder.predicate_and( let _ = builder.predicate_and(
"pred_and", "pred_and",
&["id"], &["id"],
@ -3349,8 +3353,10 @@ mod tests {
use NativePredicate as NP; use NativePredicate as NP;
use StatementTmplBuilder as STB; use StatementTmplBuilder as STB;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42)); let stb0 = STB::new_from_pred(NP::Equal)
let stb1 = STB::new(NP::Equal) .arg(("id", "score"))
.arg(literal(42));
let stb1 = STB::new_from_pred(NP::Equal)
.arg(("secret_id", "key")) .arg(("secret_id", "key"))
.arg(("id", "score")); .arg(("id", "score"));
let _ = builder.predicate_and( let _ = builder.predicate_and(

View file

@ -1083,11 +1083,11 @@ pub mod tests {
let vd_set = VDSet::new(&vds); let vd_set = VDSet::new(&vds);
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into());
let stb0 = STB::new(NP::Contains) let stb0 = STB::new_from_pred(NP::Contains)
.arg("dict") .arg("dict")
.arg(literal("score")) .arg(literal("score"))
.arg(literal(42)); .arg(literal(42));
let stb1 = STB::new(NP::Equal) let stb1 = STB::new_from_pred(NP::Equal)
.arg(("secret_dict", "key")) .arg(("secret_dict", "key"))
.arg(("dict", "score")); .arg(("dict", "score"));
let _ = cpb_builder.predicate_and( let _ = cpb_builder.predicate_and(

View file

@ -7,7 +7,7 @@ use crate::{
frontend::{AnchoredKey, Error, Result, Statement, StatementArg}, frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
middleware::{ middleware::{
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params, self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
}, },
}; };
@ -41,16 +41,34 @@ pub fn literal(v: impl Into<Value>) -> BuilderArg {
BuilderArg::Literal(v.into()) BuilderArg::Literal(v.into())
} }
#[derive(Clone, Debug)]
pub enum PredicateOrWildcard {
Predicate(Predicate),
Wildcard(String),
}
#[derive(Clone)] #[derive(Clone)]
pub struct StatementTmplBuilder { pub struct StatementTmplBuilder {
pub(crate) predicate: Predicate, pub(crate) pred_or_wc: PredicateOrWildcard,
pub(crate) args: Vec<BuilderArg>, pub(crate) args: Vec<BuilderArg>,
} }
impl StatementTmplBuilder { impl StatementTmplBuilder {
pub fn new(p: impl Into<Predicate>) -> StatementTmplBuilder { pub fn new_from_pred(p: impl Into<Predicate>) -> StatementTmplBuilder {
StatementTmplBuilder { StatementTmplBuilder {
predicate: p.into(), pred_or_wc: PredicateOrWildcard::Predicate(p.into()),
args: Vec::new(),
}
}
pub fn new_from_wc(p: impl Into<String>) -> StatementTmplBuilder {
StatementTmplBuilder {
pred_or_wc: PredicateOrWildcard::Wildcard(p.into()),
args: Vec::new(),
}
}
pub fn new(pred_or_wc: PredicateOrWildcard) -> StatementTmplBuilder {
StatementTmplBuilder {
pred_or_wc,
args: Vec::new(), args: Vec::new(),
} }
} }
@ -62,68 +80,48 @@ impl StatementTmplBuilder {
/// Desugar the predicate to a simpler form /// Desugar the predicate to a simpler form
/// Should mirror the logic in `MainPodBuilder::lower_op` /// Should mirror the logic in `MainPodBuilder::lower_op`
pub(crate) fn desugar(self) -> StatementTmplBuilder { pub(crate) fn desugar(mut self) -> StatementTmplBuilder {
match self.predicate { let pred = match self.pred_or_wc {
Predicate::Native(NativePredicate::Gt) => { PredicateOrWildcard::Predicate(p) => p,
let mut stb = StatementTmplBuilder { PredicateOrWildcard::Wildcard(_) => return self,
predicate: Predicate::Native(NativePredicate::Lt), };
args: self.args, let pred = match pred {
}; Predicate::Native(nat_pred) => Predicate::Native(match nat_pred {
stb.args.swap(0, 1); NativePredicate::Gt => {
stb self.args.swap(0, 1);
} NativePredicate::Lt
Predicate::Native(NativePredicate::GtEq) => {
let mut stb = StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::LtEq),
args: self.args,
};
stb.args.swap(0, 1);
stb
}
Predicate::Native(NativePredicate::ArrayContains)
| Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: self.args,
},
Predicate::Native(NativePredicate::DictNotContains)
| Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::NotContains),
args: self.args,
},
Predicate::Native(NativePredicate::SetContains) => {
let mut new_args = self.args.clone();
new_args.push(self.args[1].clone());
StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: new_args,
} }
} NativePredicate::GtEq => {
Predicate::Native(NativePredicate::DictInsert) => StatementTmplBuilder { self.args.swap(0, 1);
predicate: Predicate::Native(NativePredicate::ContainerInsert), NativePredicate::LtEq
args: self.args,
},
Predicate::Native(NativePredicate::SetInsert) => {
let mut new_args = self.args.clone();
new_args.push(self.args[2].clone());
StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerInsert),
args: new_args,
} }
} NativePredicate::ArrayContains | NativePredicate::DictContains => {
Predicate::Native(NativePredicate::DictUpdate) NativePredicate::Contains
| Predicate::Native(NativePredicate::ArrayUpdate) => StatementTmplBuilder { }
predicate: Predicate::Native(NativePredicate::ContainerUpdate), NativePredicate::DictNotContains | NativePredicate::SetNotContains => {
args: self.args, NativePredicate::NotContains
}, }
Predicate::Native(NativePredicate::DictDelete) => StatementTmplBuilder { NativePredicate::SetContains => {
predicate: Predicate::Native(NativePredicate::ContainerDelete), self.args.push(self.args[1].clone());
args: self.args, NativePredicate::Contains
}, }
Predicate::Native(NativePredicate::SetDelete) => StatementTmplBuilder { NativePredicate::DictInsert => NativePredicate::ContainerInsert,
predicate: Predicate::Native(NativePredicate::ContainerDelete), NativePredicate::SetInsert => {
args: self.args, self.args.push(self.args[2].clone());
}, NativePredicate::ContainerInsert
_ => self, }
NativePredicate::DictUpdate | NativePredicate::ArrayUpdate => {
NativePredicate::ContainerUpdate
}
NativePredicate::DictDelete => NativePredicate::ContainerDelete,
NativePredicate::SetDelete => NativePredicate::ContainerDelete,
_ => nat_pred,
}),
_ => pred,
};
StatementTmplBuilder {
pred_or_wc: PredicateOrWildcard::Predicate(pred),
args: self.args,
} }
} }
} }
@ -200,7 +198,7 @@ impl CustomPredicateBatchBuilder {
.iter() .iter()
.map(|sb| { .map(|sb| {
let stb = sb.clone().desugar(); let stb = sb.clone().desugar();
let args = stb let st_tmpl_args = stb
.args .args
.iter() .iter()
.map(|a| { .map(|a| {
@ -216,10 +214,17 @@ impl CustomPredicateBatchBuilder {
}) })
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
let pred_or_wc = match stb.pred_or_wc {
PredicateOrWildcard::Predicate(p) => {
middleware::PredicateOrWildcard::Predicate(p)
}
PredicateOrWildcard::Wildcard(v) => middleware::PredicateOrWildcard::Wildcard(
resolve_wildcard(args, priv_args, &v)?,
),
};
Ok(StatementTmpl { Ok(StatementTmpl {
// TODO: Support wildcard pred_or_wc,
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()), args: st_tmpl_args,
args,
}) })
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
@ -299,7 +304,7 @@ mod tests {
let vd_set = &*MOCK_VD_SET; let vd_set = &*MOCK_VD_SET;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into()); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt) let gt_stb = StatementTmplBuilder::new_from_pred(NativePredicate::Gt)
.arg("s1") .arg("s1")
.arg("s2"); .arg("s2");
@ -344,7 +349,7 @@ mod tests {
let mut builder = let mut builder =
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into()); CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains) let set_contains_stb = StatementTmplBuilder::new_from_pred(NativePredicate::SetContains)
.arg("s1") .arg("s1")
.arg("s2"); .arg("s2");

View file

@ -1,34 +1,13 @@
use std::{backtrace::Backtrace, fmt::Debug}; use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::{BackendError, Statement, StatementTmpl, Value}; use crate::middleware::BackendError;
pub type Result<T, E = Error> = core::result::Result<T, E>; pub type Result<T, E = Error> = core::result::Result<T, E>;
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
let mut out = String::new();
use std::fmt::Write;
for (i, v) in wc_map.iter().enumerate() {
write!(out, "- {}: ", i).unwrap();
if let Some(v) = v {
writeln!(out, "{}", v).unwrap();
} else {
writeln!(out, "none").unwrap();
}
}
out
}
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum InnerError { pub enum InnerError {
#[error("{0} {1} is over the limit {2}")] #[error("{0} {1} is over the limit {2}")]
MaxLength(String, usize, usize), MaxLength(String, usize, usize),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
StatementsDontMatch(
Statement,
StatementTmpl,
Vec<Option<Value>>,
crate::middleware::Error,
),
#[error("invalid arguments to {0} operation")] #[error("invalid arguments to {0} operation")]
OpInvalidArgs(String), OpInvalidArgs(String),
#[error("Podlang parse error: {0}")] #[error("Podlang parse error: {0}")]
@ -99,14 +78,6 @@ impl Error {
pub(crate) fn op_invalid_args(s: String) -> Self { pub(crate) fn op_invalid_args(s: String) -> Self {
new!(OpInvalidArgs(s)) new!(OpInvalidArgs(s))
} }
pub(crate) fn statements_dont_match(
s0: Statement,
s1: StatementTmpl,
wc_map: Vec<Option<Value>>,
mid_error: crate::middleware::Error,
) -> Self {
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
}
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self { pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
new!(MaxLength(obj, found, expect)) new!(MaxLength(obj, found, expect))
} }

View file

@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize};
pub use serialization::SerializedMainPod; pub use serialization::SerializedMainPod;
use crate::middleware::{ use crate::middleware::{
self, check_custom_pred, check_st_tmpl, containers::Dictionary, hash_op, max_op, prod_op, self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op,
sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux, prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation,
OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, StatementArg, VDSet, OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement,
Value, ValueRef, StatementArg, VDSet, Value, ValueRef,
}; };
mod custom; mod custom;
@ -600,21 +600,7 @@ impl MainPodBuilder {
} }
wildcard_map[index] = Some(value); wildcard_map[index] = Some(value);
} }
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?;
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if let Err(st_tmpl_check_error) =
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
{
return Err(Error::statements_dont_match(
st.clone(),
st_tmpl.clone(),
wildcard_map,
st_tmpl_check_error,
));
}
}
}
let v_default = Value::from(0); let v_default = Value::from(0);
let st_args: Vec<_> = wildcard_map let st_args: Vec<_> = wildcard_map
.into_iter() .into_iter()
@ -817,7 +803,8 @@ pub mod tests {
use super::*; use super::*;
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer, basetypes::DEFAULT_VD_SET, mainpod::Prover, mock::mainpod::MockProver,
primitives::ec::schnorr::SecretKey, signer::Signer,
}, },
dict, dict,
examples::{ examples::{
@ -1423,6 +1410,75 @@ pub mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_wildcard_predicate() -> Result<()> {
let mock = true;
let params = Params::default();
let mock_prover = MockProver {};
let real_prover = Prover {};
let (vd_set, prover): (_, &dyn MainPodProver) = if mock {
(&VDSet::new(&[]), &mock_prover)
} else {
println!("Prebuilding circuits to calculate vd_set...");
let vd_set = &*DEFAULT_VD_SET;
println!("vd_set calculation complete");
(vd_set, &real_prover)
};
let input = r#"
Test(a, b, private: c) = AND(
Equal(a, 5)
b(1, 5)
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap();
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
let st2 = builder.priv_op(Operation::ne(6, 3)).unwrap();
let _st = builder
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
.unwrap();
let pod = builder.prove(prover).unwrap();
pod.pod.verify().unwrap();
let input = r#"
Test(a, b, private: c) = OR(
Equal(a, 5)
b(1, 5)
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let st0 = Statement::None;
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
let st2 = Statement::None;
let _st = builder
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
.unwrap();
let pod = builder.prove(prover).unwrap();
pod.pod.verify().unwrap();
Ok(())
}
#[test] #[test]
fn test_apply_predicate_e2e() -> Result<()> { fn test_apply_predicate_e2e() -> Result<()> {
// End-to-end test of apply_predicate with MockProver // End-to-end test of apply_predicate with MockProver

View file

@ -88,6 +88,9 @@ pub enum ValidationError {
first_span: Option<Span>, first_span: Option<Span>,
second_span: Option<Span>, second_span: Option<Span>,
}, },
#[error("Wildcard '{name}' collides with a predicate name")]
WildcardPredicateNameCollision { name: String },
} }
/// Lowering errors from frontend AST lowering to middleware /// Lowering errors from frontend AST lowering to middleware

View file

@ -619,6 +619,7 @@ fn build_single_batch(
batch_idx, batch_idx,
reference_map, reference_map,
existing_batches, existing_batches,
name,
symbols, symbols,
) )
}) })
@ -648,6 +649,7 @@ fn build_statement_with_resolved_refs(
current_batch_idx: usize, current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>, reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>], existing_batches: &[Arc<CustomPredicateBatch>],
custom_predicate_name: &str, // custom pred that defines this statement template
symbols: &SymbolTable, symbols: &SymbolTable,
) -> Result<StatementTmplBuilder, BatchingError> { ) -> Result<StatementTmplBuilder, BatchingError> {
let callee_name = &stmt.predicate.name; let callee_name = &stmt.predicate.name;
@ -657,16 +659,17 @@ fn build_statement_with_resolved_refs(
current_batch_idx, current_batch_idx,
reference_map, reference_map,
existing_batches, existing_batches,
custom_predicate_name,
}; };
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| { let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
BatchingError::Internal { BatchingError::Internal {
message: format!("Unknown predicate reference: '{}'", callee_name), message: format!("Unknown predicate reference: '{}'", callee_name),
} }
})?; })?;
// Build the statement template // Build the statement template
let mut builder = StatementTmplBuilder::new(predicate); let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args { for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg)); builder = builder.arg(lower_statement_arg(arg));

View file

@ -10,7 +10,7 @@ use std::{
}; };
use crate::{ use crate::{
frontend::{BuilderArg, StatementTmplBuilder}, frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
lang::{ lang::{
frontend_ast::*, frontend_ast::*,
frontend_ast_batch::{self, PredicateBatches}, frontend_ast_batch::{self, PredicateBatches},
@ -18,8 +18,8 @@ use crate::{
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
}, },
middleware::{ middleware::{
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key, self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl, NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Value, Wildcard, StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
}, },
}; };
@ -35,6 +35,7 @@ pub enum ResolutionContext<'a> {
current_batch_idx: usize, current_batch_idx: usize,
reference_map: &'a HashMap<String, (usize, usize)>, reference_map: &'a HashMap<String, (usize, usize)>,
existing_batches: &'a [Arc<CustomPredicateBatch>], existing_batches: &'a [Arc<CustomPredicateBatch>],
custom_predicate_name: &'a str,
}, },
} }
@ -43,10 +44,23 @@ pub fn resolve_predicate(
pred_name: &str, pred_name: &str,
symbols: &SymbolTable, symbols: &SymbolTable,
context: &ResolutionContext, context: &ResolutionContext,
) -> Option<Predicate> { ) -> Option<PredicateOrWildcard> {
// 1. Try native predicate first // 0. Try wildcard first
if let ResolutionContext::Batch {
custom_predicate_name,
..
} = context
{
if let Some(wc_scope) = symbols.wildcard_scopes.get(*custom_predicate_name) {
if wc_scope.wildcards.contains_key(pred_name) {
return Some(PredicateOrWildcard::Wildcard(pred_name.to_string()));
}
}
}
// 1. Try native predicate second
if let Ok(native) = NativePredicate::from_str(pred_name) { if let Ok(native) = NativePredicate::from_str(pred_name) {
return Some(Predicate::Native(native)); return Some(PredicateOrWildcard::Predicate(Predicate::Native(native)));
} }
// 2. Look up in symbol table // 2. Look up in symbol table
@ -64,6 +78,7 @@ pub fn resolve_predicate(
current_batch_idx, current_batch_idx,
reference_map, reference_map,
existing_batches, existing_batches,
..
} => resolve_local_predicate( } => resolve_local_predicate(
pred_name, pred_name,
*current_batch_idx, *current_batch_idx,
@ -85,7 +100,7 @@ pub fn resolve_predicate(
verifier_data_hash: *verifier_data_hash, verifier_data_hash: *verifier_data_hash,
}), }),
}; };
return Some(predicate); return Some(PredicateOrWildcard::Predicate(predicate));
} }
// 3. In batch context, also check reference_map for split chain pieces // 3. In batch context, also check reference_map for split chain pieces
@ -94,6 +109,7 @@ pub fn resolve_predicate(
current_batch_idx, current_batch_idx,
reference_map, reference_map,
existing_batches, existing_batches,
..
} = context } = context
{ {
if reference_map.contains_key(pred_name) { if reference_map.contains_key(pred_name) {
@ -102,7 +118,8 @@ pub fn resolve_predicate(
*current_batch_idx, *current_batch_idx,
reference_map, reference_map,
existing_batches, existing_batches,
); )
.map(PredicateOrWildcard::Predicate);
} }
} }
@ -328,7 +345,7 @@ impl<'a> Lowerer<'a> {
})?; })?;
// Create a builder with the resolved predicate and desugar // Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate); let mut builder = StatementTmplBuilder::new(predicate.clone());
for arg in &stmt.args { for arg in &stmt.args {
let builder_arg = lower_statement_arg(arg); let builder_arg = lower_statement_arg(arg);
builder = builder.arg(builder_arg); builder = builder.arg(builder_arg);
@ -356,9 +373,14 @@ impl<'a> Lowerer<'a> {
mw_args.push(mw_arg); mw_args.push(mw_arg);
} }
let predicate = match desugared.pred_or_wc {
PredicateOrWildcard::Predicate(p) => p,
PredicateOrWildcard::Wildcard(_) => {
unreachable!("wildcard predicates aren't considered in requests")
}
};
Ok(MWStatementTmpl { Ok(MWStatementTmpl {
// TODO: Support wildcard pred_or_wc: middleware::PredicateOrWildcard::Predicate(predicate),
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
args: mw_args, args: mw_args,
}) })
} }
@ -424,7 +446,6 @@ impl<'a> Lowerer<'a> {
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
split_results.push(result); split_results.push(result);
} }
Ok(split_results) Ok(split_results)
} }
} }
@ -601,7 +622,7 @@ mod tests {
// Should be BatchSelf(0) referring to pred1 // Should be BatchSelf(0) referring to pred1
assert!(matches!( assert!(matches!(
stmt.pred_or_wc, stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) middleware::PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
)); ));
} }
@ -639,10 +660,25 @@ mod tests {
// Should desugar to the Contains predicate // Should desugar to the Contains predicate
assert!(matches!( assert!(matches!(
stmt.pred_or_wc, stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains)) middleware::PredicateOrWildcard::Predicate(Predicate::Native(
NativePredicate::Contains
))
)); ));
} }
#[test]
fn test_wc_pred() {
let input = r#"
my_pred(X, DynPred) = AND (
Equal(X["pred"], DynPred)
DynPred(X)
)
"#;
let params = Params::default();
parse_validate_and_lower(input, &params).unwrap();
}
#[test] #[test]
fn test_multi_batch_packing() { fn test_multi_batch_packing() {
// Create more predicates than fit in a single batch // Create more predicates than fit in a single batch
@ -749,7 +785,7 @@ mod tests {
// Verify the second statement is an intro predicate reference // Verify the second statement is an intro predicate reference
let intro_stmt = &pred.statements()[1]; let intro_stmt = &pred.statements()[1];
match intro_stmt.pred_or_wc() { match intro_stmt.pred_or_wc() {
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => { middleware::PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
assert_eq!(intro_ref.name, "external_check"); assert_eq!(intro_ref.name, "external_check");
assert_eq!(intro_ref.args_len, 1); assert_eq!(intro_ref.args_len, 1);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH); assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);

View file

@ -3,7 +3,11 @@
//! This module provides semantic validation for parsed AST documents, //! This module provides semantic validation for parsed AST documents,
//! including name resolution, arity checking, and wildcard validation. //! including name resolution, arity checking, and wildcard validation.
use std::{collections::HashMap, str::FromStr, sync::Arc}; use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use hex::ToHex; use hex::ToHex;
@ -411,6 +415,21 @@ impl Validator {
Ok(()) Ok(())
} }
/// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using
/// wildcard predicates.
fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> {
for name in names {
if NativePredicate::from_str(name).is_ok()
|| self.symbols.predicates.contains_key(*name)
{
return Err(ValidationError::WildcardPredicateNameCollision {
name: (*name).clone(),
});
}
}
Ok(())
}
fn validate_statement( fn validate_statement(
&self, &self,
stmt: &StatementTmpl, stmt: &StatementTmpl,
@ -418,18 +437,26 @@ impl Validator {
) -> Result<(), ValidationError> { ) -> Result<(), ValidationError> {
let pred_name = &stmt.predicate.name; let pred_name = &stmt.predicate.name;
let wc_names = match wildcard_context {
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
None => HashSet::new(),
};
self.validate_wildcard_names(&wc_names)?;
// Check if predicate exists // Check if predicate exists
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) { let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate // Native predicate
PredicateInfo { Some(PredicateInfo {
kind: PredicateKind::Native(native), kind: PredicateKind::Native(native),
arity: native.arity(), arity: native.arity(),
public_arity: native.arity(), public_arity: native.arity(),
source_span: None, source_span: None,
} })
} else if let Some(info) = self.symbols.predicates.get(pred_name) { } else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate // Custom or imported predicate
info.clone() Some(info.clone())
} else if wc_names.contains(pred_name) {
None
} else { } else {
return Err(ValidationError::UndefinedPredicate { return Err(ValidationError::UndefinedPredicate {
name: pred_name.clone(), name: pred_name.clone(),
@ -437,19 +464,20 @@ impl Validator {
}); });
}; };
let expected_arity = pred_info.public_arity; if let Some(ref pred_info) = pred_info {
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity { if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch { return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.clone(), predicate: pred_name.clone(),
expected: expected_arity, expected: expected_arity,
found: stmt.args.len(), found: stmt.args.len(),
span: stmt.span, span: stmt.span,
}); });
}
} }
// Validate arguments // Validate arguments
self.validate_statement_args(stmt, &pred_info, wildcard_context)?; self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?;
Ok(()) Ok(())
} }
@ -457,13 +485,13 @@ impl Validator {
fn validate_statement_args( fn validate_statement_args(
&self, &self,
stmt: &StatementTmpl, stmt: &StatementTmpl,
pred_info: &PredicateInfo, pred_info: Option<&PredicateInfo>,
wildcard_context: Option<(&str, &WildcardScope)>, wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> { ) -> Result<(), ValidationError> {
// For custom predicates, only wildcards and literals are allowed // For custom predicates, only wildcards and literals are allowed
if matches!( if matches!(
pred_info.kind, pred_info.map(|i| &i.kind),
PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. } Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
) { ) {
for arg in &stmt.args { for arg in &stmt.args {
match arg { match arg {
@ -631,6 +659,18 @@ mod tests {
)); ));
} }
#[test]
fn test_wildcard_predicate_collision() {
let input = r#"
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
"#;
let result = parse_and_validate(input, &[]);
assert!(matches!(
result,
Err(ValidationError::WildcardPredicateNameCollision { .. })
));
}
#[test] #[test]
fn test_custom_predicate_with_anchored_key() { fn test_custom_predicate_with_anchored_key() {
let input = r#" let input = r#"

View file

@ -3,12 +3,26 @@
use std::{backtrace::Backtrace, fmt::Debug}; use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::{ use crate::middleware::{
CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmplArg, CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmpl,
Value, Wildcard, StatementTmplArg, Value, Wildcard,
}; };
pub type Result<T, E = Error> = core::result::Result<T, E>; pub type Result<T, E = Error> = core::result::Result<T, E>;
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
let mut out = String::new();
use std::fmt::Write;
for (i, v) in wc_map.iter().enumerate() {
write!(out, "- {}: ", i).unwrap();
if let Some(v) = v {
writeln!(out, "{}", v).unwrap();
} else {
writeln!(out, "none").unwrap();
}
}
out
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum MiddlewareInnerError { pub enum MiddlewareInnerError {
#[error("incorrect statement args")] #[error("incorrect statement args")]
@ -33,6 +47,13 @@ pub enum MiddlewareInnerError {
MismatchedStatementWildcardPredicate(Value, Value, Predicate), MismatchedStatementWildcardPredicate(Value, Value, Predicate),
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
StatementsDontMatch(
Statement,
StatementTmpl,
Vec<Option<Value>>,
crate::middleware::Error,
),
#[error( #[error(
"None of the statement templates of the following custom predicate have been matched:\n{0}" "None of the statement templates of the following custom predicate have been matched:\n{0}"
)] )]
@ -132,6 +153,14 @@ impl Error {
wc_value, st_arg, arg_index, pred wc_value, st_arg, arg_index, pred
)) ))
} }
pub(crate) fn statements_dont_match(
s0: Statement,
s1: StatementTmpl,
wc_map: Vec<Option<Value>>,
mid_error: crate::middleware::Error,
) -> Self {
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
}
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
new!(UnsatisfiedCustomPredicateDisjunction(pred)) new!(UnsatisfiedCustomPredicateDisjunction(pred))
} }

View file

@ -611,17 +611,21 @@ pub fn fill_wildcard_values(
wildcard_map: &mut [Option<Value>], wildcard_map: &mut [Option<Value>],
) -> Result<()> { ) -> Result<()> {
for (st_tmpl, st) in pred.statements.iter().zip(args) { for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc { if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?; wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
} }
st_tmpl let st_args = st.args();
.args
.iter() for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
.zip(&st_args) if let Err(st_tmpl_check_error) = check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) {
.try_for_each(|(st_tmpl_arg, st_arg)| { return Err(Error::statements_dont_match(
check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) st.clone(),
})?; st_tmpl.clone(),
wildcard_map.to_vec(),
st_tmpl_check_error,
));
}
}
} }
Ok(()) Ok(())
} }
@ -741,16 +745,7 @@ pub(crate) fn check_custom_pred(
for (st_tmpl, st) in pred.statements.iter().zip(args) { for (st_tmpl, st) in pred.statements.iter().zip(args) {
// For `or` predicates, only one statement needs to match the template. // For `or` predicates, only one statement needs to match the template.
// The rest of the statements can be `None`. // The rest of the statements can be `None`.
let expected_pred_is_none = match &st_tmpl.pred_or_wc { if !pred.conjunction && matches!(st, Statement::None) {
PredicateOrWildcard::Predicate(st_tmpl_pred) => {
*st_tmpl_pred == Predicate::Native(NativePredicate::None)
}
PredicateOrWildcard::Wildcard(wc) => {
wc_values[wc.index]
== Value::from(Predicate::Native(NativePredicate::None).hash(params))
}
};
if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none {
continue; continue;
} }
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?; check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;