Fix parsing of intro statement templates inside custom predicates (#467)
* Fix parsing of intro statement templates inside custom predicates * Tidy up comments
This commit is contained in:
parent
337a51135e
commit
879c7201ad
2 changed files with 263 additions and 169 deletions
|
|
@ -12,7 +12,7 @@
|
|||
//! cross-batch calls always point to earlier batches via `CustomPredicateRef`.
|
||||
//! - Forward cross-batch references cannot occur with this planner (they are treated as unreachable).
|
||||
|
||||
use std::{collections::HashMap, str::FromStr, sync::Arc};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use petgraph::{algo::condensation, graph::DiGraph, prelude::NodeIndex, visit::EdgeRef};
|
||||
|
||||
|
|
@ -21,12 +21,11 @@ use crate::{
|
|||
lang::{
|
||||
error::BatchingError,
|
||||
frontend_ast::{ConjunctionType, CustomPredicateDef},
|
||||
frontend_ast_lower::lower_statement_arg,
|
||||
frontend_ast_lower::{lower_statement_arg, resolve_predicate, ResolutionContext},
|
||||
frontend_ast_split::{SplitChainInfo, SplitResult},
|
||||
frontend_ast_validate::SymbolTable,
|
||||
},
|
||||
middleware::{
|
||||
CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, Predicate, Statement,
|
||||
},
|
||||
middleware::{CustomPredicateBatch, CustomPredicateRef, Params, Statement},
|
||||
};
|
||||
|
||||
/// A single step in a multi-operation sequence for split predicates
|
||||
|
|
@ -318,13 +317,6 @@ struct PredicateAssignment {
|
|||
index_in_batch: usize,
|
||||
}
|
||||
|
||||
/// Information about an imported predicate for use during batching
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImportedPredicateInfo {
|
||||
pub batch: Arc<CustomPredicateBatch>,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
/// Pack predicates into multiple batches
|
||||
///
|
||||
/// Takes a list of split results (containing predicates and optional chain info)
|
||||
|
|
@ -337,13 +329,13 @@ pub struct ImportedPredicateInfo {
|
|||
/// - Within a batch, predicates can reference each other freely via `BatchSelf`; cross-batch
|
||||
/// references always point to earlier batches via `CustomPredicateRef`.
|
||||
///
|
||||
/// `imported_predicates` maps predicate names to their imported batch info,
|
||||
/// allowing predicates to call imported predicates from other batches.
|
||||
/// `symbols` provides the symbol table for resolving predicate references,
|
||||
/// including imported predicates from other batches and intro predicates.
|
||||
pub fn batch_predicates(
|
||||
split_results: Vec<SplitResult>,
|
||||
params: &Params,
|
||||
base_batch_name: &str,
|
||||
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
|
||||
symbols: &SymbolTable,
|
||||
) -> Result<PredicateBatches, BatchingError> {
|
||||
// Extract predicates and collect split chains
|
||||
let mut predicates = Vec::new();
|
||||
|
|
@ -403,7 +395,7 @@ pub fn batch_predicates(
|
|||
batch_idx,
|
||||
&reference_map,
|
||||
&batches,
|
||||
imported_predicates,
|
||||
symbols,
|
||||
params,
|
||||
&batch_name,
|
||||
)?;
|
||||
|
|
@ -593,7 +585,7 @@ fn build_single_batch(
|
|||
batch_idx: usize,
|
||||
reference_map: &HashMap<String, (usize, usize)>,
|
||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
||||
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
|
||||
symbols: &SymbolTable,
|
||||
params: &Params,
|
||||
batch_name: &str,
|
||||
) -> Result<Arc<CustomPredicateBatch>, BatchingError> {
|
||||
|
|
@ -624,11 +616,10 @@ fn build_single_batch(
|
|||
.map(|stmt| {
|
||||
build_statement_with_resolved_refs(
|
||||
stmt,
|
||||
name,
|
||||
batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
imported_predicates,
|
||||
symbols,
|
||||
)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
|
@ -654,46 +645,26 @@ fn build_single_batch(
|
|||
/// Build a statement template with properly resolved predicate references
|
||||
fn build_statement_with_resolved_refs(
|
||||
stmt: &crate::lang::frontend_ast::StatementTmpl,
|
||||
caller_name: &str,
|
||||
current_batch_idx: usize,
|
||||
reference_map: &HashMap<String, (usize, usize)>,
|
||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
||||
imported_predicates: &HashMap<String, ImportedPredicateInfo>,
|
||||
symbols: &SymbolTable,
|
||||
) -> Result<StatementTmplBuilder, BatchingError> {
|
||||
let callee_name = &stmt.predicate.name;
|
||||
|
||||
// Resolve the predicate
|
||||
let predicate = if let Ok(native) = NativePredicate::from_str(callee_name) {
|
||||
Predicate::Native(native)
|
||||
} else if let Some(&(target_batch, target_idx)) = reference_map.get(callee_name) {
|
||||
// Local predicate in this document
|
||||
if target_batch == current_batch_idx {
|
||||
// Same batch - use BatchSelf
|
||||
Predicate::BatchSelf(target_idx)
|
||||
} else if target_batch < current_batch_idx {
|
||||
// Earlier batch - use Custom ref
|
||||
let batch = &existing_batches[target_batch];
|
||||
Predicate::Custom(CustomPredicateRef::new(batch.clone(), target_idx))
|
||||
} else {
|
||||
// Forward reference to a later batch should be impossible with the dependency-aware planner
|
||||
unreachable!(
|
||||
"Forward cross-batch reference: '{}' (batch {}) -> '{}' (batch {})",
|
||||
caller_name, current_batch_idx, callee_name, target_batch
|
||||
);
|
||||
}
|
||||
} else if let Some(imported) = imported_predicates.get(callee_name) {
|
||||
// Imported predicate from another batch
|
||||
Predicate::Custom(CustomPredicateRef::new(
|
||||
imported.batch.clone(),
|
||||
imported.index,
|
||||
))
|
||||
} else {
|
||||
// Unknown predicate
|
||||
return Err(BatchingError::Internal {
|
||||
message: format!("Unknown predicate reference: '{}'", callee_name),
|
||||
});
|
||||
// Resolve the predicate using the unified resolution function
|
||||
let context = ResolutionContext::Batch {
|
||||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
};
|
||||
|
||||
let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| {
|
||||
BatchingError::Internal {
|
||||
message: format!("Unknown predicate reference: '{}'", callee_name),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Build the statement template
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
|
||||
|
|
@ -709,24 +680,30 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::{
|
||||
lang::{
|
||||
frontend_ast::parse::parse_document, frontend_ast_split::split_predicate_if_needed,
|
||||
frontend_ast::parse::parse_document,
|
||||
frontend_ast_split::split_predicate_if_needed,
|
||||
frontend_ast_validate::{validate, ValidatedAST},
|
||||
parser::parse_podlang,
|
||||
},
|
||||
middleware::PredicateOrWildcard,
|
||||
middleware::{Predicate, PredicateOrWildcard},
|
||||
};
|
||||
|
||||
fn parse_predicates(input: &str) -> Vec<CustomPredicateDef> {
|
||||
/// Helper: parse and validate input, returning predicates and symbol table
|
||||
fn parse_and_validate(input: &str) -> (Vec<CustomPredicateDef>, ValidatedAST) {
|
||||
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||
let validated = validate(document.clone(), &[]).expect("Failed to validate");
|
||||
|
||||
document
|
||||
let predicates = document
|
||||
.items
|
||||
.into_iter()
|
||||
.filter_map(|item| match item {
|
||||
crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
.collect();
|
||||
|
||||
(predicates, validated)
|
||||
}
|
||||
|
||||
/// Helper: wrap predicates into SplitResult (without actually splitting)
|
||||
|
|
@ -748,14 +725,14 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -772,14 +749,14 @@ mod tests {
|
|||
pred3(C) = AND(Equal(C["z"], 3))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -798,14 +775,14 @@ mod tests {
|
|||
pred5(E) = AND(Equal(E["v"], 5))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -828,14 +805,14 @@ mod tests {
|
|||
pred1(A) = AND(Equal(A["x"], 1))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -861,14 +838,14 @@ mod tests {
|
|||
pred2(B) = AND(pred1(B))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -902,14 +879,14 @@ mod tests {
|
|||
pred5(E) = AND(pred1(E))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
|
@ -949,7 +926,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
// Split the large predicate
|
||||
|
|
@ -966,7 +943,7 @@ mod tests {
|
|||
// That's 5 predicates, which spans 2 batches
|
||||
assert_eq!(total_preds, 5);
|
||||
|
||||
let result = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new());
|
||||
let result = batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let batches = result.unwrap();
|
||||
|
|
@ -995,14 +972,14 @@ mod tests {
|
|||
pred5(E) = AND(Equal(E["v"], 5))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let batches = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
)
|
||||
.expect("Planner should avoid forward cross-batch reference");
|
||||
|
||||
|
|
@ -1019,8 +996,13 @@ mod tests {
|
|||
fn test_empty_input() {
|
||||
let split_results: Vec<SplitResult> = vec![];
|
||||
let params = Params::default();
|
||||
// For empty input, we need an empty symbol table
|
||||
let empty_symbols = SymbolTable {
|
||||
predicates: HashMap::new(),
|
||||
wildcard_scopes: HashMap::new(),
|
||||
};
|
||||
|
||||
let result = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new());
|
||||
let result = batch_predicates(split_results, ¶ms, "TestBatch", &empty_symbols);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let batches = result.unwrap();
|
||||
|
|
@ -1035,14 +1017,14 @@ mod tests {
|
|||
pred2(B) = AND(Equal(B["y"], 2))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let batches = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1061,7 +1043,7 @@ mod tests {
|
|||
pred2(B) = AND(pred1(B))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params {
|
||||
max_custom_batch_size: 1, // force SCC > capacity
|
||||
..Default::default()
|
||||
|
|
@ -1071,7 +1053,7 @@ mod tests {
|
|||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
|
|
@ -1098,7 +1080,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
// Split and batch
|
||||
|
|
@ -1107,7 +1089,8 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed");
|
||||
all_split_results.push(result);
|
||||
}
|
||||
let batches = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new())
|
||||
let batches =
|
||||
batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
|
|
@ -1147,14 +1130,14 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let batches = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1195,7 +1178,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
// Split the predicate
|
||||
|
|
@ -1210,7 +1193,7 @@ mod tests {
|
|||
assert_eq!(split_results[0].predicates.len(), 2);
|
||||
assert!(split_results[0].chain_info.is_some());
|
||||
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new())
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
// Verify chain info
|
||||
|
|
@ -1259,7 +1242,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
// Split the predicate
|
||||
|
|
@ -1274,7 +1257,7 @@ mod tests {
|
|||
assert_eq!(split_results[0].predicates.len(), 3);
|
||||
assert!(split_results[0].chain_info.is_some());
|
||||
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new())
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
// Verify chain info
|
||||
|
|
@ -1320,7 +1303,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
// Split the predicate
|
||||
|
|
@ -1330,7 +1313,7 @@ mod tests {
|
|||
split_results.push(result);
|
||||
}
|
||||
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new())
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
// Try with wrong number of statements (3 instead of 6)
|
||||
|
|
@ -1363,14 +1346,14 @@ mod tests {
|
|||
my_pred(A) = AND(Equal(A["x"], 1))
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let batches = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
&HashMap::new(),
|
||||
validated.symbols(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1401,7 +1384,7 @@ mod tests {
|
|||
)
|
||||
"#;
|
||||
|
||||
let predicates = parse_predicates(input);
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let mut split_results = Vec::new();
|
||||
|
|
@ -1410,7 +1393,7 @@ mod tests {
|
|||
split_results.push(result);
|
||||
}
|
||||
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new())
|
||||
let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
let statements: Vec<Statement> = (0..6).map(test_statement).collect();
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -14,15 +15,124 @@ use crate::{
|
|||
frontend_ast::*,
|
||||
frontend_ast_batch::{self, PredicateBatches},
|
||||
frontend_ast_split,
|
||||
frontend_ast_validate::{PredicateKind, ValidatedAST},
|
||||
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
||||
},
|
||||
middleware::{
|
||||
self, containers, IntroPredicateRef, NativePredicate, Params, Predicate,
|
||||
PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||
StatementTmplArg as MWStatementTmplArg, Wildcard,
|
||||
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
||||
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
||||
},
|
||||
};
|
||||
|
||||
/// Context for predicate resolution - determines how local custom predicates are resolved
|
||||
pub enum ResolutionContext<'a> {
|
||||
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
|
||||
Request {
|
||||
batches: Option<&'a PredicateBatches>,
|
||||
},
|
||||
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
|
||||
Batch {
|
||||
current_batch_idx: usize,
|
||||
reference_map: &'a HashMap<String, (usize, usize)>,
|
||||
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
||||
},
|
||||
}
|
||||
|
||||
/// Resolve a predicate name to a Predicate using the symbol table
|
||||
pub fn resolve_predicate(
|
||||
pred_name: &str,
|
||||
symbols: &SymbolTable,
|
||||
context: &ResolutionContext,
|
||||
) -> Option<Predicate> {
|
||||
// 1. Try native predicate first
|
||||
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
return Some(Predicate::Native(native));
|
||||
}
|
||||
|
||||
// 2. Look up in symbol table
|
||||
if let Some(info) = symbols.predicates.get(pred_name) {
|
||||
let predicate = match &info.kind {
|
||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||
|
||||
PredicateKind::Custom { .. } => match context {
|
||||
ResolutionContext::Request { batches } => {
|
||||
let batches = batches.as_ref()?;
|
||||
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
|
||||
Predicate::Custom(pred_ref)
|
||||
}
|
||||
ResolutionContext::Batch {
|
||||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
} => resolve_local_predicate(
|
||||
pred_name,
|
||||
*current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
)?,
|
||||
},
|
||||
|
||||
PredicateKind::BatchImported { batch, index } => {
|
||||
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
|
||||
}
|
||||
|
||||
PredicateKind::IntroImported {
|
||||
name,
|
||||
verifier_data_hash,
|
||||
} => Predicate::Intro(IntroPredicateRef {
|
||||
name: name.clone(),
|
||||
args_len: info.public_arity,
|
||||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
};
|
||||
return Some(predicate);
|
||||
}
|
||||
|
||||
// 3. In batch context, also check reference_map for split chain pieces
|
||||
// (predicates created by splitting that aren't in the original symbol table)
|
||||
if let ResolutionContext::Batch {
|
||||
current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
} = context
|
||||
{
|
||||
if reference_map.contains_key(pred_name) {
|
||||
return resolve_local_predicate(
|
||||
pred_name,
|
||||
*current_batch_idx,
|
||||
reference_map,
|
||||
existing_batches,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
|
||||
fn resolve_local_predicate(
|
||||
pred_name: &str,
|
||||
current_batch_idx: usize,
|
||||
reference_map: &HashMap<String, (usize, usize)>,
|
||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Option<Predicate> {
|
||||
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
|
||||
if target_batch == current_batch_idx {
|
||||
Some(Predicate::BatchSelf(target_idx))
|
||||
} else if target_batch < current_batch_idx {
|
||||
let batch = &existing_batches[target_batch];
|
||||
Some(Predicate::Custom(CustomPredicateRef::new(
|
||||
batch.clone(),
|
||||
target_idx,
|
||||
)))
|
||||
} else {
|
||||
unreachable!(
|
||||
"Forward cross-batch reference should be impossible: {} -> {}",
|
||||
current_batch_idx, target_batch
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Shared lowering utilities
|
||||
// ============================================================================
|
||||
|
|
@ -33,37 +143,37 @@ use crate::{
|
|||
/// Lower a literal value from AST to middleware Value.
|
||||
///
|
||||
/// This is a pure conversion that cannot fail.
|
||||
pub fn lower_literal(lit: &LiteralValue) -> middleware::Value {
|
||||
pub fn lower_literal(lit: &LiteralValue) -> Value {
|
||||
match lit {
|
||||
LiteralValue::Int(i) => middleware::Value::from(i.value),
|
||||
LiteralValue::Bool(b) => middleware::Value::from(b.value),
|
||||
LiteralValue::String(s) => middleware::Value::from(s.value.clone()),
|
||||
LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash),
|
||||
LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point),
|
||||
LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()),
|
||||
LiteralValue::Int(i) => Value::from(i.value),
|
||||
LiteralValue::Bool(b) => Value::from(b.value),
|
||||
LiteralValue::String(s) => Value::from(s.value.clone()),
|
||||
LiteralValue::Raw(r) => Value::from(r.hash.hash),
|
||||
LiteralValue::PublicKey(pk) => Value::from(pk.point),
|
||||
LiteralValue::SecretKey(sk) => Value::from(sk.secret_key.clone()),
|
||||
LiteralValue::Array(a) => {
|
||||
let elements: Vec<_> = a.elements.iter().map(lower_literal).collect();
|
||||
let array = containers::Array::new(elements);
|
||||
middleware::Value::from(array)
|
||||
Value::from(array)
|
||||
}
|
||||
LiteralValue::Set(s) => {
|
||||
let elements: std::collections::HashSet<_> =
|
||||
s.elements.iter().map(lower_literal).collect();
|
||||
let set = containers::Set::new(elements);
|
||||
middleware::Value::from(set)
|
||||
Value::from(set)
|
||||
}
|
||||
LiteralValue::Dict(d) => {
|
||||
let pairs: std::collections::HashMap<_, _> = d
|
||||
let pairs: HashMap<_, _> = d
|
||||
.pairs
|
||||
.iter()
|
||||
.map(|pair| {
|
||||
let key = middleware::Key::from(pair.key.value.as_str());
|
||||
let key = Key::from(pair.key.value.as_str());
|
||||
let value = lower_literal(&pair.value);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
let dict = containers::Dictionary::new(pairs);
|
||||
middleware::Value::from(dict)
|
||||
Value::from(dict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -151,41 +261,18 @@ impl<'a> Lowerer<'a> {
|
|||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build map of imported predicates for batching
|
||||
let imported_predicates = self.build_imported_predicates_map();
|
||||
|
||||
// Use the new batching module to pack predicates into batches
|
||||
// Pass the symbol table for unified predicate resolution
|
||||
let batches = frontend_ast_batch::batch_predicates(
|
||||
custom_predicates,
|
||||
self.params,
|
||||
&batch_name,
|
||||
&imported_predicates,
|
||||
self.validated.symbols(),
|
||||
)?;
|
||||
|
||||
Ok(Some(batches))
|
||||
}
|
||||
|
||||
fn build_imported_predicates_map(
|
||||
&self,
|
||||
) -> HashMap<String, frontend_ast_batch::ImportedPredicateInfo> {
|
||||
let symbols = self.validated.symbols();
|
||||
let mut imported = HashMap::new();
|
||||
|
||||
for (name, info) in &symbols.predicates {
|
||||
if let PredicateKind::BatchImported { batch, index } = &info.kind {
|
||||
imported.insert(
|
||||
name.clone(),
|
||||
frontend_ast_batch::ImportedPredicateInfo {
|
||||
batch: batch.clone(),
|
||||
index: *index,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
imported
|
||||
}
|
||||
|
||||
fn lower_request(
|
||||
&self,
|
||||
batches: Option<&PredicateBatches>,
|
||||
|
|
@ -232,42 +319,13 @@ impl<'a> Lowerer<'a> {
|
|||
let pred_name = &stmt.predicate.name;
|
||||
let symbols = self.validated.symbols();
|
||||
|
||||
// Resolve predicate - for request statements, local custom predicates
|
||||
// must be resolved to CustomPredicateRef (not BatchSelf)
|
||||
let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
Predicate::Native(native)
|
||||
} else if let Some(info) = symbols.predicates.get(pred_name) {
|
||||
match &info.kind {
|
||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||
PredicateKind::Custom { .. } => {
|
||||
// Local custom predicates - resolve to CustomPredicateRef
|
||||
let batches = batches.ok_or_else(|| LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
})?;
|
||||
let pred_ref = batches.predicate_ref_by_name(pred_name).ok_or_else(|| {
|
||||
// Resolve predicate using the unified resolution function
|
||||
let context = ResolutionContext::Request { batches };
|
||||
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
|
||||
LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
}
|
||||
})?;
|
||||
Predicate::Custom(pred_ref)
|
||||
}
|
||||
PredicateKind::BatchImported { batch, index } => {
|
||||
Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index))
|
||||
}
|
||||
PredicateKind::IntroImported {
|
||||
name,
|
||||
verifier_data_hash,
|
||||
} => Predicate::Intro(IntroPredicateRef {
|
||||
name: name.clone(),
|
||||
args_len: info.public_arity,
|
||||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
return Err(LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
});
|
||||
};
|
||||
|
||||
// Create a builder with the resolved predicate and desugar
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
|
|
@ -291,7 +349,7 @@ impl<'a> Lowerer<'a> {
|
|||
.get(&root_name)
|
||||
.expect("Root wildcard not found");
|
||||
let wildcard = Wildcard::new(root_name, *root_index);
|
||||
let key = middleware::Key::from(key_str.as_str());
|
||||
let key = Key::from(key_str.as_str());
|
||||
MWStatementTmplArg::AnchoredKey(wildcard, key)
|
||||
}
|
||||
};
|
||||
|
|
@ -646,4 +704,57 @@ mod tests {
|
|||
assert_eq!(batches.total_predicate_count(), 5);
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intro_predicate_in_custom_predicate() {
|
||||
use hex::ToHex;
|
||||
|
||||
use crate::middleware::EMPTY_HASH;
|
||||
|
||||
// Import an intro predicate and use it inside a custom predicate definition
|
||||
let intro_hash = EMPTY_HASH.encode_hex::<String>();
|
||||
let input = format!(
|
||||
r#"
|
||||
use intro external_check(X) from 0x{intro_hash}
|
||||
|
||||
my_pred(A) = AND (
|
||||
Equal(A["foo"], 42)
|
||||
external_check(A)
|
||||
)
|
||||
"#
|
||||
);
|
||||
|
||||
let params = Params::default();
|
||||
|
||||
// Parse, validate, and lower
|
||||
let parsed = parse_podlang(&input).expect("Failed to parse");
|
||||
let document =
|
||||
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
|
||||
let validated = validate(document, &[]).expect("Failed to validate");
|
||||
let result = lower(validated, ¶ms, "test_batch".to_string());
|
||||
|
||||
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
let batch = expect_batch(&lowered);
|
||||
|
||||
// Should have one custom predicate
|
||||
assert_eq!(batch.predicates().len(), 1);
|
||||
|
||||
let pred = &batch.predicates()[0];
|
||||
assert_eq!(pred.name, "my_pred");
|
||||
// 2 statements: Equal and external_check
|
||||
assert_eq!(pred.statements().len(), 2);
|
||||
|
||||
// Verify the second statement is an intro predicate reference
|
||||
let intro_stmt = &pred.statements()[1];
|
||||
match intro_stmt.pred_or_wc() {
|
||||
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
||||
assert_eq!(intro_ref.name, "external_check");
|
||||
assert_eq!(intro_ref.args_len, 1);
|
||||
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
||||
}
|
||||
other => panic!("Expected Intro predicate, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue