Remove batch splitting system (#475)

* First pass at removing batch splitting

* Refactor to separate module loading from request parsing

* Consolidate module functionality

* Tidy up comments

* Use array of modules instead of HashMap

* Formatting

* Use module hashes when importing modules
This commit is contained in:
Rob Knight 2026-02-09 10:31:47 +01:00 committed by GitHub
parent 5dab8195b4
commit acab26e5c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1425 additions and 1938 deletions

View file

@ -12,7 +12,7 @@ use std::{
use hex::ToHex;
use crate::{
lang::frontend_ast::*,
lang::{frontend_ast::*, Module},
middleware::{CustomPredicateBatch, Hash, NativePredicate},
};
@ -49,6 +49,8 @@ pub struct SymbolTable {
pub predicates: HashMap<String, PredicateInfo>,
/// Wildcard scopes for each custom predicate
pub wildcard_scopes: HashMap<String, WildcardScope>,
/// Imported modules (bound name → Module reference)
pub imported_modules: HashMap<String, Arc<Module>>,
}
/// Information about a predicate
@ -71,6 +73,11 @@ pub enum PredicateKind {
batch: Arc<CustomPredicateBatch>,
index: usize,
},
ModuleImported {
module_name: String,
predicate_name: String,
predicate_index: usize,
},
IntroImported {
name: String,
verifier_data_hash: Hash,
@ -107,39 +114,45 @@ pub enum DiagnosticLevel {
pub use crate::lang::error::ValidationError;
/// Validate an AST document
/// Mode for parsing/validation - determines what constructs are allowed
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseMode {
/// Module mode: predicate definitions allowed, REQUEST block not allowed
Module,
/// Request mode: REQUEST block required, predicate definitions not allowed
Request,
}
/// Validate an AST document in the given mode
pub fn validate(
document: Document,
available_batches: &[Arc<CustomPredicateBatch>],
available_modules: &HashMap<Hash, Arc<Module>>,
mode: ParseMode,
) -> Result<ValidatedAST, ValidationError> {
let validator = Validator::new(available_batches);
let validator = Validator::new(available_modules, mode);
validator.validate(document)
}
struct Validator {
available_batches: HashMap<String, Arc<CustomPredicateBatch>>,
available_modules: HashMap<Hash, Arc<Module>>,
symbols: SymbolTable,
diagnostics: Vec<Diagnostic>,
custom_predicate_count: usize,
mode: ParseMode,
}
impl Validator {
fn new(batches: &[Arc<CustomPredicateBatch>]) -> Self {
let mut available_batches = HashMap::new();
for batch in batches {
// Store by hex ID for lookup
let id = format!("0x{}", batch.id().encode_hex::<String>());
available_batches.insert(id, batch.clone());
}
fn new(available_modules: &HashMap<Hash, Arc<Module>>, mode: ParseMode) -> Self {
Self {
available_batches,
available_modules: available_modules.clone(),
symbols: SymbolTable {
predicates: HashMap::new(),
wildcard_scopes: HashMap::new(),
imported_modules: HashMap::new(),
},
diagnostics: Vec::new(),
custom_predicate_count: 0,
mode,
}
}
@ -160,25 +173,36 @@ impl Validator {
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
// First process imports
for item in &document.items {
if let DocumentItem::UseBatchStatement(use_stmt) = item {
self.process_use_batch_statement(use_stmt)?;
if let DocumentItem::UseModuleStatement(use_stmt) = item {
self.process_use_module_statement(use_stmt)?;
}
if let DocumentItem::UseIntroStatement(use_stmt) = item {
self.process_use_intro_statement(use_stmt)?;
}
}
// Then process custom predicate definitions
// Check mode constraints for predicate definitions
let mut has_predicates = false;
for item in &document.items {
if let DocumentItem::CustomPredicateDef(pred_def) = item {
if self.mode == ParseMode::Request {
return Err(ValidationError::PredicatesNotAllowedInRequest {
span: pred_def.span,
});
}
has_predicates = true;
self.process_custom_predicate_def(pred_def)?;
}
}
// Check for multiple REQUEST definitions (only one allowed)
// Check mode constraints for REQUEST blocks
let mut has_request = false;
let mut first_request_span = None;
for item in &document.items {
if let DocumentItem::RequestDef(req) = item {
if self.mode == ParseMode::Module {
return Err(ValidationError::RequestNotAllowedInModule { span: req.span });
}
if let Some(first_span) = first_request_span {
return Err(ValidationError::MultipleRequestDefinitions {
first_span: Some(first_span),
@ -186,61 +210,44 @@ impl Validator {
});
}
first_request_span = req.span;
has_request = true;
}
}
// Enforce that modules have predicates and requests have a REQUEST block
match self.mode {
ParseMode::Module if !has_predicates => {
return Err(ValidationError::NoPredicatesInModule);
}
ParseMode::Request if !has_request => {
return Err(ValidationError::NoRequestBlock);
}
_ => {}
}
Ok(())
}
fn process_use_batch_statement(
fn process_use_module_statement(
&mut self,
use_stmt: &UseBatchStatement,
use_stmt: &UseModuleStatement,
) -> Result<(), ValidationError> {
let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::<String>());
let alias = &use_stmt.alias.name;
let hash = &use_stmt.hash.hash;
let batch = self.available_batches.get(&batch_id).ok_or_else(|| {
ValidationError::BatchNotFound {
id: batch_id.clone(),
span: use_stmt.batch_ref.span,
}
})?;
// Check if the module is available by hash
let module =
self.available_modules
.get(hash)
.ok_or_else(|| ValidationError::ModuleNotFound {
name: hash.encode_hex::<String>(),
span: use_stmt.span,
})?;
if use_stmt.imports.len() != batch.predicates().len() {
return Err(ValidationError::ImportArityMismatch {
expected: batch.predicates().len(),
found: use_stmt.imports.len(),
span: use_stmt.span,
});
}
for (i, import) in use_stmt.imports.iter().enumerate() {
if let ImportName::Named(name) = import {
if self.symbols.predicates.contains_key(name) {
return Err(ValidationError::DuplicateImport {
name: name.clone(),
span: use_stmt.span,
});
}
let pred = &batch.predicates()[i];
// CustomPredicate has args_len (public args) and wildcard_names (total args)
let total_arity = pred.wildcard_names.len();
let public_arity = pred.args_len;
self.symbols.predicates.insert(
name.clone(),
PredicateInfo {
kind: PredicateKind::BatchImported {
batch: batch.clone(),
index: i,
},
arity: total_arity,
public_arity,
source_span: use_stmt.span,
},
);
}
}
// Store the module keyed by alias for later qualified name resolution
self.symbols
.imported_modules
.insert(alias.clone(), module.clone());
Ok(())
}
@ -435,7 +442,11 @@ impl Validator {
stmt: &StatementTmpl,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
let pred_name = &stmt.predicate.name;
let pred_name = stmt.predicate.predicate_name();
let pred_span = match &stmt.predicate {
PredicateRef::Local(id) => id.span,
PredicateRef::Qualified { predicate, .. } => predicate.span,
};
let wc_names = match wildcard_context {
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
@ -444,31 +455,65 @@ impl Validator {
self.validate_wildcard_names(&wc_names)?;
// Check if predicate exists
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
Some(info.clone())
} else if wc_names.contains(pred_name) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.clone(),
span: stmt.predicate.span,
});
let pred_info = match &stmt.predicate {
PredicateRef::Qualified { module, predicate } => {
// Look up the predicate in the imported module
let module_name = &module.name;
if let Some(imported_module) = self.symbols.imported_modules.get(module_name) {
// Find the predicate in the module
if let Some(&idx) = imported_module.predicate_index.get(&predicate.name) {
let module_pred = &imported_module.batch.predicates()[idx];
Some(PredicateInfo {
kind: PredicateKind::ModuleImported {
module_name: module_name.clone(),
predicate_name: predicate.name.clone(),
predicate_index: idx,
},
arity: module_pred.wildcard_names.len(),
public_arity: module_pred.args_len,
source_span: None,
})
} else {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module_name, predicate.name),
span: pred_span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module_name.clone(),
span: module.span,
});
}
}
PredicateRef::Local(_) => {
if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
Some(info.clone())
} else if wc_names.contains(&pred_name.to_string()) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.to_string(),
span: pred_span,
});
}
}
};
if let Some(ref pred_info) = pred_info {
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.clone(),
predicate: pred_name.to_string(),
expected: expected_arity,
found: stmt.args.len(),
span: stmt.span,
@ -491,13 +536,15 @@ impl Validator {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
Some(PredicateKind::Custom { .. })
| Some(PredicateKind::BatchImported { .. })
| Some(PredicateKind::ModuleImported { .. })
) {
for arg in &stmt.args {
match arg {
StatementTmplArg::AnchoredKey(_) => {
return Err(ValidationError::InvalidArgumentType {
predicate: stmt.predicate.name.clone(),
predicate: stmt.predicate.predicate_name().to_string(),
span: stmt.span,
});
}
@ -552,25 +599,30 @@ impl Validator {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::{
lang::{frontend_ast::parse::parse_document, parser::parse_podlang},
lang::{frontend_ast::parse::parse_document, parser::parse_podlang, Module},
middleware::{CustomPredicate, Params, EMPTY_HASH},
};
fn parse_and_validate(
fn parse_and_validate_module(
input: &str,
batches: &[Arc<CustomPredicateBatch>],
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, batches)
validate(document, modules, ParseMode::Module)
}
#[test]
fn test_validate_empty() {
let result = parse_and_validate("", &[]);
assert!(result.is_ok());
fn parse_and_validate_request(
input: &str,
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, modules, ParseMode::Request)
}
#[test]
@ -578,7 +630,7 @@ mod tests {
let input = r#"REQUEST(
Equal(A["foo"], B["bar"])
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
@ -589,7 +641,7 @@ mod tests {
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
@ -602,7 +654,7 @@ mod tests {
let input = r#"REQUEST(
UndefinedPred(A, B)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UndefinedPredicate { .. })
@ -616,7 +668,7 @@ mod tests {
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
);
@ -627,7 +679,7 @@ mod tests {
let input = r#"REQUEST(
Equal(A, B, C)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::ArgumentCountMismatch { .. })
@ -640,7 +692,7 @@ mod tests {
my_pred(A) = AND (Equal(A["x"], 1))
my_pred(B) = AND (Equal(B["y"], 2))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicatePredicate { .. })
@ -652,7 +704,7 @@ mod tests {
let input = r#"
my_pred(A, A) = AND (Equal(A["x"], 1))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateWildcard { .. })
@ -664,7 +716,7 @@ mod tests {
let input = r#"
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::WildcardPredicateNameCollision { .. })
@ -673,16 +725,36 @@ mod tests {
#[test]
fn test_custom_predicate_with_anchored_key() {
let input = r#"
my_pred(A, B) = AND (
Equal(A["foo"], B["bar"])
)
// First create a module with the predicate
let params = Params::default();
let pred = CustomPredicate::and(
&params,
"my_pred".to_string(),
vec![],
2,
vec!["A".to_string(), "B".to_string()],
)
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
// Test that passing anchored key to custom predicate fails
let input = format!(
r#"
use module 0x{} as testmod
REQUEST(
my_pred(X["key"], Y)
testmod::my_pred(X["key"], Y)
)
"#;
let result = parse_and_validate(input, &[]);
"#,
module_hash
);
let result = parse_and_validate_request(&input, &available_modules);
assert!(matches!(
result,
Err(ValidationError::InvalidArgumentType { .. })
@ -695,12 +767,12 @@ mod tests {
pred1(A) = AND (
pred2(A)
)
pred2(B) = AND (
Equal(B["x"], 1)
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
}
@ -712,7 +784,7 @@ mod tests {
Equal(B["z"], C["w"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
@ -743,7 +815,7 @@ mod tests {
span: None,
})],
};
let result = validate(document, &[]);
let result = validate(document, &HashMap::new(), ParseMode::Module);
assert!(matches!(
result,
Err(ValidationError::EmptyStatementList { .. })
@ -756,7 +828,7 @@ mod tests {
REQUEST(Equal(A["x"], 1))
REQUEST(Equal(B["y"], 2))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::MultipleRequestDefinitions { .. })
@ -764,10 +836,14 @@ mod tests {
}
#[test]
fn test_use_statement() {
fn test_use_module_statement() {
use std::sync::Arc;
use hex::ToHex;
let params = Params::default();
// Create a batch to import
// Create a module to import
let pred = CustomPredicate::and(
&params,
"imported".to_string(),
@ -778,28 +854,33 @@ mod tests {
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
let batch_id = batch.id().encode_hex::<String>();
let input = format!(
r#"
use batch imported_pred from 0x{}
use module 0x{} as testmod
use intro intro_pred() from 0x{}
REQUEST(
imported_pred(A, B)
testmod::imported(A, B)
intro_pred()
)
"#,
batch_id,
module_hash,
EMPTY_HASH.encode_hex::<String>()
);
let result = parse_and_validate(&input, &[batch]);
let result = parse_and_validate_request(&input, &available_modules);
assert!(result.is_ok());
let validated = result.unwrap();
assert!(validated.symbols.predicates.contains_key("imported_pred"));
// Module predicates are accessed via qualified names, so no local binding
assert!(validated.symbols.predicates.contains_key("intro_pred"));
assert!(validated.symbols.imported_modules.contains_key("testmod"));
}
#[test]
@ -809,7 +890,7 @@ mod tests {
DictContains(D, K, V)
SetNotContains(S, E)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
}