Fix parsing of intro statement templates inside custom predicates (#467)

* Fix parsing of intro statement templates inside custom predicates

* Tidy up comments
This commit is contained in:
Rob Knight 2026-01-30 19:30:57 +01:00 committed by GitHub
parent 337a51135e
commit 879c7201ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 263 additions and 169 deletions

View file

@ -12,7 +12,7 @@
//! cross-batch calls always point to earlier batches via `CustomPredicateRef`.
//! - Forward cross-batch references cannot occur with this planner (they are treated as unreachable).
use std::{collections::HashMap, str::FromStr, sync::Arc};
use std::{collections::HashMap, sync::Arc};
use petgraph::{algo::condensation, graph::DiGraph, prelude::NodeIndex, visit::EdgeRef};
@ -21,12 +21,11 @@ use crate::{
lang::{
error::BatchingError,
frontend_ast::{ConjunctionType, CustomPredicateDef},
frontend_ast_lower::lower_statement_arg,
frontend_ast_lower::{lower_statement_arg, resolve_predicate, ResolutionContext},
frontend_ast_split::{SplitChainInfo, SplitResult},
frontend_ast_validate::SymbolTable,
},
middleware::{
CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, Predicate, Statement,
},
middleware::{CustomPredicateBatch, CustomPredicateRef, Params, Statement},
};
/// A single step in a multi-operation sequence for split predicates
@ -318,13 +317,6 @@ struct PredicateAssignment {
index_in_batch: usize,
}
/// Information about an imported predicate for use during batching
#[derive(Debug, Clone)]
pub struct ImportedPredicateInfo {
pub batch: Arc<CustomPredicateBatch>,
pub index: usize,
}
/// Pack predicates into multiple batches
///
/// Takes a list of split results (containing predicates and optional chain info)
@ -337,13 +329,13 @@ pub struct ImportedPredicateInfo {
/// - Within a batch, predicates can reference each other freely via `BatchSelf`; cross-batch
/// references always point to earlier batches via `CustomPredicateRef`.
///
/// `imported_predicates` maps predicate names to their imported batch info,
/// allowing predicates to call imported predicates from other batches.
/// `symbols` provides the symbol table for resolving predicate references,
/// including imported predicates from other batches and intro predicates.
pub fn batch_predicates(
split_results: Vec<SplitResult>,
params: &Params,
base_batch_name: &str,
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
symbols: &SymbolTable,
) -> Result<PredicateBatches, BatchingError> {
// Extract predicates and collect split chains
let mut predicates = Vec::new();
@ -403,7 +395,7 @@ pub fn batch_predicates(
batch_idx,
&reference_map,
&batches,
imported_predicates,
symbols,
params,
&batch_name,
)?;
@ -593,7 +585,7 @@ fn build_single_batch(
batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
symbols: &SymbolTable,
params: &Params,
batch_name: &str,
) -> Result<Arc<CustomPredicateBatch>, BatchingError> {
@ -624,11 +616,10 @@ fn build_single_batch(
.map(|stmt| {
build_statement_with_resolved_refs(
stmt,
name,
batch_idx,
reference_map,
existing_batches,
imported_predicates,
symbols,
)
})
.collect::<Result<_, _>>()?;
@ -654,46 +645,26 @@ fn build_single_batch(
/// Build a statement template with properly resolved predicate references
fn build_statement_with_resolved_refs(
stmt: &crate::lang::frontend_ast::StatementTmpl,
caller_name: &str,
current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
symbols: &SymbolTable,
) -> Result<StatementTmplBuilder, BatchingError> {
let callee_name = &stmt.predicate.name;
// Resolve the predicate
let predicate = if let Ok(native) = NativePredicate::from_str(callee_name) {
Predicate::Native(native)
} else if let Some(&(target_batch, target_idx)) = reference_map.get(callee_name) {
// Local predicate in this document
if target_batch == current_batch_idx {
// Same batch - use BatchSelf
Predicate::BatchSelf(target_idx)
} else if target_batch < current_batch_idx {
// Earlier batch - use Custom ref
let batch = &existing_batches[target_batch];
Predicate::Custom(CustomPredicateRef::new(batch.clone(), target_idx))
} else {
// Forward reference to a later batch should be impossible with the dependency-aware planner
unreachable!(
"Forward cross-batch reference: '{}' (batch {}) -> '{}' (batch {})",
caller_name, current_batch_idx, callee_name, target_batch
);
}
} else if let Some(imported) = imported_predicates.get(callee_name) {
// Imported predicate from another batch
Predicate::Custom(CustomPredicateRef::new(
imported.batch.clone(),
imported.index,
))
} else {
// Unknown predicate
return Err(BatchingError::Internal {
message: format!("Unknown predicate reference: '{}'", callee_name),
});
// Resolve the predicate using the unified resolution function
let context = ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
};
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
BatchingError::Internal {
message: format!("Unknown predicate reference: '{}'", callee_name),
}
})?;
// Build the statement template
let mut builder = StatementTmplBuilder::new(predicate);
@ -709,24 +680,30 @@ mod tests {
use super::*;
use crate::{
lang::{
frontend_ast::parse::parse_document, frontend_ast_split::split_predicate_if_needed,
frontend_ast::parse::parse_document,
frontend_ast_split::split_predicate_if_needed,
frontend_ast_validate::{validate, ValidatedAST},
parser::parse_podlang,
},
middleware::PredicateOrWildcard,
middleware::{Predicate, PredicateOrWildcard},
};
fn parse_predicates(input: &str) -> Vec<CustomPredicateDef> {
/// Helper: parse and validate input, returning predicates and symbol table
fn parse_and_validate(input: &str) -> (Vec<CustomPredicateDef>, ValidatedAST) {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
let validated = validate(document.clone(), &[]).expect("Failed to validate");
document
let predicates = document
.items
.into_iter()
.filter_map(|item| match item {
crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred),
_ => None,
})
.collect()
.collect();
(predicates, validated)
}
/// Helper: wrap predicates into SplitResult (without actually splitting)
@ -748,14 +725,14 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -772,14 +749,14 @@ mod tests {
pred3(C) = AND(Equal(C["z"], 3))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -798,14 +775,14 @@ mod tests {
pred5(E) = AND(Equal(E["v"], 5))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -828,14 +805,14 @@ mod tests {
pred1(A) = AND(Equal(A["x"], 1))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -861,14 +838,14 @@ mod tests {
pred2(B) = AND(pred1(B))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -902,14 +879,14 @@ mod tests {
pred5(E) = AND(pred1(E))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_ok());
@ -949,7 +926,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the large predicate
@ -966,7 +943,7 @@ mod tests {
// That's 5 predicates, which spans 2 batches
assert_eq!(total_preds, 5);
let result = batch_predicates(all_split_results, &params, "TestBatch", &HashMap::new());
let result = batch_predicates(all_split_results, &params, "TestBatch", validated.symbols());
assert!(result.is_ok());
let batches = result.unwrap();
@ -995,14 +972,14 @@ mod tests {
pred5(E) = AND(Equal(E["v"], 5))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
)
.expect("Planner should avoid forward cross-batch reference");
@ -1019,8 +996,13 @@ mod tests {
fn test_empty_input() {
let split_results: Vec<SplitResult> = vec![];
let params = Params::default();
// For empty input, we need an empty symbol table
let empty_symbols = SymbolTable {
predicates: HashMap::new(),
wildcard_scopes: HashMap::new(),
};
let result = batch_predicates(split_results, &params, "TestBatch", &HashMap::new());
let result = batch_predicates(split_results, &params, "TestBatch", &empty_symbols);
assert!(result.is_ok());
let batches = result.unwrap();
@ -1035,14 +1017,14 @@ mod tests {
pred2(B) = AND(Equal(B["y"], 2))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
)
.unwrap();
@ -1061,7 +1043,7 @@ mod tests {
pred2(B) = AND(pred1(B))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params {
max_custom_batch_size: 1, // force SCC > capacity
..Default::default()
@ -1071,7 +1053,7 @@ mod tests {
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
);
assert!(result.is_err());
assert!(result
@ -1098,7 +1080,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
// Split and batch
@ -1107,8 +1089,9 @@ mod tests {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
all_split_results.push(result);
}
let batches = batch_predicates(all_split_results, &params, "TestBatch", &HashMap::new())
.expect("Batch failed");
let batches =
batch_predicates(all_split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
assert_eq!(batches.batch_count(), 2);
@ -1147,14 +1130,14 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
)
.unwrap();
@ -1195,7 +1178,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the predicate
@ -1210,7 +1193,7 @@ mod tests {
assert_eq!(split_results[0].predicates.len(), 2);
assert!(split_results[0].chain_info.is_some());
let batches = batch_predicates(split_results, &params, "TestBatch", &HashMap::new())
let batches = batch_predicates(split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
// Verify chain info
@ -1259,7 +1242,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the predicate
@ -1274,7 +1257,7 @@ mod tests {
assert_eq!(split_results[0].predicates.len(), 3);
assert!(split_results[0].chain_info.is_some());
let batches = batch_predicates(split_results, &params, "TestBatch", &HashMap::new())
let batches = batch_predicates(split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
// Verify chain info
@ -1320,7 +1303,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the predicate
@ -1330,7 +1313,7 @@ mod tests {
split_results.push(result);
}
let batches = batch_predicates(split_results, &params, "TestBatch", &HashMap::new())
let batches = batch_predicates(split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
// Try with wrong number of statements (3 instead of 6)
@ -1363,14 +1346,14 @@ mod tests {
my_pred(A) = AND(Equal(A["x"], 1))
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
&HashMap::new(),
validated.symbols(),
)
.unwrap();
@ -1401,7 +1384,7 @@ mod tests {
)
"#;
let predicates = parse_predicates(input);
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let mut split_results = Vec::new();
@ -1410,7 +1393,7 @@ mod tests {
split_results.push(result);
}
let batches = batch_predicates(split_results, &params, "TestBatch", &HashMap::new())
let batches = batch_predicates(split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
let statements: Vec<Statement> = (0..6).map(test_statement).collect();