Feat/fst order pred part3 & part4 (#457)

* support wildcard predicates in frontend

* suport wildcard predicate in podlang

* add validation test

* test full flow and apply some fixes

* fix clippy

* fix merge issues

* use desugared predicate

* Fix parsing of intro statement templates inside custom predicates

* Tidy up comments

* lang: handle wildcard predicate

* add unreachable message

---------

Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
Eduard S. 2026-02-02 10:59:33 +01:00 committed by GitHub
parent b66f5051b5
commit 498e946612
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 324 additions and 180 deletions

View file

@ -88,6 +88,9 @@ pub enum ValidationError {
first_span: Option<Span>,
second_span: Option<Span>,
},
#[error("Wildcard '{name}' collides with a predicate name")]
WildcardPredicateNameCollision { name: String },
}
/// Lowering errors from frontend AST lowering to middleware

View file

@ -619,6 +619,7 @@ fn build_single_batch(
batch_idx,
reference_map,
existing_batches,
name,
symbols,
)
})
@ -648,6 +649,7 @@ fn build_statement_with_resolved_refs(
current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
custom_predicate_name: &str, // custom pred that defines this statement template
symbols: &SymbolTable,
) -> Result<StatementTmplBuilder, BatchingError> {
let callee_name = &stmt.predicate.name;
@ -657,16 +659,17 @@ fn build_statement_with_resolved_refs(
current_batch_idx,
reference_map,
existing_batches,
custom_predicate_name,
};
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
BatchingError::Internal {
message: format!("Unknown predicate reference: '{}'", callee_name),
}
})?;
// Build the statement template
let mut builder = StatementTmplBuilder::new(predicate);
let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg));

View file

@ -10,7 +10,7 @@ use std::{
};
use crate::{
frontend::{BuilderArg, StatementTmplBuilder},
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
lang::{
frontend_ast::*,
frontend_ast_batch::{self, PredicateBatches},
@ -18,8 +18,8 @@ use crate::{
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
},
middleware::{
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
},
};
@ -35,6 +35,7 @@ pub enum ResolutionContext<'a> {
current_batch_idx: usize,
reference_map: &'a HashMap<String, (usize, usize)>,
existing_batches: &'a [Arc<CustomPredicateBatch>],
custom_predicate_name: &'a str,
},
}
@ -43,10 +44,23 @@ pub fn resolve_predicate(
pred_name: &str,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<Predicate> {
// 1. Try native predicate first
) -> Option<PredicateOrWildcard> {
// 0. Try wildcard first
if let ResolutionContext::Batch {
custom_predicate_name,
..
} = context
{
if let Some(wc_scope) = symbols.wildcard_scopes.get(*custom_predicate_name) {
if wc_scope.wildcards.contains_key(pred_name) {
return Some(PredicateOrWildcard::Wildcard(pred_name.to_string()));
}
}
}
// 1. Try native predicate second
if let Ok(native) = NativePredicate::from_str(pred_name) {
return Some(Predicate::Native(native));
return Some(PredicateOrWildcard::Predicate(Predicate::Native(native)));
}
// 2. Look up in symbol table
@ -64,6 +78,7 @@ pub fn resolve_predicate(
current_batch_idx,
reference_map,
existing_batches,
..
} => resolve_local_predicate(
pred_name,
*current_batch_idx,
@ -85,7 +100,7 @@ pub fn resolve_predicate(
verifier_data_hash: *verifier_data_hash,
}),
};
return Some(predicate);
return Some(PredicateOrWildcard::Predicate(predicate));
}
// 3. In batch context, also check reference_map for split chain pieces
@ -94,6 +109,7 @@ pub fn resolve_predicate(
current_batch_idx,
reference_map,
existing_batches,
..
} = context
{
if reference_map.contains_key(pred_name) {
@ -102,7 +118,8 @@ pub fn resolve_predicate(
*current_batch_idx,
reference_map,
existing_batches,
);
)
.map(PredicateOrWildcard::Predicate);
}
}
@ -328,7 +345,7 @@ impl<'a> Lowerer<'a> {
})?;
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate);
let mut builder = StatementTmplBuilder::new(predicate.clone());
for arg in &stmt.args {
let builder_arg = lower_statement_arg(arg);
builder = builder.arg(builder_arg);
@ -356,9 +373,14 @@ impl<'a> Lowerer<'a> {
mw_args.push(mw_arg);
}
let predicate = match desugared.pred_or_wc {
PredicateOrWildcard::Predicate(p) => p,
PredicateOrWildcard::Wildcard(_) => {
unreachable!("wildcard predicates aren't considered in requests")
}
};
Ok(MWStatementTmpl {
// TODO: Support wildcard
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
pred_or_wc: middleware::PredicateOrWildcard::Predicate(predicate),
args: mw_args,
})
}
@ -424,7 +446,6 @@ impl<'a> Lowerer<'a> {
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
split_results.push(result);
}
Ok(split_results)
}
}
@ -601,7 +622,7 @@ mod tests {
// Should be BatchSelf(0) referring to pred1
assert!(matches!(
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
middleware::PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
));
}
@ -639,10 +660,25 @@ mod tests {
// Should desugar to the Contains predicate
assert!(matches!(
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
middleware::PredicateOrWildcard::Predicate(Predicate::Native(
NativePredicate::Contains
))
));
}
#[test]
fn test_wc_pred() {
let input = r#"
my_pred(X, DynPred) = AND (
Equal(X["pred"], DynPred)
DynPred(X)
)
"#;
let params = Params::default();
parse_validate_and_lower(input, &params).unwrap();
}
#[test]
fn test_multi_batch_packing() {
// Create more predicates than fit in a single batch
@ -749,7 +785,7 @@ mod tests {
// Verify the second statement is an intro predicate reference
let intro_stmt = &pred.statements()[1];
match intro_stmt.pred_or_wc() {
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
middleware::PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
assert_eq!(intro_ref.name, "external_check");
assert_eq!(intro_ref.args_len, 1);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);

View file

@ -3,7 +3,11 @@
//! This module provides semantic validation for parsed AST documents,
//! including name resolution, arity checking, and wildcard validation.
use std::{collections::HashMap, str::FromStr, sync::Arc};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use hex::ToHex;
@ -411,6 +415,21 @@ impl Validator {
Ok(())
}
/// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using
/// wildcard predicates.
fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> {
for name in names {
if NativePredicate::from_str(name).is_ok()
|| self.symbols.predicates.contains_key(*name)
{
return Err(ValidationError::WildcardPredicateNameCollision {
name: (*name).clone(),
});
}
}
Ok(())
}
fn validate_statement(
&self,
stmt: &StatementTmpl,
@ -418,18 +437,26 @@ impl Validator {
) -> Result<(), ValidationError> {
let pred_name = &stmt.predicate.name;
let wc_names = match wildcard_context {
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
None => HashSet::new(),
};
self.validate_wildcard_names(&wc_names)?;
// Check if predicate exists
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
PredicateInfo {
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
}
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
info.clone()
Some(info.clone())
} else if wc_names.contains(pred_name) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.clone(),
@ -437,19 +464,20 @@ impl Validator {
});
};
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.clone(),
expected: expected_arity,
found: stmt.args.len(),
span: stmt.span,
});
if let Some(ref pred_info) = pred_info {
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.clone(),
expected: expected_arity,
found: stmt.args.len(),
span: stmt.span,
});
}
}
// Validate arguments
self.validate_statement_args(stmt, &pred_info, wildcard_context)?;
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?;
Ok(())
}
@ -457,13 +485,13 @@ impl Validator {
fn validate_statement_args(
&self,
stmt: &StatementTmpl,
pred_info: &PredicateInfo,
pred_info: Option<&PredicateInfo>,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.kind,
PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. }
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
) {
for arg in &stmt.args {
match arg {
@ -631,6 +659,18 @@ mod tests {
));
}
#[test]
fn test_wildcard_predicate_collision() {
let input = r#"
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
"#;
let result = parse_and_validate(input, &[]);
assert!(matches!(
result,
Err(ValidationError::WildcardPredicateNameCollision { .. })
));
}
#[test]
fn test_custom_predicate_with_anchored_key() {
let input = r#"