From 498e94661237f3da04e4085b4a52c45debe7fedb Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Mon, 2 Feb 2026 10:59:33 +0100 Subject: [PATCH] Feat/fst order pred part3 & part4 (#457) * support wildcard predicates in frontend * suport wildcard predicate in podlang * add validation test * test full flow and apply some fixes * fix clippy * fix merge issues * use desugared predicate * Fix parsing of intro statement templates inside custom predicates * Tidy up comments * lang: handle wildcard predicate * add unreachable message --------- Co-authored-by: Rob Knight --- src/backends/plonky2/circuits/mainpod.rs | 14 ++- src/backends/plonky2/mainpod/mod.rs | 4 +- src/frontend/custom.rs | 145 ++++++++++++----------- src/frontend/error.rs | 31 +---- src/frontend/mod.rs | 96 +++++++++++---- src/lang/error.rs | 3 + src/lang/frontend_ast_batch.rs | 7 +- src/lang/frontend_ast_lower.rs | 66 ++++++++--- src/lang/frontend_ast_validate.rs | 74 +++++++++--- src/middleware/error.rs | 33 +++++- src/middleware/operation.rs | 31 ++--- 11 files changed, 324 insertions(+), 180 deletions(-) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index fe23403..2bb6ee5 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -3255,8 +3255,12 @@ mod tests { use NativePredicate as NP; use StatementTmplBuilder as STB; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42)); - let stb1 = STB::new(NP::Equal).arg(("id", "key")).arg("secret"); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("id", "key")) + .arg("secret"); let _ = builder.predicate_and( "pred_and", &["id"], @@ -3349,8 +3353,10 @@ mod tests { use NativePredicate as NP; use StatementTmplBuilder as STB; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new(NP::Equal).arg(("id", "score")).arg(literal(42)); - let stb1 = STB::new(NP::Equal) + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) .arg(("secret_id", "key")) .arg(("id", "score")); let _ = builder.predicate_and( diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 78b617c..6c20a09 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1083,11 +1083,11 @@ pub mod tests { let vd_set = VDSet::new(&vds); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); - let stb0 = STB::new(NP::Contains) + let stb0 = STB::new_from_pred(NP::Contains) .arg("dict") .arg(literal("score")) .arg(literal(42)); - let stb1 = STB::new(NP::Equal) + let stb1 = STB::new_from_pred(NP::Equal) .arg(("secret_dict", "key")) .arg(("dict", "score")); let _ = cpb_builder.predicate_and( diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index be40a90..c0ee1ba 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -7,7 +7,7 @@ use crate::{ frontend::{AnchoredKey, Error, Result, Statement, StatementArg}, middleware::{ self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params, - Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, + Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, }, }; @@ -41,16 +41,34 @@ pub fn literal(v: impl Into) -> BuilderArg { BuilderArg::Literal(v.into()) } +#[derive(Clone, Debug)] +pub enum PredicateOrWildcard { + Predicate(Predicate), + Wildcard(String), +} + #[derive(Clone)] pub struct StatementTmplBuilder { - pub(crate) predicate: Predicate, + pub(crate) pred_or_wc: PredicateOrWildcard, pub(crate) args: Vec, } impl StatementTmplBuilder { - pub fn new(p: impl Into) -> StatementTmplBuilder { + pub fn new_from_pred(p: impl Into) -> StatementTmplBuilder { StatementTmplBuilder { - predicate: p.into(), + pred_or_wc: PredicateOrWildcard::Predicate(p.into()), + args: Vec::new(), + } + } + pub fn new_from_wc(p: impl Into) -> StatementTmplBuilder { + StatementTmplBuilder { + pred_or_wc: PredicateOrWildcard::Wildcard(p.into()), + args: Vec::new(), + } + } + pub fn new(pred_or_wc: PredicateOrWildcard) -> StatementTmplBuilder { + StatementTmplBuilder { + pred_or_wc, args: Vec::new(), } } @@ -62,68 +80,48 @@ impl StatementTmplBuilder { /// Desugar the predicate to a simpler form /// Should mirror the logic in `MainPodBuilder::lower_op` - pub(crate) fn desugar(self) -> StatementTmplBuilder { - match self.predicate { - Predicate::Native(NativePredicate::Gt) => { - let mut stb = StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::Lt), - args: self.args, - }; - stb.args.swap(0, 1); - stb - } - Predicate::Native(NativePredicate::GtEq) => { - let mut stb = StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::LtEq), - args: self.args, - }; - stb.args.swap(0, 1); - stb - } - Predicate::Native(NativePredicate::ArrayContains) - | Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::Contains), - args: self.args, - }, - Predicate::Native(NativePredicate::DictNotContains) - | Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::NotContains), - args: self.args, - }, - Predicate::Native(NativePredicate::SetContains) => { - let mut new_args = self.args.clone(); - new_args.push(self.args[1].clone()); - StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::Contains), - args: new_args, + pub(crate) fn desugar(mut self) -> StatementTmplBuilder { + let pred = match self.pred_or_wc { + PredicateOrWildcard::Predicate(p) => p, + PredicateOrWildcard::Wildcard(_) => return self, + }; + let pred = match pred { + Predicate::Native(nat_pred) => Predicate::Native(match nat_pred { + NativePredicate::Gt => { + self.args.swap(0, 1); + NativePredicate::Lt } - } - Predicate::Native(NativePredicate::DictInsert) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::ContainerInsert), - args: self.args, - }, - Predicate::Native(NativePredicate::SetInsert) => { - let mut new_args = self.args.clone(); - new_args.push(self.args[2].clone()); - StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::ContainerInsert), - args: new_args, + NativePredicate::GtEq => { + self.args.swap(0, 1); + NativePredicate::LtEq } - } - Predicate::Native(NativePredicate::DictUpdate) - | Predicate::Native(NativePredicate::ArrayUpdate) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::ContainerUpdate), - args: self.args, - }, - Predicate::Native(NativePredicate::DictDelete) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::ContainerDelete), - args: self.args, - }, - Predicate::Native(NativePredicate::SetDelete) => StatementTmplBuilder { - predicate: Predicate::Native(NativePredicate::ContainerDelete), - args: self.args, - }, - _ => self, + NativePredicate::ArrayContains | NativePredicate::DictContains => { + NativePredicate::Contains + } + NativePredicate::DictNotContains | NativePredicate::SetNotContains => { + NativePredicate::NotContains + } + NativePredicate::SetContains => { + self.args.push(self.args[1].clone()); + NativePredicate::Contains + } + NativePredicate::DictInsert => NativePredicate::ContainerInsert, + NativePredicate::SetInsert => { + self.args.push(self.args[2].clone()); + NativePredicate::ContainerInsert + } + NativePredicate::DictUpdate | NativePredicate::ArrayUpdate => { + NativePredicate::ContainerUpdate + } + NativePredicate::DictDelete => NativePredicate::ContainerDelete, + NativePredicate::SetDelete => NativePredicate::ContainerDelete, + _ => nat_pred, + }), + _ => pred, + }; + StatementTmplBuilder { + pred_or_wc: PredicateOrWildcard::Predicate(pred), + args: self.args, } } } @@ -200,7 +198,7 @@ impl CustomPredicateBatchBuilder { .iter() .map(|sb| { let stb = sb.clone().desugar(); - let args = stb + let st_tmpl_args = stb .args .iter() .map(|a| { @@ -216,10 +214,17 @@ impl CustomPredicateBatchBuilder { }) }) .collect::>()?; + let pred_or_wc = match stb.pred_or_wc { + PredicateOrWildcard::Predicate(p) => { + middleware::PredicateOrWildcard::Predicate(p) + } + PredicateOrWildcard::Wildcard(v) => middleware::PredicateOrWildcard::Wildcard( + resolve_wildcard(args, priv_args, &v)?, + ), + }; Ok(StatementTmpl { - // TODO: Support wildcard - pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()), - args, + pred_or_wc, + args: st_tmpl_args, }) }) .collect::>()?; @@ -299,7 +304,7 @@ mod tests { let vd_set = &*MOCK_VD_SET; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into()); - let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt) + let gt_stb = StatementTmplBuilder::new_from_pred(NativePredicate::Gt) .arg("s1") .arg("s2"); @@ -344,7 +349,7 @@ mod tests { let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into()); - let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains) + let set_contains_stb = StatementTmplBuilder::new_from_pred(NativePredicate::SetContains) .arg("s1") .arg("s2"); diff --git a/src/frontend/error.rs b/src/frontend/error.rs index 3d162d5..21264d7 100644 --- a/src/frontend/error.rs +++ b/src/frontend/error.rs @@ -1,34 +1,13 @@ use std::{backtrace::Backtrace, fmt::Debug}; -use crate::middleware::{BackendError, Statement, StatementTmpl, Value}; +use crate::middleware::BackendError; pub type Result = core::result::Result; -fn display_wc_map(wc_map: &[Option]) -> String { - let mut out = String::new(); - use std::fmt::Write; - for (i, v) in wc_map.iter().enumerate() { - write!(out, "- {}: ", i).unwrap(); - if let Some(v) = v { - writeln!(out, "{}", v).unwrap(); - } else { - writeln!(out, "none").unwrap(); - } - } - out -} - #[derive(thiserror::Error, Debug)] pub enum InnerError { #[error("{0} {1} is over the limit {2}")] MaxLength(String, usize, usize), - #[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))] - StatementsDontMatch( - Statement, - StatementTmpl, - Vec>, - crate::middleware::Error, - ), #[error("invalid arguments to {0} operation")] OpInvalidArgs(String), #[error("Podlang parse error: {0}")] @@ -99,14 +78,6 @@ impl Error { pub(crate) fn op_invalid_args(s: String) -> Self { new!(OpInvalidArgs(s)) } - pub(crate) fn statements_dont_match( - s0: Statement, - s1: StatementTmpl, - wc_map: Vec>, - mid_error: crate::middleware::Error, - ) -> Self { - new!(StatementsDontMatch(s0, s1, wc_map, mid_error)) - } pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self { new!(MaxLength(obj, found, expect)) } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index f600f7c..c8ea847 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize}; pub use serialization::SerializedMainPod; use crate::middleware::{ - self, check_custom_pred, check_st_tmpl, containers::Dictionary, hash_op, max_op, prod_op, - sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux, - OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, StatementArg, VDSet, - Value, ValueRef, + self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op, + prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, + OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, + StatementArg, VDSet, Value, ValueRef, }; mod custom; @@ -600,21 +600,7 @@ impl MainPodBuilder { } wildcard_map[index] = Some(value); } - for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { - let st_args = st.args(); - for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { - if let Err(st_tmpl_check_error) = - check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) - { - return Err(Error::statements_dont_match( - st.clone(), - st_tmpl.clone(), - wildcard_map, - st_tmpl_check_error, - )); - } - } - } + fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?; let v_default = Value::from(0); let st_args: Vec<_> = wildcard_map .into_iter() @@ -817,7 +803,8 @@ pub mod tests { use super::*; use crate::{ backends::plonky2::{ - mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer, + basetypes::DEFAULT_VD_SET, mainpod::Prover, mock::mainpod::MockProver, + primitives::ec::schnorr::SecretKey, signer::Signer, }, dict, examples::{ @@ -1423,6 +1410,75 @@ pub mod tests { Ok(()) } + #[test] + fn test_wildcard_predicate() -> Result<()> { + let mock = true; + let params = Params::default(); + + let mock_prover = MockProver {}; + let real_prover = Prover {}; + let (vd_set, prover): (_, &dyn MainPodProver) = if mock { + (&VDSet::new(&[]), &mock_prover) + } else { + println!("Prebuilding circuits to calculate vd_set..."); + let vd_set = &*DEFAULT_VD_SET; + println!("vd_set calculation complete"); + (vd_set, &real_prover) + }; + + let input = r#" + Test(a, b, private: c) = AND( + Equal(a, 5) + b(1, 5) + c(6, 3) + ) + "#; + let batch = parse(input, ¶ms, &[]) + .unwrap() + .first_batch() + .unwrap() + .clone(); + let pred_test = batch.predicate_ref_by_name("Test").unwrap(); + + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap(); + let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap(); + let st2 = builder.priv_op(Operation::ne(6, 3)).unwrap(); + let _st = builder + .op(true, vec![], Operation::custom(pred_test, [st0, st1, st2])) + .unwrap(); + + let pod = builder.prove(prover).unwrap(); + pod.pod.verify().unwrap(); + + let input = r#" + Test(a, b, private: c) = OR( + Equal(a, 5) + b(1, 5) + c(6, 3) + ) + "#; + let batch = parse(input, ¶ms, &[]) + .unwrap() + .first_batch() + .unwrap() + .clone(); + let pred_test = batch.predicate_ref_by_name("Test").unwrap(); + + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let st0 = Statement::None; + let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap(); + let st2 = Statement::None; + let _st = builder + .op(true, vec![], Operation::custom(pred_test, [st0, st1, st2])) + .unwrap(); + + let pod = builder.prove(prover).unwrap(); + pod.pod.verify().unwrap(); + + Ok(()) + } + #[test] fn test_apply_predicate_e2e() -> Result<()> { // End-to-end test of apply_predicate with MockProver diff --git a/src/lang/error.rs b/src/lang/error.rs index 318e715..2ae7c25 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -88,6 +88,9 @@ pub enum ValidationError { first_span: Option, second_span: Option, }, + + #[error("Wildcard '{name}' collides with a predicate name")] + WildcardPredicateNameCollision { name: String }, } /// Lowering errors from frontend AST lowering to middleware diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs index fb58748..6b5f375 100644 --- a/src/lang/frontend_ast_batch.rs +++ b/src/lang/frontend_ast_batch.rs @@ -619,6 +619,7 @@ fn build_single_batch( batch_idx, reference_map, existing_batches, + name, symbols, ) }) @@ -648,6 +649,7 @@ fn build_statement_with_resolved_refs( current_batch_idx: usize, reference_map: &HashMap, existing_batches: &[Arc], + custom_predicate_name: &str, // custom pred that defines this statement template symbols: &SymbolTable, ) -> Result { let callee_name = &stmt.predicate.name; @@ -657,16 +659,17 @@ fn build_statement_with_resolved_refs( current_batch_idx, reference_map, existing_batches, + custom_predicate_name, }; - let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| { + let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| { BatchingError::Internal { message: format!("Unknown predicate reference: '{}'", callee_name), } })?; // Build the statement template - let mut builder = StatementTmplBuilder::new(predicate); + let mut builder = StatementTmplBuilder::new(pred_or_wc); for arg in &stmt.args { builder = builder.arg(lower_statement_arg(arg)); diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index ce693b9..185681e 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -10,7 +10,7 @@ use std::{ }; use crate::{ - frontend::{BuilderArg, StatementTmplBuilder}, + frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder}, lang::{ frontend_ast::*, frontend_ast_batch::{self, PredicateBatches}, @@ -18,8 +18,8 @@ use crate::{ frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, }, middleware::{ - containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key, - NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl, + self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key, + NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value, Wildcard, }, }; @@ -35,6 +35,7 @@ pub enum ResolutionContext<'a> { current_batch_idx: usize, reference_map: &'a HashMap, existing_batches: &'a [Arc], + custom_predicate_name: &'a str, }, } @@ -43,10 +44,23 @@ pub fn resolve_predicate( pred_name: &str, symbols: &SymbolTable, context: &ResolutionContext, -) -> Option { - // 1. Try native predicate first +) -> Option { + // 0. Try wildcard first + if let ResolutionContext::Batch { + custom_predicate_name, + .. + } = context + { + if let Some(wc_scope) = symbols.wildcard_scopes.get(*custom_predicate_name) { + if wc_scope.wildcards.contains_key(pred_name) { + return Some(PredicateOrWildcard::Wildcard(pred_name.to_string())); + } + } + } + + // 1. Try native predicate second if let Ok(native) = NativePredicate::from_str(pred_name) { - return Some(Predicate::Native(native)); + return Some(PredicateOrWildcard::Predicate(Predicate::Native(native))); } // 2. Look up in symbol table @@ -64,6 +78,7 @@ pub fn resolve_predicate( current_batch_idx, reference_map, existing_batches, + .. } => resolve_local_predicate( pred_name, *current_batch_idx, @@ -85,7 +100,7 @@ pub fn resolve_predicate( verifier_data_hash: *verifier_data_hash, }), }; - return Some(predicate); + return Some(PredicateOrWildcard::Predicate(predicate)); } // 3. In batch context, also check reference_map for split chain pieces @@ -94,6 +109,7 @@ pub fn resolve_predicate( current_batch_idx, reference_map, existing_batches, + .. } = context { if reference_map.contains_key(pred_name) { @@ -102,7 +118,8 @@ pub fn resolve_predicate( *current_batch_idx, reference_map, existing_batches, - ); + ) + .map(PredicateOrWildcard::Predicate); } } @@ -328,7 +345,7 @@ impl<'a> Lowerer<'a> { })?; // Create a builder with the resolved predicate and desugar - let mut builder = StatementTmplBuilder::new(predicate); + let mut builder = StatementTmplBuilder::new(predicate.clone()); for arg in &stmt.args { let builder_arg = lower_statement_arg(arg); builder = builder.arg(builder_arg); @@ -356,9 +373,14 @@ impl<'a> Lowerer<'a> { mw_args.push(mw_arg); } + let predicate = match desugared.pred_or_wc { + PredicateOrWildcard::Predicate(p) => p, + PredicateOrWildcard::Wildcard(_) => { + unreachable!("wildcard predicates aren't considered in requests") + } + }; Ok(MWStatementTmpl { - // TODO: Support wildcard - pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate), + pred_or_wc: middleware::PredicateOrWildcard::Predicate(predicate), args: mw_args, }) } @@ -424,7 +446,6 @@ impl<'a> Lowerer<'a> { let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; split_results.push(result); } - Ok(split_results) } } @@ -601,7 +622,7 @@ mod tests { // Should be BatchSelf(0) referring to pred1 assert!(matches!( stmt.pred_or_wc, - PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) + middleware::PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) )); } @@ -639,10 +660,25 @@ mod tests { // Should desugar to the Contains predicate assert!(matches!( stmt.pred_or_wc, - PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains)) + middleware::PredicateOrWildcard::Predicate(Predicate::Native( + NativePredicate::Contains + )) )); } + #[test] + fn test_wc_pred() { + let input = r#" + my_pred(X, DynPred) = AND ( + Equal(X["pred"], DynPred) + DynPred(X) + ) + "#; + + let params = Params::default(); + parse_validate_and_lower(input, ¶ms).unwrap(); + } + #[test] fn test_multi_batch_packing() { // Create more predicates than fit in a single batch @@ -749,7 +785,7 @@ mod tests { // Verify the second statement is an intro predicate reference let intro_stmt = &pred.statements()[1]; match intro_stmt.pred_or_wc() { - PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => { + middleware::PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => { assert_eq!(intro_ref.name, "external_check"); assert_eq!(intro_ref.args_len, 1); assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH); diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 6fd2349..8939c8d 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -3,7 +3,11 @@ //! This module provides semantic validation for parsed AST documents, //! including name resolution, arity checking, and wildcard validation. -use std::{collections::HashMap, str::FromStr, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, + sync::Arc, +}; use hex::ToHex; @@ -411,6 +415,21 @@ impl Validator { Ok(()) } + /// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using + /// wildcard predicates. + fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> { + for name in names { + if NativePredicate::from_str(name).is_ok() + || self.symbols.predicates.contains_key(*name) + { + return Err(ValidationError::WildcardPredicateNameCollision { + name: (*name).clone(), + }); + } + } + Ok(()) + } + fn validate_statement( &self, stmt: &StatementTmpl, @@ -418,18 +437,26 @@ impl Validator { ) -> Result<(), ValidationError> { let pred_name = &stmt.predicate.name; + let wc_names = match wildcard_context { + Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(), + None => HashSet::new(), + }; + self.validate_wildcard_names(&wc_names)?; + // Check if predicate exists let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) { // Native predicate - PredicateInfo { + Some(PredicateInfo { kind: PredicateKind::Native(native), arity: native.arity(), public_arity: native.arity(), source_span: None, - } + }) } else if let Some(info) = self.symbols.predicates.get(pred_name) { // Custom or imported predicate - info.clone() + Some(info.clone()) + } else if wc_names.contains(pred_name) { + None } else { return Err(ValidationError::UndefinedPredicate { name: pred_name.clone(), @@ -437,19 +464,20 @@ impl Validator { }); }; - let expected_arity = pred_info.public_arity; - - if stmt.args.len() != expected_arity { - return Err(ValidationError::ArgumentCountMismatch { - predicate: pred_name.clone(), - expected: expected_arity, - found: stmt.args.len(), - span: stmt.span, - }); + if let Some(ref pred_info) = pred_info { + let expected_arity = pred_info.public_arity; + if stmt.args.len() != expected_arity { + return Err(ValidationError::ArgumentCountMismatch { + predicate: pred_name.clone(), + expected: expected_arity, + found: stmt.args.len(), + span: stmt.span, + }); + } } // Validate arguments - self.validate_statement_args(stmt, &pred_info, wildcard_context)?; + self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; Ok(()) } @@ -457,13 +485,13 @@ impl Validator { fn validate_statement_args( &self, stmt: &StatementTmpl, - pred_info: &PredicateInfo, + pred_info: Option<&PredicateInfo>, wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { // For custom predicates, only wildcards and literals are allowed if matches!( - pred_info.kind, - PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. } + pred_info.map(|i| &i.kind), + Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. }) ) { for arg in &stmt.args { match arg { @@ -631,6 +659,18 @@ mod tests { )); } + #[test] + fn test_wildcard_predicate_collision() { + let input = r#" + my_pred(A, Lt) = AND (Equal(A["x"], Lt)) + "#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::WildcardPredicateNameCollision { .. }) + )); + } + #[test] fn test_custom_predicate_with_anchored_key() { let input = r#" diff --git a/src/middleware/error.rs b/src/middleware/error.rs index e71544b..23650ce 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -3,12 +3,26 @@ use std::{backtrace::Backtrace, fmt::Debug}; use crate::middleware::{ - CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmplArg, - Value, Wildcard, + CustomPredicate, Hash, Key, Operation, Predicate, Statement, StatementArg, StatementTmpl, + StatementTmplArg, Value, Wildcard, }; pub type Result = core::result::Result; +fn display_wc_map(wc_map: &[Option]) -> String { + let mut out = String::new(); + use std::fmt::Write; + for (i, v) in wc_map.iter().enumerate() { + write!(out, "- {}: ", i).unwrap(); + if let Some(v) = v { + writeln!(out, "{}", v).unwrap(); + } else { + writeln!(out, "none").unwrap(); + } + } + out +} + #[derive(Debug, thiserror::Error)] pub enum MiddlewareInnerError { #[error("incorrect statement args")] @@ -33,6 +47,13 @@ pub enum MiddlewareInnerError { MismatchedStatementWildcardPredicate(Value, Value, Predicate), #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), + #[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))] + StatementsDontMatch( + Statement, + StatementTmpl, + Vec>, + crate::middleware::Error, + ), #[error( "None of the statement templates of the following custom predicate have been matched:\n{0}" )] @@ -132,6 +153,14 @@ impl Error { wc_value, st_arg, arg_index, pred )) } + pub(crate) fn statements_dont_match( + s0: Statement, + s1: StatementTmpl, + wc_map: Vec>, + mid_error: crate::middleware::Error, + ) -> Self { + new!(StatementsDontMatch(s0, s1, wc_map, mid_error)) + } pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { new!(UnsatisfiedCustomPredicateDisjunction(pred)) } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 6435a92..7d06d9c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -611,17 +611,21 @@ pub fn fill_wildcard_values( wildcard_map: &mut [Option], ) -> Result<()> { for (st_tmpl, st) in pred.statements.iter().zip(args) { - let st_args = st.args(); if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc { wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?; } - st_tmpl - .args - .iter() - .zip(&st_args) - .try_for_each(|(st_tmpl_arg, st_arg)| { - check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) - })?; + let st_args = st.args(); + + for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { + if let Err(st_tmpl_check_error) = check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) { + return Err(Error::statements_dont_match( + st.clone(), + st_tmpl.clone(), + wildcard_map.to_vec(), + st_tmpl_check_error, + )); + } + } } Ok(()) } @@ -741,16 +745,7 @@ pub(crate) fn check_custom_pred( for (st_tmpl, st) in pred.statements.iter().zip(args) { // For `or` predicates, only one statement needs to match the template. // The rest of the statements can be `None`. - let expected_pred_is_none = match &st_tmpl.pred_or_wc { - PredicateOrWildcard::Predicate(st_tmpl_pred) => { - *st_tmpl_pred == Predicate::Native(NativePredicate::None) - } - PredicateOrWildcard::Wildcard(wc) => { - wc_values[wc.index] - == Value::from(Predicate::Native(NativePredicate::None).hash(params)) - } - }; - if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none { + if !pred.conjunction && matches!(st, Statement::None) { continue; } check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;