Feat/fst order pred part3 & part4 (#457)
* support wildcard predicates in frontend * suport wildcard predicate in podlang * add validation test * test full flow and apply some fixes * fix clippy * fix merge issues * use desugared predicate * Fix parsing of intro statement templates inside custom predicates * Tidy up comments * lang: handle wildcard predicate * add unreachable message --------- Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
parent
b66f5051b5
commit
498e946612
11 changed files with 324 additions and 180 deletions
|
|
@ -88,6 +88,9 @@ pub enum ValidationError {
|
|||
first_span: Option<Span>,
|
||||
second_span: Option<Span>,
|
||||
},
|
||||
|
||||
#[error("Wildcard '{name}' collides with a predicate name")]
|
||||
WildcardPredicateNameCollision { name: String },
|
||||
}
|
||||
|
||||
/// Lowering errors from frontend AST lowering to middleware
|
||||
|
|
|
|||
|
|
@ -619,6 +619,7 @@ fn build_single_batch(
|
|||
batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
name,
|
||||
symbols,
|
||||
)
|
||||
})
|
||||
|
|
@ -648,6 +649,7 @@ fn build_statement_with_resolved_refs(
|
|||
current_batch_idx: usize,
|
||||
reference_map: &HashMap<String, (usize, usize)>,
|
||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
||||
custom_predicate_name: &str, // custom pred that defines this statement template
|
||||
symbols: &SymbolTable,
|
||||
) -> Result<StatementTmplBuilder, BatchingError> {
|
||||
let callee_name = &stmt.predicate.name;
|
||||
|
|
@ -657,16 +659,17 @@ fn build_statement_with_resolved_refs(
|
|||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
custom_predicate_name,
|
||||
};
|
||||
|
||||
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
|
||||
let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
|
||||
BatchingError::Internal {
|
||||
message: format!("Unknown predicate reference: '{}'", callee_name),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Build the statement template
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
let mut builder = StatementTmplBuilder::new(pred_or_wc);
|
||||
|
||||
for arg in &stmt.args {
|
||||
builder = builder.arg(lower_statement_arg(arg));
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use std::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
frontend::{BuilderArg, StatementTmplBuilder},
|
||||
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
|
||||
lang::{
|
||||
frontend_ast::*,
|
||||
frontend_ast_batch::{self, PredicateBatches},
|
||||
|
|
@ -18,8 +18,8 @@ use crate::{
|
|||
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
||||
},
|
||||
middleware::{
|
||||
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
||||
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
||||
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
|
||||
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
||||
},
|
||||
};
|
||||
|
|
@ -35,6 +35,7 @@ pub enum ResolutionContext<'a> {
|
|||
current_batch_idx: usize,
|
||||
reference_map: &'a HashMap<String, (usize, usize)>,
|
||||
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
||||
custom_predicate_name: &'a str,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -43,10 +44,23 @@ pub fn resolve_predicate(
|
|||
pred_name: &str,
|
||||
symbols: &SymbolTable,
|
||||
context: &ResolutionContext,
|
||||
) -> Option<Predicate> {
|
||||
// 1. Try native predicate first
|
||||
) -> Option<PredicateOrWildcard> {
|
||||
// 0. Try wildcard first
|
||||
if let ResolutionContext::Batch {
|
||||
custom_predicate_name,
|
||||
..
|
||||
} = context
|
||||
{
|
||||
if let Some(wc_scope) = symbols.wildcard_scopes.get(*custom_predicate_name) {
|
||||
if wc_scope.wildcards.contains_key(pred_name) {
|
||||
return Some(PredicateOrWildcard::Wildcard(pred_name.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Try native predicate second
|
||||
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
return Some(Predicate::Native(native));
|
||||
return Some(PredicateOrWildcard::Predicate(Predicate::Native(native)));
|
||||
}
|
||||
|
||||
// 2. Look up in symbol table
|
||||
|
|
@ -64,6 +78,7 @@ pub fn resolve_predicate(
|
|||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
..
|
||||
} => resolve_local_predicate(
|
||||
pred_name,
|
||||
*current_batch_idx,
|
||||
|
|
@ -85,7 +100,7 @@ pub fn resolve_predicate(
|
|||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
};
|
||||
return Some(predicate);
|
||||
return Some(PredicateOrWildcard::Predicate(predicate));
|
||||
}
|
||||
|
||||
// 3. In batch context, also check reference_map for split chain pieces
|
||||
|
|
@ -94,6 +109,7 @@ pub fn resolve_predicate(
|
|||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
..
|
||||
} = context
|
||||
{
|
||||
if reference_map.contains_key(pred_name) {
|
||||
|
|
@ -102,7 +118,8 @@ pub fn resolve_predicate(
|
|||
*current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
);
|
||||
)
|
||||
.map(PredicateOrWildcard::Predicate);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -328,7 +345,7 @@ impl<'a> Lowerer<'a> {
|
|||
})?;
|
||||
|
||||
// Create a builder with the resolved predicate and desugar
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
let mut builder = StatementTmplBuilder::new(predicate.clone());
|
||||
for arg in &stmt.args {
|
||||
let builder_arg = lower_statement_arg(arg);
|
||||
builder = builder.arg(builder_arg);
|
||||
|
|
@ -356,9 +373,14 @@ impl<'a> Lowerer<'a> {
|
|||
mw_args.push(mw_arg);
|
||||
}
|
||||
|
||||
let predicate = match desugared.pred_or_wc {
|
||||
PredicateOrWildcard::Predicate(p) => p,
|
||||
PredicateOrWildcard::Wildcard(_) => {
|
||||
unreachable!("wildcard predicates aren't considered in requests")
|
||||
}
|
||||
};
|
||||
Ok(MWStatementTmpl {
|
||||
// TODO: Support wildcard
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
|
||||
pred_or_wc: middleware::PredicateOrWildcard::Predicate(predicate),
|
||||
args: mw_args,
|
||||
})
|
||||
}
|
||||
|
|
@ -424,7 +446,6 @@ impl<'a> Lowerer<'a> {
|
|||
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
||||
split_results.push(result);
|
||||
}
|
||||
|
||||
Ok(split_results)
|
||||
}
|
||||
}
|
||||
|
|
@ -601,7 +622,7 @@ mod tests {
|
|||
// Should be BatchSelf(0) referring to pred1
|
||||
assert!(matches!(
|
||||
stmt.pred_or_wc,
|
||||
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
||||
middleware::PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
||||
));
|
||||
}
|
||||
|
||||
|
|
@ -639,10 +660,25 @@ mod tests {
|
|||
// Should desugar to the Contains predicate
|
||||
assert!(matches!(
|
||||
stmt.pred_or_wc,
|
||||
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
|
||||
middleware::PredicateOrWildcard::Predicate(Predicate::Native(
|
||||
NativePredicate::Contains
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wc_pred() {
|
||||
let input = r#"
|
||||
my_pred(X, DynPred) = AND (
|
||||
Equal(X["pred"], DynPred)
|
||||
DynPred(X)
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
parse_validate_and_lower(input, ¶ms).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_batch_packing() {
|
||||
// Create more predicates than fit in a single batch
|
||||
|
|
@ -749,7 +785,7 @@ mod tests {
|
|||
// Verify the second statement is an intro predicate reference
|
||||
let intro_stmt = &pred.statements()[1];
|
||||
match intro_stmt.pred_or_wc() {
|
||||
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
||||
middleware::PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
||||
assert_eq!(intro_ref.name, "external_check");
|
||||
assert_eq!(intro_ref.args_len, 1);
|
||||
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@
|
|||
//! This module provides semantic validation for parsed AST documents,
|
||||
//! including name resolution, arity checking, and wildcard validation.
|
||||
|
||||
use std::{collections::HashMap, str::FromStr, sync::Arc};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use hex::ToHex;
|
||||
|
||||
|
|
@ -411,6 +415,21 @@ impl Validator {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using
|
||||
/// wildcard predicates.
|
||||
fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> {
|
||||
for name in names {
|
||||
if NativePredicate::from_str(name).is_ok()
|
||||
|| self.symbols.predicates.contains_key(*name)
|
||||
{
|
||||
return Err(ValidationError::WildcardPredicateNameCollision {
|
||||
name: (*name).clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_statement(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
|
|
@ -418,18 +437,26 @@ impl Validator {
|
|||
) -> Result<(), ValidationError> {
|
||||
let pred_name = &stmt.predicate.name;
|
||||
|
||||
let wc_names = match wildcard_context {
|
||||
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
|
||||
None => HashSet::new(),
|
||||
};
|
||||
self.validate_wildcard_names(&wc_names)?;
|
||||
|
||||
// Check if predicate exists
|
||||
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
// Native predicate
|
||||
PredicateInfo {
|
||||
Some(PredicateInfo {
|
||||
kind: PredicateKind::Native(native),
|
||||
arity: native.arity(),
|
||||
public_arity: native.arity(),
|
||||
source_span: None,
|
||||
}
|
||||
})
|
||||
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
||||
// Custom or imported predicate
|
||||
info.clone()
|
||||
Some(info.clone())
|
||||
} else if wc_names.contains(pred_name) {
|
||||
None
|
||||
} else {
|
||||
return Err(ValidationError::UndefinedPredicate {
|
||||
name: pred_name.clone(),
|
||||
|
|
@ -437,19 +464,20 @@ impl Validator {
|
|||
});
|
||||
};
|
||||
|
||||
let expected_arity = pred_info.public_arity;
|
||||
|
||||
if stmt.args.len() != expected_arity {
|
||||
return Err(ValidationError::ArgumentCountMismatch {
|
||||
predicate: pred_name.clone(),
|
||||
expected: expected_arity,
|
||||
found: stmt.args.len(),
|
||||
span: stmt.span,
|
||||
});
|
||||
if let Some(ref pred_info) = pred_info {
|
||||
let expected_arity = pred_info.public_arity;
|
||||
if stmt.args.len() != expected_arity {
|
||||
return Err(ValidationError::ArgumentCountMismatch {
|
||||
predicate: pred_name.clone(),
|
||||
expected: expected_arity,
|
||||
found: stmt.args.len(),
|
||||
span: stmt.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
self.validate_statement_args(stmt, &pred_info, wildcard_context)?;
|
||||
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -457,13 +485,13 @@ impl Validator {
|
|||
fn validate_statement_args(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
pred_info: &PredicateInfo,
|
||||
pred_info: Option<&PredicateInfo>,
|
||||
wildcard_context: Option<(&str, &WildcardScope)>,
|
||||
) -> Result<(), ValidationError> {
|
||||
// For custom predicates, only wildcards and literals are allowed
|
||||
if matches!(
|
||||
pred_info.kind,
|
||||
PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. }
|
||||
pred_info.map(|i| &i.kind),
|
||||
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
|
||||
) {
|
||||
for arg in &stmt.args {
|
||||
match arg {
|
||||
|
|
@ -631,6 +659,18 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_predicate_collision() {
|
||||
let input = r#"
|
||||
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::WildcardPredicateNameCollision { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_predicate_with_anchored_key() {
|
||||
let input = r#"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue