Remove batch splitting system (#475)

* First pass at removing batch splitting

* Refactor to separate module loading from request parsing

* Consolidate module functionality

* Tidy up comments

* Use array of modules instead of HashMap

* Formatting

* Use module hashes when importing modules
This commit is contained in:
Rob Knight 2026-02-09 10:31:47 +01:00 committed by GitHub
parent 5dab8195b4
commit acab26e5c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1425 additions and 1938 deletions

View file

@ -1,52 +1,69 @@
//! Lowering from frontend AST to middleware structures
//!
//! This module converts validated frontend AST to middleware data structures.
//! Supports automatic predicate splitting and multi-batch packing.
//! Supports automatic predicate splitting.
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use crate::{
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
lang::{
frontend_ast::*,
frontend_ast_batch::{self, PredicateBatches},
frontend_ast_split,
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
module, Module,
},
middleware::{
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params,
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value,
Wildcard,
},
};
/// Context for predicate resolution - determines how local custom predicates are resolved
/// Context for predicate resolution - determines how predicates are resolved
pub enum ResolutionContext<'a> {
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
Request {
batches: Option<&'a PredicateBatches>,
},
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
Batch {
current_batch_idx: usize,
reference_map: &'a HashMap<String, (usize, usize)>,
existing_batches: &'a [Arc<CustomPredicateBatch>],
/// Request context: predicates resolve via imports only (no local definitions)
Request,
/// Module context: local predicates resolve to BatchSelf
Module {
/// Maps predicate name to index within the module
reference_map: &'a HashMap<String, usize>,
/// Name of the custom predicate being defined (for wildcard scope lookup)
custom_predicate_name: &'a str,
},
}
/// Resolve a predicate reference to a Predicate using the symbol table
pub fn resolve_predicate_ref(
pred_ref: &PredicateRef,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<PredicateOrWildcard> {
match pred_ref {
PredicateRef::Qualified { module, predicate } => {
// Look up the module in the imported_modules
let imported_module = symbols.imported_modules.get(&module.name)?;
// Find the predicate index in the module
let idx = *imported_module.predicate_index.get(&predicate.name)?;
Some(PredicateOrWildcard::Predicate(Predicate::Custom(
CustomPredicateRef::new(imported_module.batch.clone(), idx),
)))
}
PredicateRef::Local(id) => resolve_predicate(&id.name, symbols, context),
}
}
/// Resolve a predicate name to a Predicate using the symbol table
pub fn resolve_predicate(
pred_name: &str,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<PredicateOrWildcard> {
// 0. Try wildcard first
if let ResolutionContext::Batch {
// 0. Try wildcard first (only in module context where we're defining predicates)
if let ResolutionContext::Module {
custom_predicate_name,
..
} = context
@ -69,28 +86,35 @@ pub fn resolve_predicate(
PredicateKind::Native(np) => Predicate::Native(*np),
PredicateKind::Custom { .. } => match context {
ResolutionContext::Request { batches } => {
let batches = batches.as_ref()?;
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
Predicate::Custom(pred_ref)
ResolutionContext::Request => {
// Requests can't define local predicates, so this shouldn't happen
return None;
}
ResolutionContext::Module { reference_map, .. } => {
resolve_local_predicate(pred_name, reference_map)?
}
ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
..
} => resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
)?,
},
PredicateKind::BatchImported { batch, index } => {
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
}
PredicateKind::ModuleImported {
module_name,
predicate_index,
..
} => {
// Look up the module in the imported_modules
let module = symbols
.imported_modules
.get(module_name)
.expect("Module should exist if ModuleImported predicate kind exists");
Predicate::Custom(CustomPredicateRef::new(
module.batch.clone(),
*predicate_index,
))
}
PredicateKind::IntroImported {
name,
verifier_data_hash,
@ -103,51 +127,25 @@ pub fn resolve_predicate(
return Some(PredicateOrWildcard::Predicate(predicate));
}
// 3. In batch context, also check reference_map for split chain pieces
// 3. In module context, also check reference_map for split chain pieces
// (predicates created by splitting that aren't in the original symbol table)
if let ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
..
} = context
{
if let ResolutionContext::Module { reference_map, .. } = context {
if reference_map.contains_key(pred_name) {
return resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
)
.map(PredicateOrWildcard::Predicate);
return resolve_local_predicate(pred_name, reference_map)
.map(PredicateOrWildcard::Predicate);
}
}
None
}
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
/// Resolve a local predicate (one in this module or a split chain piece) using the reference_map
fn resolve_local_predicate(
pred_name: &str,
current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
reference_map: &HashMap<String, usize>,
) -> Option<Predicate> {
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
if target_batch == current_batch_idx {
Some(Predicate::BatchSelf(target_idx))
} else if target_batch < current_batch_idx {
let batch = &existing_batches[target_batch];
Some(Predicate::Custom(CustomPredicateRef::new(
batch.clone(),
target_idx,
)))
} else {
unreachable!(
"Forward cross-batch reference should be impossible: {} -> {}",
current_batch_idx, target_batch
);
}
let &idx = reference_map.get(pred_name)?;
Some(Predicate::BatchSelf(idx))
}
// ============================================================================
@ -155,7 +153,7 @@ fn resolve_local_predicate(
// ============================================================================
// These functions convert AST types to middleware/builder types and are used
// by both the request lowering (in this module) and predicate batching
// (in frontend_ast_batch).
// (in module.rs).
/// Lower a literal value from AST to middleware Value.
///
@ -215,38 +213,37 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
}
}
/// Result of lowering: optional custom predicate batches and optional request
///
/// A Podlang file can contain:
/// - Just custom predicates (batches: Some, request: None)
/// - Just a request (batches: None, request: Some)
/// - Both (batches: Some, request: Some)
/// - Neither (batches: None, request: None) - just imports
#[derive(Debug, Clone)]
pub struct LoweredOutput {
pub batches: Option<PredicateBatches>,
pub request: Option<crate::frontend::PodRequest>,
}
pub use crate::lang::error::LoweringError;
/// Lower a validated AST to middleware structures
/// Lower a validated module AST to a Module
///
/// Returns both the custom predicate batch (if any) and the request (if any).
/// At least one will be Some if the document contains custom predicates or a request.
pub fn lower(
/// The validated AST must have been validated in Module mode.
pub fn lower_module(
validated: ValidatedAST,
params: &Params,
batch_name: String,
) -> Result<LoweredOutput, LoweringError> {
module_name: &str,
) -> Result<Module, LoweringError> {
if !validated.diagnostics().is_empty() {
// For now, treat any diagnostics as errors
// In future we could allow warnings
return Err(LoweringError::ValidationErrors);
}
let lowerer = Lowerer::new(validated, params);
lowerer.lower(batch_name)
lowerer.lower_module(module_name)
}
/// Lower a validated request AST to a PodRequest
///
/// The validated AST must have been validated in Request mode.
pub fn lower_request(
validated: ValidatedAST,
params: &Params,
) -> Result<crate::frontend::PodRequest, LoweringError> {
if !validated.diagnostics().is_empty() {
return Err(LoweringError::ValidationErrors);
}
let lowerer = Lowerer::new(validated, params);
lowerer.lower_request()
}
struct Lowerer<'a> {
@ -259,52 +256,33 @@ impl<'a> Lowerer<'a> {
Self { validated, params }
}
fn lower(self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
// Lower custom predicates (if any) - now supports multiple batches
let batches = self.lower_batches(batch_name)?;
// Lower request (if any) - pass batches so refs can be resolved
let request = self.lower_request(batches.as_ref())?;
Ok(LoweredOutput { batches, request })
}
fn lower_batches(&self, batch_name: String) -> Result<Option<PredicateBatches>, LoweringError> {
fn lower_module(self, module_name: &str) -> Result<Module, LoweringError> {
// Extract and split custom predicates from document
let custom_predicates = self.extract_and_split_predicates()?;
// If no custom predicates, return None
if custom_predicates.is_empty() {
return Ok(None);
}
// Use the new batching module to pack predicates into batches
// Pass the symbol table for unified predicate resolution
let batches = frontend_ast_batch::batch_predicates(
// Build the module from split predicates
let module = module::build_module(
custom_predicates,
self.params,
&batch_name,
module_name,
self.validated.symbols(),
)?;
Ok(Some(batches))
Ok(module)
}
fn lower_request(
&self,
batches: Option<&PredicateBatches>,
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
fn lower_request(self) -> Result<crate::frontend::PodRequest, LoweringError> {
let doc = self.validated.document();
// Find request definition (if any)
let request_def = doc.items.iter().find_map(|item| match item {
DocumentItem::RequestDef(req) => Some(req),
_ => None,
});
let Some(request_def) = request_def else {
return Ok(None);
};
// Find request definition
let request_def = doc
.items
.iter()
.find_map(|item| match item {
DocumentItem::RequestDef(req) => Some(req),
_ => None,
})
.expect("Request mode validation ensures REQUEST block exists");
// Build wildcard map from all wildcards used in the request statements
let wildcard_map = self.build_request_wildcard_map(request_def);
@ -312,18 +290,17 @@ impl<'a> Lowerer<'a> {
// Lower each statement to middleware templates, resolving predicates
let mut request_templates = Vec::new();
for stmt in &request_def.statements {
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?;
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map)?;
request_templates.push(mw_stmt);
}
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
Ok(crate::frontend::PodRequest::new(request_templates))
}
fn lower_request_statement(
&self,
stmt: &StatementTmpl,
wildcard_map: &HashMap<String, usize>,
batches: Option<&PredicateBatches>,
) -> Result<MWStatementTmpl, LoweringError> {
// Enforce argument count limit for request statements
if stmt.args.len() > Params::max_statement_args() {
@ -333,16 +310,16 @@ impl<'a> Lowerer<'a> {
});
}
let pred_name = &stmt.predicate.name;
let symbols = self.validated.symbols();
// Resolve predicate using the unified resolution function
let context = ResolutionContext::Request { batches };
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: pred_name.clone(),
}
})?;
let context = ResolutionContext::Request;
let predicate =
resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: format!("{}", stmt.predicate),
}
})?;
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate.clone());
@ -453,31 +430,24 @@ impl<'a> Lowerer<'a> {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::lang::{
frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang,
frontend_ast::parse::parse_document,
frontend_ast_validate::{validate, ParseMode},
parser::parse_podlang,
};
fn parse_validate_and_lower(
fn parse_validate_and_lower_module(
input: &str,
params: &Params,
) -> Result<LoweredOutput, LoweringError> {
) -> Result<Module, LoweringError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
let validated = validate(document, &[]).expect("Failed to validate");
lower(validated, params, "test_batch".to_string())
}
// Helper to get the first batch from the output (expecting it to exist)
fn expect_batch(
output: &LoweredOutput,
) -> &std::sync::Arc<crate::middleware::CustomPredicateBatch> {
output
.batches
.as_ref()
.expect("Expected batches to be present")
.first_batch()
.expect("Expected at least one batch")
let validated =
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
lower_module(validated, params, "test_batch")
}
#[test]
@ -489,16 +459,16 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
if let Err(e) = &result {
eprintln!("Error: {:?}", e);
}
assert!(result.is_ok());
let lowered = result.unwrap();
assert_eq!(expect_batch(&lowered).predicates().len(), 1);
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 1);
let pred = &expect_batch(&lowered).predicates()[0];
let pred = &module.batch.predicates()[0];
assert_eq!(pred.name, "my_pred");
assert_eq!(pred.args_len(), 2);
assert_eq!(pred.wildcard_names().len(), 2);
@ -515,11 +485,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
assert_eq!(pred.args_len(), 1); // Only A is public
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
}
@ -534,11 +504,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
assert!(pred.is_disjunction());
}
@ -556,23 +526,22 @@ mod tests {
"#;
let params = Params::default(); // max_custom_predicate_arity = 5
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
if let Err(e) = &result {
eprintln!("Splitting error: {:?}", e);
}
assert!(result.is_ok());
let lowered = result.unwrap();
let module = result.unwrap();
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
let batches = lowered.batches.as_ref().expect("Expected batches");
assert_eq!(batches.total_predicate_count(), 2);
assert_eq!(module.batch.predicates().len(), 2);
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
// my_pred_1 has 2 statements
// my_pred has 5 statements (4 + chain call)
// Just verify we have the right total statement counts
let batch = batches.first_batch().unwrap();
let total_statements: usize = batch
let total_statements: usize = module
.batch
.predicates()
.iter()
.map(|p| p.statements().len())
@ -593,11 +562,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 2);
}
#[test]
@ -613,11 +582,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred2 = &expect_batch(&lowered).predicates()[1];
let module = result.unwrap();
let pred2 = &module.batch.predicates()[1];
let stmt = &pred2.statements()[0];
// Should be BatchSelf(0) referring to pred1
@ -638,7 +607,7 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
}
@ -651,11 +620,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
let stmt = &pred.statements()[0];
// Should desugar to the Contains predicate
@ -677,7 +646,7 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
}
@ -706,18 +675,18 @@ mod tests {
let parsed = parse_podlang(&input).expect("Failed to parse");
let document =
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
let validated = validate(document, &[]).expect("Failed to validate");
let result = lower(validated, &params, "test_batch".to_string());
let validated =
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
let result = lower_module(validated, &params, "test_batch");
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
let lowered = result.unwrap();
let batch = expect_batch(&lowered);
let module = result.unwrap();
// Should have one custom predicate
assert_eq!(batch.predicates().len(), 1);
assert_eq!(module.batch.predicates().len(), 1);
let pred = &batch.predicates()[0];
let pred = &module.batch.predicates()[0];
assert_eq!(pred.name, "my_pred");
// 2 statements: Equal and external_check
assert_eq!(pred.statements().len(), 2);