diff --git a/examples/main_pod_points.rs b/examples/main_pod_points.rs index 2b5f257..08c238f 100644 --- a/examples/main_pod_points.rs +++ b/examples/main_pod_points.rs @@ -16,7 +16,7 @@ use pod2::{ primitives::ec::schnorr::SecretKey, signer::Signer, }, frontend::{MainPodBuilder, Operation, SignedDictBuilder}, - lang::parse, + lang::load_module, middleware::{MainPodProver, Params, VDSet}, }; @@ -88,10 +88,8 @@ fn main() -> Result<(), Box> { game_pk = game_pk, ); println!("# custom predicate batch:{}", input); - let batch = parse(&input, ¶ms, &[])? - .first_batch() - .expect("Expected batch") - .clone(); + let module = load_module(&input, "points_module", ¶ms, vec![])?; + let batch = module.batch.clone(); let points_pred = batch.predicate_ref_by_name("points").unwrap(); let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap(); diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index e6f5329..2f25d38 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -836,7 +836,7 @@ pub mod tests { frontend::{ self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, }, - lang::parse, + lang::load_module, middleware::{ self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _, DEFAULT_VD_LIST, DEFAULT_VD_SET, @@ -1165,7 +1165,7 @@ pub mod tests { #[test] fn test_undetermined_values() { let params = Default::default(); - let batch = parse( + let module = load_module( r#" two_equal(x,y,z) = OR( Equal(x,y) @@ -1173,13 +1173,12 @@ pub mod tests { Equal(x,z) ) "#, + "test", ¶ms, - &[], + vec![], ) - .unwrap() - .first_batch() - .unwrap() - .clone(); + .unwrap(); + let batch = module.batch.clone(); let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); let cpr = CustomPredicateRef { batch, index: 0 }; let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap(); diff --git a/src/examples/custom.rs b/src/examples/custom.rs index b64b8a4..9a68b78 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - -use hex::ToHex; +use std::{collections::HashMap, sync::Arc}; use crate::{ frontend::{PodRequest, Result}, - lang::parse, + lang::{load_module, parse_request, Module}, middleware::{CustomPredicateBatch, Params}, }; @@ -32,11 +30,8 @@ pub fn eth_dos_batch(params: &Params) -> Result> { eth_dos_ind(src, dst, distance) ) "#; - let batch = parse(input, params, &[]) - .expect("lang parse") - .first_batch() - .expect("Expected batch") - .clone(); + let module = load_module(input, "eth_dos", params, vec![]).expect("lang parse"); + let batch = module.batch.clone(); println!("a.0. {}", batch.predicates()[0]); println!("a.1. {}", batch.predicates()[1]); println!("a.2. {}", batch.predicates()[2]); @@ -45,18 +40,26 @@ pub fn eth_dos_batch(params: &Params) -> Result> { } pub fn eth_dos_request() -> Result { + use hex::ToHex; + let batch = eth_dos_batch(&Params::default())?; - let batch_id = batch.id().encode_hex::(); + let eth_dos_module = Arc::new(Module::new(batch, HashMap::new())); + let module_hash = eth_dos_module.id().encode_hex::(); + let input = format!( r#" - use batch _, _, _, eth_dos from 0x{batch_id} + use module 0x{} as eth_dos REQUEST( - eth_dos(src, dst, distance) + eth_dos::eth_dos(src, dst, distance) ) "#, + module_hash ); - let parsed = parse(&input, &Params::default(), &[batch])?; - Ok(parsed.request) + Ok(parse_request( + &input, + &Params::default(), + &[eth_dos_module], + )?) } #[cfg(test)] diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 5a5775f..2b490f9 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -12,7 +12,7 @@ use crate::{ frontend::{ MainPod, MainPodBuilder, Operation, PodRequest, Result, SignedDict, SignedDictBuilder, }, - lang::parse, + lang::parse_request, middleware::{ self, containers::Set, hash_values, CustomPredicateRef, Params, Predicate, PublicKey, Signer as _, Statement, StatementArg, TypedValue, VDSet, Value, @@ -90,8 +90,7 @@ pub fn zu_kyc_pod_request(gov_signer: &Value, pay_signer: &Value) -> Result { assert_eq!(pred_ref, &predicate); diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index bb80438..bd411ef 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -632,7 +632,7 @@ mod tests { dict, examples::MOCK_VD_SET, frontend::{Operation as FrontendOp, SignedDictBuilder}, - lang::parse, + lang::load_module, }; #[test] @@ -756,18 +756,17 @@ mod tests { // pred_a accepts a Contains statement // pred_b accepts a pred_a statement (Custom statement from pred_a) - let parsed = parse( + let module = load_module( r#" pred_a(X) = AND(Contains(X, "k", 1)) pred_b(X) = AND(pred_a(X)) "#, + "test", ¶ms, - &[], + vec![], ) - .expect("parse predicates"); - let batch = parsed - .first_batch() - .expect("parse predicates should have a batch"); + .expect("load module"); + let batch = &module.batch; let mut builder = MultiPodBuilder::new(¶ms, vd_set); @@ -1484,20 +1483,19 @@ mod tests { let vd_set = &*MOCK_VD_SET; // Chain of predicates: each accepts the output of the previous - let parsed = parse( + let module = load_module( r#" pred_a(X) = AND(Contains(X, "k", 1)) pred_b(X) = AND(pred_a(X)) pred_c(X) = AND(pred_b(X)) pred_d(X) = AND(pred_c(X)) "#, + "test", ¶ms, - &[], + vec![], ) - .expect("parse predicates"); - let batch = parsed - .first_batch() - .expect("parse predicates should have a batch"); + .expect("load module"); + let batch = &module.batch; let mut builder = MultiPodBuilder::new(¶ms, vd_set); @@ -1612,7 +1610,7 @@ mod tests { // pred_a takes TWO custom statement arguments (b_out and c_out) // pred_b and pred_c each take a Contains // Note: AND clauses are newline-separated, not comma-separated - let parsed = parse( + let module = load_module( r#" pred_b(X) = AND(Contains(X, "k", 1)) pred_c(X) = AND(Contains(X, "k", 1)) @@ -1621,13 +1619,12 @@ mod tests { pred_c(Y) ) "#, + "test", ¶ms, - &[], + vec![], ) - .expect("parse predicates"); - let batch = parsed - .first_batch() - .expect("parse predicates should have a batch"); + .expect("load module"); + let batch = &module.batch; let mut builder = MultiPodBuilder::new(¶ms, vd_set); diff --git a/src/frontend/pod_request.rs b/src/frontend/pod_request.rs index c804610..bd4a32c 100644 --- a/src/frontend/pod_request.rs +++ b/src/frontend/pod_request.rs @@ -185,7 +185,7 @@ mod tests { zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET, }, frontend::{MainPodBuilder, Operation}, - lang::parse, + lang::parse_request, middleware::{Params, Value}, }; @@ -210,7 +210,7 @@ mod tests { assert!(request.exact_match_pod(&*kyc.pod).is_ok()); // This request does not match the POD, because the POD does not contain a NotEqual statement. - let non_matching_request = parse( + let non_matching_request = parse_request( r#" REQUEST( NotEqual(4, 5) @@ -219,8 +219,7 @@ mod tests { ¶ms, &[], ) - .unwrap() - .request; + .unwrap(); assert!(non_matching_request.exact_match_pod(&*kyc.pod).is_err()); } @@ -240,7 +239,7 @@ mod tests { println!("{pod}"); - let request = parse( + let request = parse_request( r#" REQUEST( SumOf(a, b, c) @@ -252,7 +251,7 @@ mod tests { ) .unwrap(); - let bindings = request.request.exact_match_pod(&*pod.pod).unwrap(); + let bindings = request.exact_match_pod(&*pod.pod).unwrap(); assert_eq!(*bindings.get("a").unwrap(), 10.into()); assert_eq!(*bindings.get("b").unwrap(), 9.into()); assert_eq!(*bindings.get("c").unwrap(), 1.into()); diff --git a/src/lang/error.rs b/src/lang/error.rs index 2ae7c25..3cf1eb6 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -50,8 +50,8 @@ pub enum ValidationError { span: Option, }, - #[error("Batch not found: {id}")] - BatchNotFound { id: String, span: Option }, + #[error("Module not found: {name}")] + ModuleNotFound { name: String, span: Option }, #[error("Undefined predicate: {name}")] UndefinedPredicate { name: String, span: Option }, @@ -91,6 +91,18 @@ pub enum ValidationError { #[error("Wildcard '{name}' collides with a predicate name")] WildcardPredicateNameCollision { name: String }, + + #[error("Predicate definitions are not allowed in requests")] + PredicatesNotAllowedInRequest { span: Option }, + + #[error("REQUEST block is not allowed in modules")] + RequestNotAllowedInModule { span: Option }, + + #[error("Modules must contain at least one predicate definition")] + NoPredicatesInModule, + + #[error("Requests must contain a REQUEST block")] + NoRequestBlock, } /// Lowering errors from frontend AST lowering to middleware diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index d820d7f..4ca7fe4 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -18,17 +18,17 @@ pub struct Document { /// Top-level items that can appear in a document #[derive(Debug, Clone, PartialEq)] pub enum DocumentItem { - UseBatchStatement(UseBatchStatement), + UseModuleStatement(UseModuleStatement), UseIntroStatement(UseIntroStatement), CustomPredicateDef(CustomPredicateDef), RequestDef(RequestDef), } -/// Import statement: `use batch pred1, pred2, _ from 0x...` +/// Module import statement: `use module 0xHASH as alias` #[derive(Debug, Clone, PartialEq)] -pub struct UseBatchStatement { - pub imports: Vec, - pub batch_ref: HashHex, +pub struct UseModuleStatement { + pub hash: HashHex, + pub alias: Identifier, pub span: Option, } @@ -40,19 +40,6 @@ pub struct UseIntroStatement { pub intro_hash: HashHex, pub span: Option, } -/// Individual import name (identifier or unused "_") -#[derive(Debug, Clone, PartialEq)] -pub enum ImportName { - Named(String), - Unused, // "_" -} - -/// Batch reference (hash) -#[derive(Debug, Clone, PartialEq)] -pub struct BatchRef { - pub hash: HashHex, - pub span: Option, -} /// Intro predicate reference (hash) #[derive(Debug, Clone, PartialEq)] @@ -96,11 +83,33 @@ pub enum ConjunctionType { /// Statement template: predicate call with arguments #[derive(Debug, Clone, PartialEq)] pub struct StatementTmpl { - pub predicate: Identifier, + pub predicate: PredicateRef, pub args: Vec, pub span: Option, } +/// Reference to a predicate (local or qualified with module name) +#[derive(Debug, Clone, PartialEq)] +pub enum PredicateRef { + /// Unqualified name (local or native predicate) + Local(Identifier), + /// Qualified name (module::predicate) + Qualified { + module: Identifier, + predicate: Identifier, + }, +} + +impl PredicateRef { + /// Get the predicate name (without module qualifier) + pub fn predicate_name(&self) -> &str { + match self { + PredicateRef::Local(id) => &id.name, + PredicateRef::Qualified { predicate, .. } => &predicate.name, + } + } +} + /// Arguments that can be passed to statements #[derive(Debug, Clone, PartialEq)] pub enum StatementTmplArg { @@ -256,7 +265,7 @@ impl fmt::Display for Document { impl fmt::Display for DocumentItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - DocumentItem::UseBatchStatement(u) => write!(f, "{}", u), + DocumentItem::UseModuleStatement(u) => write!(f, "{}", u), DocumentItem::UseIntroStatement(u) => write!(f, "{}", u), DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c), DocumentItem::RequestDef(r) => write!(f, "{}", r), @@ -264,16 +273,9 @@ impl fmt::Display for DocumentItem { } } -impl fmt::Display for UseBatchStatement { +impl fmt::Display for UseModuleStatement { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "use batch ")?; - for (i, import) in self.imports.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", import)?; - } - write!(f, " from {}", self.batch_ref) + write!(f, "use module {} as {}", self.hash, self.alias) } } @@ -290,21 +292,17 @@ impl fmt::Display for UseIntroStatement { } } -impl fmt::Display for ImportName { +impl fmt::Display for PredicateRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ImportName::Named(name) => write!(f, "{}", name), - ImportName::Unused => write!(f, "_"), + PredicateRef::Local(id) => write!(f, "{}", id), + PredicateRef::Qualified { module, predicate } => { + write!(f, "{}::{}", module, predicate) + } } } } -impl fmt::Display for BatchRef { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.hash) - } -} - impl fmt::Display for IntroPredicateRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.hash) @@ -536,10 +534,10 @@ pub mod parse { for inner_pair in pair.into_inner() { match inner_pair.as_rule() { - Rule::use_batch_statement => { - items.push(DocumentItem::UseBatchStatement(parse_use_batch_statement( - inner_pair, - ))); + Rule::use_module_statement => { + items.push(DocumentItem::UseModuleStatement( + parse_use_module_statement(inner_pair), + )); } Rule::use_intro_statement => { items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement( @@ -562,25 +560,17 @@ pub mod parse { Ok(Document { items }) } - fn parse_use_batch_statement(pair: Pair) -> UseBatchStatement { - assert_eq!(pair.as_rule(), Rule::use_batch_statement); + fn parse_use_module_statement(pair: Pair) -> UseModuleStatement { + assert_eq!(pair.as_rule(), Rule::use_module_statement); let span = get_span(&pair); let mut inner = pair.into_inner(); - let use_list_pair = inner - .find(|p| p.as_rule() == Rule::use_predicate_list) - .unwrap(); - let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap(); + let hash = parse_hash_hex(inner.next().unwrap()); + let alias = parse_identifier(inner.next().unwrap()); - let imports = use_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::import_name) - .map(parse_import_name) - .collect(); - - UseBatchStatement { - imports, - batch_ref: parse_hash_hex(batch_ref_pair.into_inner().next().unwrap()), + UseModuleStatement { + hash, + alias, span: Some(span), } } @@ -622,16 +612,6 @@ pub mod parse { } } - fn parse_import_name(pair: Pair) -> ImportName { - assert_eq!(pair.as_rule(), Rule::import_name); - let s = pair.as_str(); - if s == "_" { - ImportName::Unused - } else { - ImportName::Named(s.to_string()) - } - } - fn parse_hash_hex(pair: Pair) -> HashHex { assert_eq!(pair.as_rule(), Rule::hash_hex); let span = get_span(&pair); @@ -748,7 +728,7 @@ pub mod parse { let span = get_span(&pair); let mut inner = pair.into_inner(); - let predicate = parse_identifier(inner.next().unwrap()); + let predicate = parse_predicate_ref(inner.next().unwrap()); let mut args = Vec::new(); if let Some(arg_list) = inner.next() { @@ -768,6 +748,22 @@ pub mod parse { }) } + fn parse_predicate_ref(pair: Pair) -> PredicateRef { + assert_eq!(pair.as_rule(), Rule::predicate_ref); + let inner = pair.into_inner().next().unwrap(); + + match inner.as_rule() { + Rule::qualified_predicate_ref => { + let mut parts = inner.into_inner(); + let module = parse_identifier(parts.next().unwrap()); + let predicate = parse_identifier(parts.next().unwrap()); + PredicateRef::Qualified { module, predicate } + } + Rule::identifier => PredicateRef::Local(parse_identifier(inner)), + _ => unreachable!("Unexpected predicate_ref rule: {:?}", inner.as_rule()), + } + } + fn parse_statement_arg(pair: Pair) -> Result { assert_eq!(pair.as_rule(), Rule::statement_arg); let inner = pair.into_inner().next().unwrap(); @@ -1047,9 +1043,10 @@ mod tests { fn clear_spans(doc: &mut Document) { for item in &mut doc.items { match item { - DocumentItem::UseBatchStatement(u) => { + DocumentItem::UseModuleStatement(u) => { u.span = None; - u.batch_ref.span = None; + u.hash.span = None; + u.alias.span = None; } DocumentItem::UseIntroStatement(u) => { u.span = None; @@ -1082,9 +1079,19 @@ mod tests { } } + fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) { + match pred_ref { + PredicateRef::Local(id) => id.span = None, + PredicateRef::Qualified { module, predicate } => { + module.span = None; + predicate.span = None; + } + } + } + fn clear_statement_spans(stmt: &mut StatementTmpl) { stmt.span = None; - stmt.predicate.span = None; + clear_predicate_ref_spans(&mut stmt.predicate); for arg in &mut stmt.args { match arg { StatementTmplArg::Literal(lit) => clear_literal_spans(lit), @@ -1168,8 +1175,8 @@ mod tests { } #[test] - fn test_use_batch_statement() { - let input = r#"use batch pred1, pred2, _ from 0x0000000000000000000000000000000000000000000000000000000000000000"#; + fn test_use_module_statement() { + let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as helpers"#; test_roundtrip(input); } @@ -1223,11 +1230,11 @@ mod tests { #[test] fn test_complete_document() { - let input = r#"use batch imported_pred from 0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd + let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported is_valid(User, private: Config) = AND ( Equal(User["age"], Config["min_age"]) - imported_pred(User, Config) + imported::some_pred(User, Config) ) check_both(A, B, C) = OR ( @@ -1306,7 +1313,7 @@ REQUEST( // Check request structure if let DocumentItem::RequestDef(req) = &ast.items[1] { assert_eq!(req.statements.len(), 1); - assert_eq!(req.statements[0].predicate.name, "my_pred"); + assert_eq!(req.statements[0].predicate.predicate_name(), "my_pred"); assert_eq!(req.statements[0].args.len(), 2); } else { panic!("Expected RequestDef"); diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs deleted file mode 100644 index 37450cd..0000000 --- a/src/lang/frontend_ast_batch.rs +++ /dev/null @@ -1,1198 +0,0 @@ -//! Multi-batch packing for predicates -//! -//! This module implements packing of multiple predicates (including split chains) -//! into multiple CustomPredicateBatches when they exceed single-batch limits. -//! -//! Packing strategy (dependency-aware): -//! - Build a dependency graph of predicates (edges: callee → caller for local refs). -//! - Condense strongly connected components (SCCs) to ensure mutually-recursive preds stay together. -//! - Topologically order the SCC DAG; within each topological layer, pack larger components first -//! (ties broken by declaration order) to reduce wasted space. -//! - Within a batch, intra-batch calls use `BatchSelf` and work regardless of declaration order; -//! cross-batch calls always point to earlier batches via `CustomPredicateRef`. -//! - Forward cross-batch references cannot occur with this planner (they are treated as unreachable). - -use std::{collections::HashMap, sync::Arc}; - -use petgraph::{algo::condensation, graph::DiGraph, prelude::NodeIndex, visit::EdgeRef}; - -use crate::{ - frontend::{CustomPredicateBatchBuilder, Operation, OperationArg, StatementTmplBuilder}, - lang::{ - error::BatchingError, - frontend_ast::{ConjunctionType, CustomPredicateDef}, - frontend_ast_lower::{lower_statement_arg, resolve_predicate, ResolutionContext}, - frontend_ast_split::{SplitChainInfo, SplitResult}, - frontend_ast_validate::SymbolTable, - }, - middleware::{CustomPredicateBatch, CustomPredicateRef, Params, Statement}, -}; - -/// A single step in a multi-operation sequence for split predicates -#[derive(Debug, Clone)] -struct OperationStep { - /// The operation to perform - operation: Operation, - /// Whether this step's result should be public - public: bool, -} - -/// Errors that can occur when building multi-operations -#[derive(Debug, Clone, thiserror::Error)] -pub enum MultiOperationError { - #[error("Predicate not found: {0}")] - PredicateNotFound(String), - - #[error("Chain piece not found: {0}")] - ChainPieceNotFound(String), - - #[error( - "Wrong statement count for predicate '{predicate}': expected {expected}, got {actual}" - )] - WrongStatementCount { - predicate: String, - expected: usize, - actual: usize, - }, - - #[error("No operation steps to apply")] - NoSteps, -} - -/// Container for multiple predicate batches -#[derive(Debug, Clone)] -pub struct PredicateBatches { - batches: Vec>, - /// Maps predicate name to (batch_index, predicate_index_within_batch) - predicate_index: HashMap, - /// Split chain metadata for predicates that were split - /// Maps original predicate name to its chain info - split_chains: HashMap, -} - -impl Default for PredicateBatches { - fn default() -> Self { - Self::new() - } -} - -impl PredicateBatches { - pub fn new() -> Self { - Self { - batches: Vec::new(), - predicate_index: HashMap::new(), - split_chains: HashMap::new(), - } - } - - /// Get split chain info for a predicate (if it was split) - pub fn split_chain(&self, name: &str) -> Option<&SplitChainInfo> { - self.split_chains.get(name) - } - - /// Get a reference to a predicate by name - pub fn predicate_ref_by_name(&self, name: &str) -> Option { - let (batch_idx, pred_idx) = self.predicate_index.get(name)?; - let batch = self.batches.get(*batch_idx)?; - Some(CustomPredicateRef::new(batch.clone(), *pred_idx)) - } - - /// Get all batches - pub fn batches(&self) -> &[Arc] { - &self.batches - } - - /// Get the first batch (for backwards compatibility) - pub fn first_batch(&self) -> Option<&Arc> { - self.batches.first() - } - - /// Get batch count - pub fn batch_count(&self) -> usize { - self.batches.len() - } - - /// Check if empty - pub fn is_empty(&self) -> bool { - self.batches.is_empty() - } - - /// Total predicate count across all batches - pub fn total_predicate_count(&self) -> usize { - self.batches.iter().map(|b| b.predicates().len()).sum() - } - - /// Build operation steps for a predicate (internal helper) - /// - /// For non-split predicates, returns a single operation. - /// For split predicates, returns the chain of operations in execution order - /// (innermost first), with chain link placeholders. - fn build_steps( - &self, - predicate_name: &str, - statements: Vec, - public: bool, - ) -> Result, MultiOperationError> { - // Check if this predicate was split - let chain_info = match self.split_chains.get(predicate_name) { - Some(info) => info, - None => { - // Not split - single operation with all statements - let pred_ref = self.predicate_ref_by_name(predicate_name).ok_or_else(|| { - MultiOperationError::PredicateNotFound(predicate_name.to_string()) - })?; - - return Ok(vec![OperationStep { - operation: Operation::custom(pred_ref, statements), - public, - }]); - } - }; - - // Validate statement count - if statements.len() != chain_info.real_statement_count { - return Err(MultiOperationError::WrongStatementCount { - predicate: predicate_name.to_string(), - expected: chain_info.real_statement_count, - actual: statements.len(), - }); - } - - // Reorder statements from original order to split order - // reorder_map[original_idx] = split_idx - // So we need to place statements[i] at position reorder_map[i] - let mut reordered = vec![Statement::None; statements.len()]; - for (original_idx, stmt) in statements.into_iter().enumerate() { - let split_idx = chain_info.reorder_map[original_idx]; - reordered[split_idx] = stmt; - } - - // Build operations for each piece in execution order (innermost first) - // - // chain_pieces are in execution order: [continuation_N, ..., continuation_1, main] - // But in split order, statements are laid out: [main's stmts, cont_1's stmts, ..., cont_N's stmts] - // So we need to compute offsets from the END for the first pieces. - // - // Example with 6 statements, max_arity 5: - // split order: [stmt0, stmt1, stmt2, stmt3, stmt4, stmt5] - // chain_pieces[0] (large_pred_1): takes stmt5 (the last 1) - // chain_pieces[1] (large_pred): takes stmt0-4 (the first 5) - // - // We compute offsets by going through pieces in reverse order (matching split order). - - let num_pieces = chain_info.chain_pieces.len(); - - // Compute the starting offset for each piece by iterating in reverse - // (reverse of chain_pieces = same order as split layout) - let mut piece_offsets = vec![0usize; num_pieces]; - let mut offset = 0; - for i in (0..num_pieces).rev() { - piece_offsets[i] = offset; - offset += chain_info.chain_pieces[i].real_statement_count; - } - - let mut steps = Vec::new(); - for (piece_idx, piece) in chain_info.chain_pieces.iter().enumerate() { - let is_final = piece_idx == num_pieces - 1; - - // Get predicate ref for this piece - let piece_ref = self - .predicate_ref_by_name(&piece.name) - .ok_or_else(|| MultiOperationError::ChainPieceNotFound(piece.name.clone()))?; - - // Slice the reordered statements for this piece - let start = piece_offsets[piece_idx]; - let end = start + piece.real_statement_count; - let piece_statements: Vec = reordered[start..end].to_vec(); - - // Build the operation - // For non-final pieces, we'll add a placeholder that will be replaced - // with the previous step's result when applied - let mut args = piece_statements; - if piece.has_chain_call { - // Add placeholder for chain link - will be replaced by apply_multi_operation - args.push(Statement::None); - } - - steps.push(OperationStep { - operation: Operation::custom(piece_ref, args), - public: public && is_final, // Only final piece is public - }); - } - - Ok(steps) - } - - /// Apply a predicate directly into a `MainPodBuilder` (common case). - /// - /// For split predicates, earlier chain links are applied as private, and only the final - /// piece is applied as public when `public` is true. For non-split predicates, the single - /// operation is applied with the provided `public` flag. - /// - /// Arguments: - /// - `builder`: target builder to receive operations - /// - `name`: predicate name - /// - `statements`: user statements in original declaration order - /// - `public`: whether the final result should be public - pub fn apply_predicate( - &self, - builder: &mut crate::frontend::MainPodBuilder, - name: &str, - statements: Vec, - public: bool, - ) -> crate::frontend::Result { - self.apply_predicate_with(name, statements, public, |is_public, op| { - if is_public { - builder.pub_op(op) - } else { - builder.priv_op(op) - } - }) - } - - /// Advanced variant: apply using a custom closure. - /// - /// Prefer `apply_predicate` for common usage. This method allows callers to intercept each - /// operation (with its `public` flag) and decide how to execute it. - /// - /// Arguments: - /// - `name`: predicate name - /// - `statements`: user statements in original declaration order - /// - `public`: whether the final result should be public - /// - `apply_op`: closure `(is_public, operation) -> Result` used to execute each step - pub fn apply_predicate_with( - &self, - name: &str, - statements: Vec, - public: bool, - mut apply_op: F, - ) -> Result - where - F: FnMut(bool, Operation) -> Result, - E: From, - { - let steps = self.build_steps(name, statements, public)?; - - if steps.is_empty() { - return Err(MultiOperationError::NoSteps.into()); - } - - let mut prev_result: Option = None; - - for step in steps { - let op = if let Some(prev) = prev_result { - // Replace the last Statement::None arg with the previous result. - // By construction, all steps after the first include a chain placeholder - // as their last argument. - let mut args = step.operation.1; - let last = args - .last_mut() - .expect("chain statement should include placeholder arg"); - assert!( - matches!(last, OperationArg::Statement(Statement::None)), - "expected last arg to be a Statement::None placeholder" - ); - *last = OperationArg::Statement(prev); - Operation(step.operation.0, args, step.operation.2) - } else { - step.operation - }; - - prev_result = Some(apply_op(step.public, op)?); - } - - // Safe to unwrap because we checked steps.is_empty() above - Ok(prev_result.unwrap()) - } -} - -/// Assignment of a predicate to a batch -#[derive(Debug, Clone)] -struct PredicateAssignment { - /// Full name (e.g., "my_pred_1" for split link) - full_name: String, - /// Which batch this goes into - batch_index: usize, - /// Index within that batch - index_in_batch: usize, -} - -/// Pack predicates into multiple batches -/// -/// Takes a list of split results (containing predicates and optional chain info) -/// and packs them into batches, handling cross-batch references correctly. -/// -/// Predicates are packed dependency‑aware: -/// - Mutually recursive predicates (SCCs) are kept together. -/// - Components are ordered topologically; within each layer, larger components are packed first -/// (ties by declaration order) to reduce wasted space. -/// - Within a batch, predicates can reference each other freely via `BatchSelf`; cross-batch -/// references always point to earlier batches via `CustomPredicateRef`. -/// -/// `symbols` provides the symbol table for resolving predicate references, -/// including imported predicates from other batches and intro predicates. -pub fn batch_predicates( - split_results: Vec, - params: &Params, - base_batch_name: &str, - symbols: &SymbolTable, -) -> Result { - // Extract predicates and collect split chains - let mut predicates = Vec::new(); - let mut split_chains = HashMap::new(); - - for result in split_results { - // Collect chain info if present - if let Some(chain_info) = result.chain_info { - split_chains.insert(chain_info.original_name.clone(), chain_info); - } - // Flatten predicates - predicates.extend(result.predicates); - } - - if predicates.is_empty() { - return Ok(PredicateBatches::new()); - } - - // Plan batch assignments in declaration order - let assignments = plan_batch_assignments(&predicates, Params::max_custom_batch_size())?; - - // Build reference map: name -> (batch_idx, idx_in_batch) - let reference_map: HashMap = assignments - .iter() - .map(|a| (a.full_name.clone(), (a.batch_index, a.index_in_batch))) - .collect(); - - // Determine number of batches - let num_batches = assignments - .iter() - .map(|a| a.batch_index) - .max() - .map(|m| m + 1) - .unwrap_or(0); - - // Build batches in order - let mut batches = Vec::new(); - let mut predicate_index = HashMap::new(); - - for batch_idx in 0..num_batches { - // Collect predicates for this batch (in assignment order) - let batch_predicates: Vec<_> = predicates - .iter() - .zip(assignments.iter()) - .filter(|(_, a)| a.batch_index == batch_idx) - .map(|(p, _)| p.clone()) - .collect(); - - let batch_name = if num_batches == 1 { - base_batch_name.to_string() - } else { - format!("{}_{}", base_batch_name, batch_idx) - }; - - let batch = build_single_batch( - &batch_predicates, - batch_idx, - &reference_map, - &batches, - symbols, - params, - &batch_name, - )?; - - // Update predicate index - for (idx, pred) in batch_predicates.iter().enumerate() { - predicate_index.insert(pred.name.name.clone(), (batch_idx, idx)); - } - - batches.push(batch); - } - - Ok(PredicateBatches { - batches, - predicate_index, - split_chains, - }) -} - -/// Plan batch assignments (greedy fill in declaration order) -fn plan_batch_assignments( - predicates: &[CustomPredicateDef], - max_batch_size: usize, -) -> Result, BatchingError> { - // Map name -> original index - let mut name_to_index: HashMap = HashMap::new(); - let index_to_name: Vec = predicates - .iter() - .enumerate() - .map(|(i, pred)| { - name_to_index.insert(pred.name.name.clone(), i); - pred.name.name.clone() - }) - .collect(); - - let n = predicates.len(); - // Build graph with nodes 0..n and edges callee -> caller for local refs - let mut graph: DiGraph = DiGraph::new(); - let nodes: Vec = (0..n).map(|i| graph.add_node(i)).collect(); - for (caller_idx, pred) in predicates.iter().enumerate() { - for stmt in &pred.statements { - if let Some(&callee_idx) = name_to_index.get(&stmt.predicate.name) { - graph.add_edge(nodes[callee_idx], nodes[caller_idx], ()); - } - } - } - - // Condense SCCs into DAG; each node weight is Vec of members - // Pass `true` to remove self-loops, ensuring acyclicity for topo sort - let mut condensed = condensation(graph, /*make_acyclic=*/ true); - - // Verify each component fits in a batch and sort members by original index - for comp_members in condensed.node_weights_mut() { - comp_members.sort_unstable(); - if comp_members.len() > max_batch_size { - let members = comp_members - .iter() - .map(|&i| index_to_name[i].clone()) - .collect::>() - .join(", "); - // An SCC larger than the per-batch capacity cannot be packed: all members of a - // mutually-recursive group must live in the same batch. Splitting reduces per‑predicate - // arity but does not break cycles, and the split chain for a single predicate remains - // acyclic (so it does not increase the SCC size). Users must refactor to break the - // cycle or increase `max_custom_batch_size`. - return Err(BatchingError::Internal { - message: format!( - "Mutually recursive group of size {} exceeds batch capacity {}. Predicates: [{}]. \\n+ Consider breaking the cycle or increasing max_custom_batch_size.", - comp_members.len(), - max_batch_size, - members - ), - }); - } - } - - // Topological sort using a layer-wise variant of Kahn's algorithm. - // - // Standard Kahn's algorithm processes nodes one at a time from a queue. This variant - // instead processes entire "layers" (all nodes at the same topological depth) together, - // which allows sorting within each layer for better bin-packing while still respecting - // dependency order. - // - // Algorithm: - // 1. Compute in-degree for each node - // 2. Initialize first layer with all zero in-degree nodes (no dependencies) - // 3. For each layer: - // a. Sort by component size (desc) for bin-packing, then by key for determinism - // b. Add to output order - // c. Decrement in-degree of all neighbors; those hitting zero form the next layer - // 4. Assert all nodes visited (would fail if graph had cycles, but condensation ensures DAG) - - let node_count = condensed.node_count(); - - // Step 1: Compute in-degrees - let mut indeg = vec![0usize; node_count]; - for e in condensed.edge_references() { - indeg[e.target().index()] += 1; - } - - // Stable key per component: minimal original index inside the component - // Used as tiebreaker when sorting layers for deterministic output - let mut comp_key: Vec = vec![0; node_count]; - for ni in condensed.node_indices() { - let members = &condensed[ni]; - let key = members.iter().copied().min().expect("non-empty component"); - comp_key[ni.index()] = key; - } - - // Step 2: Initialize with zero in-degree nodes - let mut current_layer: Vec = condensed - .node_indices() - .filter(|&ni| indeg[ni.index()] == 0) - .collect(); - - let mut order: Vec = Vec::with_capacity(node_count); - use std::cmp::Reverse; - - // Step 3: Process layer by layer - while !current_layer.is_empty() { - // Sort by size desc (for bin-packing), then by comp_key asc (for determinism) - current_layer.sort_by_key(|&ni| { - let size = condensed[ni].len(); - (Reverse(size), comp_key[ni.index()]) - }); - - // Add this layer to the output order - order.extend(current_layer.iter().copied()); - - // Build next layer: decrement in-degrees, collect nodes that hit zero - let mut next_layer: Vec = Vec::new(); - for &u in ¤t_layer { - for v in condensed.neighbors(u) { - let idx = v.index(); - indeg[idx] -= 1; - if indeg[idx] == 0 { - next_layer.push(v); - } - } - } - current_layer = next_layer; - } - - // Step 4: Verify all nodes were visited (cycle detection) - assert_eq!(order.len(), node_count, "condensed graph must be acyclic"); - - // Greedy pack components by the layer-aware order - let mut pred_batch: Vec = vec![0; n]; - let mut current_batch = 0usize; - let mut current_count = 0usize; - for cid in order { - let comp = &condensed[cid]; - let comp_size = comp.len(); - // If the next component doesn't fit in the remaining capacity, start a new batch. - // This is the normal batch boundary; precedence is preserved, and we mitigate wasted - // space by sorting components within each topo layer by size (desc) earlier. - if current_count + comp_size > max_batch_size { - current_batch += 1; - current_count = 0; - } - for &pi in comp { - pred_batch[pi] = current_batch; - } - current_count += comp_size; - } - - // Compute index_in_batch by original order to match builder's enumeration - let mut per_batch_counts: HashMap = HashMap::new(); - let mut assignments = Vec::with_capacity(n); - for (i, pred) in predicates.iter().enumerate() { - let b = pred_batch[i]; - let idx = per_batch_counts.get(&b).cloned().unwrap_or(0); - per_batch_counts.insert(b, idx + 1); - assignments.push(PredicateAssignment { - full_name: pred.name.name.clone(), - batch_index: b, - index_in_batch: idx, - }); - } - - Ok(assignments) -} - -/// Build a single batch with properly resolved references -fn build_single_batch( - predicates: &[CustomPredicateDef], - batch_idx: usize, - reference_map: &HashMap, - existing_batches: &[Arc], - symbols: &SymbolTable, - params: &Params, - batch_name: &str, -) -> Result, BatchingError> { - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), batch_name.to_string()); - - for pred in predicates { - let name = &pred.name.name; - - // Collect argument names - let public_args: Vec<&str> = pred - .args - .public_args - .iter() - .map(|a| a.name.as_str()) - .collect(); - - let private_args: Vec<&str> = pred - .args - .private_args - .as_ref() - .map(|args| args.iter().map(|a| a.name.as_str()).collect()) - .unwrap_or_default(); - - // Build statement templates with resolved predicates - let statement_builders: Vec = pred - .statements - .iter() - .map(|stmt| { - build_statement_with_resolved_refs( - stmt, - batch_idx, - reference_map, - existing_batches, - name, - symbols, - ) - }) - .collect::>()?; - - let conjunction = pred.conjunction_type == ConjunctionType::And; - - builder - .predicate( - name, - conjunction, - &public_args, - &private_args, - &statement_builders, - ) - .map_err(|e| BatchingError::Internal { - message: format!("Failed to add predicate '{}': {}", name, e), - })?; - } - - Ok(builder.finish()) -} - -/// Build a statement template with properly resolved predicate references -fn build_statement_with_resolved_refs( - stmt: &crate::lang::frontend_ast::StatementTmpl, - current_batch_idx: usize, - reference_map: &HashMap, - existing_batches: &[Arc], - custom_predicate_name: &str, // custom pred that defines this statement template - symbols: &SymbolTable, -) -> Result { - let callee_name = &stmt.predicate.name; - - // Resolve the predicate using the unified resolution function - let context = ResolutionContext::Batch { - current_batch_idx, - reference_map, - existing_batches, - custom_predicate_name, - }; - - let pred_or_wc = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| { - BatchingError::Internal { - message: format!("Unknown predicate reference: '{}'", callee_name), - } - })?; - - // Build the statement template - let mut builder = StatementTmplBuilder::new(pred_or_wc); - - for arg in &stmt.args { - builder = builder.arg(lower_statement_arg(arg)); - } - - Ok(builder) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - lang::{ - frontend_ast::parse::parse_document, - frontend_ast_split::split_predicate_if_needed, - frontend_ast_validate::{validate, ValidatedAST}, - parser::parse_podlang, - }, - middleware::{Predicate, PredicateOrWildcard}, - }; - - /// Helper: parse and validate input, returning predicates and symbol table - fn parse_and_validate(input: &str) -> (Vec, ValidatedAST) { - let parsed = parse_podlang(input).expect("Failed to parse"); - let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let validated = validate(document.clone(), &[]).expect("Failed to validate"); - - let predicates = document - .items - .into_iter() - .filter_map(|item| match item { - crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred), - _ => None, - }) - .collect(); - - (predicates, validated) - } - - /// Helper: wrap predicates into SplitResult (without actually splitting) - fn preds_to_split_results(predicates: Vec) -> Vec { - predicates - .into_iter() - .map(|pred| SplitResult { - predicates: vec![pred], - chain_info: None, - }) - .collect() - } - - #[test] - fn test_single_predicate_single_batch() { - let input = r#" - my_pred(A, B) = AND( - Equal(A["x"], B["y"]) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 1); - assert_eq!(batches.total_predicate_count(), 1); - } - - #[test] - fn test_multiple_predicates_single_batch() { - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - pred3(C) = AND(Equal(C["z"], 3)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); // max_custom_batch_size = 4 - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 1); - assert_eq!(batches.total_predicate_count(), 3); - } - - #[test] - fn test_intra_batch_forward_reference() { - // pred2 calls pred1, but pred2 is declared first - // This should work because they're in the same batch - let input = r#" - pred2(B) = AND(pred1(B)) - pred1(A) = AND(Equal(A["x"], 1)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 1); - - // pred2 should reference pred1 via BatchSelf - use crate::middleware::PredicateOrWildcard; - let pred2 = &batches.batches()[0].predicates()[0]; - let stmt = &pred2.statements[0]; - assert!(matches!( - stmt.pred_or_wc(), - PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) - )); // pred1 is at index 1 - } - - #[test] - fn test_mutual_recursion_in_same_batch() { - // pred1 calls pred2, pred2 calls pred1 - mutual recursion - // This should work because they're in the same batch - let input = r#" - pred1(A) = AND(pred2(A)) - pred2(B) = AND(pred1(B)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 1); - assert_eq!(batches.total_predicate_count(), 2); - - // Both should use BatchSelf references - let pred1 = &batches.batches()[0].predicates()[0]; - let pred2 = &batches.batches()[0].predicates()[1]; - assert!(matches!( - pred1.statements[0].pred_or_wc(), - PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) - )); // calls pred2 - assert!(matches!( - pred2.statements[0].pred_or_wc(), - PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) - )); // calls pred1 - } - - #[test] - fn test_empty_input() { - let split_results: Vec = vec![]; - let params = Params::default(); - // For empty input, we need an empty symbol table - let empty_symbols = SymbolTable { - predicates: HashMap::new(), - wildcard_scopes: HashMap::new(), - }; - - let result = batch_predicates(split_results, ¶ms, "TestBatch", &empty_symbols); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert!(batches.is_empty()); - assert_eq!(batches.batch_count(), 0); - } - - #[test] - fn test_predicate_ref_by_name() { - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let batches = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ) - .unwrap(); - - // Should be able to look up both predicates - assert!(batches.predicate_ref_by_name("pred1").is_some()); - assert!(batches.predicate_ref_by_name("pred2").is_some()); - assert!(batches.predicate_ref_by_name("nonexistent").is_none()); - } - - /// Helper: create a unique Statement for testing - /// Uses Equal with distinct literal values to create distinguishable statements - fn test_statement(id: usize) -> Statement { - use crate::middleware::ValueRef; - Statement::Equal( - ValueRef::Literal((id as i64).into()), - ValueRef::Literal((id as i64).into()), - ) - } - - #[test] - fn test_apply_predicate_non_split() { - // A simple predicate that doesn't need splitting - let input = r#" - my_pred(A, B) = AND( - Equal(A["x"], B["y"]) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let batches = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ) - .unwrap(); - - // Create fake statements - let statements = vec![Statement::None, Statement::None]; - - // Track operations applied - let mut operations_applied: Vec<(bool, usize)> = Vec::new(); - let mut stmt_counter = 0; - - let result: Result = - batches.apply_predicate_with("my_pred", statements, true, |public, op| { - operations_applied.push((public, op.1.len())); - stmt_counter += 1; - Ok(test_statement(stmt_counter)) - }); - - assert!(result.is_ok()); - // Should be exactly one operation - assert_eq!(operations_applied.len(), 1); - // Should be public - assert!(operations_applied[0].0); - // Should have 2 arguments - assert_eq!(operations_applied[0].1, 2); - } - - #[test] - fn test_apply_predicate_2_piece_split() { - // A predicate that will split into 2 pieces - let input = r#" - large_pred(A) = AND( - Equal(A["a"], 1) - Equal(A["b"], 2) - Equal(A["c"], 3) - Equal(A["d"], 4) - Equal(A["e"], 5) - Equal(A["f"], 6) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - // Split the predicate - let mut split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - split_results.push(result); - } - - // Should split into 2 pieces - assert_eq!(split_results.len(), 1); - assert_eq!(split_results[0].predicates.len(), 2); - assert!(split_results[0].chain_info.is_some()); - - let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) - .expect("Batch failed"); - - // Verify chain info - let chain_info = batches.split_chain("large_pred").unwrap(); - assert_eq!(chain_info.chain_pieces.len(), 2); - assert_eq!(chain_info.real_statement_count, 6); - - // Create fake statements (6 for the 6 Equal statements) - let statements: Vec = (0..6).map(test_statement).collect(); - - // Track operations - let mut operations_applied: Vec<(bool, usize)> = Vec::new(); - let mut stmt_counter = 100; - - let result: Result = - batches.apply_predicate_with("large_pred", statements, true, |public, op| { - operations_applied.push((public, op.1.len())); - stmt_counter += 1; - Ok(test_statement(stmt_counter)) - }); - - assert!(result.is_ok()); - // Should be exactly 2 operations (innermost continuation first, then main) - assert_eq!(operations_applied.len(), 2); - // First operation (continuation) should be private - assert!(!operations_applied[0].0); - // Second operation (main) should be public - assert!(operations_applied[1].0); - } - - #[test] - fn test_apply_predicate_3_piece_split() { - // A predicate that will split into 3 pieces (needs more statements) - let input = r#" - very_large_pred(A) = AND( - Equal(A["a"], 1) - Equal(A["b"], 2) - Equal(A["c"], 3) - Equal(A["d"], 4) - Equal(A["e"], 5) - Equal(A["f"], 6) - Equal(A["g"], 7) - Equal(A["h"], 8) - Equal(A["i"], 9) - Equal(A["j"], 10) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - // Split the predicate - let mut split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - split_results.push(result); - } - - // Should split into 3 pieces - assert_eq!(split_results.len(), 1); - assert_eq!(split_results[0].predicates.len(), 3); - assert!(split_results[0].chain_info.is_some()); - - let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) - .expect("Batch failed"); - - // Verify chain info - let chain_info = batches.split_chain("very_large_pred").unwrap(); - assert_eq!(chain_info.chain_pieces.len(), 3); - assert_eq!(chain_info.real_statement_count, 10); - - // Create fake statements (10 for the 10 Equal statements) - let statements: Vec = (0..10).map(test_statement).collect(); - - // Track operations - let mut operations_applied: Vec<(bool, usize)> = Vec::new(); - let mut stmt_counter = 100; - - let result: Result = - batches.apply_predicate_with("very_large_pred", statements, true, |public, op| { - operations_applied.push((public, op.1.len())); - stmt_counter += 1; - Ok(test_statement(stmt_counter)) - }); - - assert!(result.is_ok()); - // Should be exactly 3 operations - assert_eq!(operations_applied.len(), 3); - // First two operations (continuations) should be private - assert!(!operations_applied[0].0); - assert!(!operations_applied[1].0); - // Final operation (main) should be public - assert!(operations_applied[2].0); - } - - #[test] - fn test_apply_predicate_wrong_statement_count() { - // A predicate that will split - let input = r#" - large_pred(A) = AND( - Equal(A["a"], 1) - Equal(A["b"], 2) - Equal(A["c"], 3) - Equal(A["d"], 4) - Equal(A["e"], 5) - Equal(A["f"], 6) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - // Split the predicate - let mut split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - split_results.push(result); - } - - let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) - .expect("Batch failed"); - - // Try with wrong number of statements (3 instead of 6) - let statements: Vec = (0..3).map(test_statement).collect(); - - let result: Result = - batches.apply_predicate_with("large_pred", statements, true, |_, _| { - Ok(test_statement(999)) - }); - - assert!(result.is_err()); - let err = result.unwrap_err(); - match err { - MultiOperationError::WrongStatementCount { - predicate, - expected, - actual, - } => { - assert_eq!(predicate, "large_pred"); - assert_eq!(expected, 6); - assert_eq!(actual, 3); - } - _ => panic!("Expected WrongStatementCount error, got {:?}", err), - } - } - - #[test] - fn test_apply_predicate_not_found() { - let input = r#" - my_pred(A) = AND(Equal(A["x"], 1)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let batches = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ) - .unwrap(); - - let result: Result = - batches - .apply_predicate_with("nonexistent", vec![], true, |_, _| Ok(test_statement(999))); - - assert!(result.is_err()); - match result.unwrap_err() { - MultiOperationError::PredicateNotFound(name) => { - assert_eq!(name, "nonexistent"); - } - e => panic!("Expected PredicateNotFound error, got {:?}", e), - } - } - - #[test] - fn test_apply_predicate_chain_wiring() { - // Test that chain links are properly wired (previous result replaces Statement::None) - let input = r#" - large_pred(A) = AND( - Equal(A["a"], 1) - Equal(A["b"], 2) - Equal(A["c"], 3) - Equal(A["d"], 4) - Equal(A["e"], 5) - Equal(A["f"], 6) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let mut split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - split_results.push(result); - } - - let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) - .expect("Batch failed"); - - let statements: Vec = (0..6).map(test_statement).collect(); - - // Track whether the second operation has the first result as its last arg - let mut last_args_of_ops: Vec> = Vec::new(); - let mut stmt_counter = 100; - - let result: Result = - batches.apply_predicate_with("large_pred", statements, true, |_, op| { - // Check the last argument - let last_arg = op.1.last().map(|arg| { - if let OperationArg::Statement(s) = arg { - s.clone() - } else { - Statement::None - } - }); - last_args_of_ops.push(last_arg); - stmt_counter += 1; - Ok(test_statement(stmt_counter)) - }); - - assert!(result.is_ok()); - assert_eq!(last_args_of_ops.len(), 2); - - // First operation's last arg should NOT be the result of previous (no previous) - // It might be Statement::None if no chain call, or a regular arg - - // Second operation's last arg SHOULD be test_statement(101) - the result from first op - assert_eq!(last_args_of_ops[1], Some(test_statement(101))); - } -} diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index d020e37..b429f4a 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -1,52 +1,69 @@ //! Lowering from frontend AST to middleware structures //! //! This module converts validated frontend AST to middleware data structures. -//! Supports automatic predicate splitting and multi-batch packing. +//! Supports automatic predicate splitting. use std::{ collections::{HashMap, HashSet}, str::FromStr, - sync::Arc, }; use crate::{ frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder}, lang::{ frontend_ast::*, - frontend_ast_batch::{self, PredicateBatches}, frontend_ast_split, frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, + module, Module, }, middleware::{ - self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key, - NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl, - StatementTmplArg as MWStatementTmplArg, Value, Wildcard, + self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params, + Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value, + Wildcard, }, }; -/// Context for predicate resolution - determines how local custom predicates are resolved +/// Context for predicate resolution - determines how predicates are resolved pub enum ResolutionContext<'a> { - /// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches - Request { - batches: Option<&'a PredicateBatches>, - }, - /// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef - Batch { - current_batch_idx: usize, - reference_map: &'a HashMap, - existing_batches: &'a [Arc], + /// Request context: predicates resolve via imports only (no local definitions) + Request, + /// Module context: local predicates resolve to BatchSelf + Module { + /// Maps predicate name to index within the module + reference_map: &'a HashMap, + /// Name of the custom predicate being defined (for wildcard scope lookup) custom_predicate_name: &'a str, }, } +/// Resolve a predicate reference to a Predicate using the symbol table +pub fn resolve_predicate_ref( + pred_ref: &PredicateRef, + symbols: &SymbolTable, + context: &ResolutionContext, +) -> Option { + match pred_ref { + PredicateRef::Qualified { module, predicate } => { + // Look up the module in the imported_modules + let imported_module = symbols.imported_modules.get(&module.name)?; + // Find the predicate index in the module + let idx = *imported_module.predicate_index.get(&predicate.name)?; + Some(PredicateOrWildcard::Predicate(Predicate::Custom( + CustomPredicateRef::new(imported_module.batch.clone(), idx), + ))) + } + PredicateRef::Local(id) => resolve_predicate(&id.name, symbols, context), + } +} + /// Resolve a predicate name to a Predicate using the symbol table pub fn resolve_predicate( pred_name: &str, symbols: &SymbolTable, context: &ResolutionContext, ) -> Option { - // 0. Try wildcard first - if let ResolutionContext::Batch { + // 0. Try wildcard first (only in module context where we're defining predicates) + if let ResolutionContext::Module { custom_predicate_name, .. } = context @@ -69,28 +86,35 @@ pub fn resolve_predicate( PredicateKind::Native(np) => Predicate::Native(*np), PredicateKind::Custom { .. } => match context { - ResolutionContext::Request { batches } => { - let batches = batches.as_ref()?; - let pred_ref = batches.predicate_ref_by_name(pred_name)?; - Predicate::Custom(pred_ref) + ResolutionContext::Request => { + // Requests can't define local predicates, so this shouldn't happen + return None; + } + ResolutionContext::Module { reference_map, .. } => { + resolve_local_predicate(pred_name, reference_map)? } - ResolutionContext::Batch { - current_batch_idx, - reference_map, - existing_batches, - .. - } => resolve_local_predicate( - pred_name, - *current_batch_idx, - reference_map, - existing_batches, - )?, }, PredicateKind::BatchImported { batch, index } => { Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index)) } + PredicateKind::ModuleImported { + module_name, + predicate_index, + .. + } => { + // Look up the module in the imported_modules + let module = symbols + .imported_modules + .get(module_name) + .expect("Module should exist if ModuleImported predicate kind exists"); + Predicate::Custom(CustomPredicateRef::new( + module.batch.clone(), + *predicate_index, + )) + } + PredicateKind::IntroImported { name, verifier_data_hash, @@ -103,51 +127,25 @@ pub fn resolve_predicate( return Some(PredicateOrWildcard::Predicate(predicate)); } - // 3. In batch context, also check reference_map for split chain pieces + // 3. In module context, also check reference_map for split chain pieces // (predicates created by splitting that aren't in the original symbol table) - if let ResolutionContext::Batch { - current_batch_idx, - reference_map, - existing_batches, - .. - } = context - { + if let ResolutionContext::Module { reference_map, .. } = context { if reference_map.contains_key(pred_name) { - return resolve_local_predicate( - pred_name, - *current_batch_idx, - reference_map, - existing_batches, - ) - .map(PredicateOrWildcard::Predicate); + return resolve_local_predicate(pred_name, reference_map) + .map(PredicateOrWildcard::Predicate); } } None } -/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map +/// Resolve a local predicate (one in this module or a split chain piece) using the reference_map fn resolve_local_predicate( pred_name: &str, - current_batch_idx: usize, - reference_map: &HashMap, - existing_batches: &[Arc], + reference_map: &HashMap, ) -> Option { - let &(target_batch, target_idx) = reference_map.get(pred_name)?; - if target_batch == current_batch_idx { - Some(Predicate::BatchSelf(target_idx)) - } else if target_batch < current_batch_idx { - let batch = &existing_batches[target_batch]; - Some(Predicate::Custom(CustomPredicateRef::new( - batch.clone(), - target_idx, - ))) - } else { - unreachable!( - "Forward cross-batch reference should be impossible: {} -> {}", - current_batch_idx, target_batch - ); - } + let &idx = reference_map.get(pred_name)?; + Some(Predicate::BatchSelf(idx)) } // ============================================================================ @@ -155,7 +153,7 @@ fn resolve_local_predicate( // ============================================================================ // These functions convert AST types to middleware/builder types and are used // by both the request lowering (in this module) and predicate batching -// (in frontend_ast_batch). +// (in module.rs). /// Lower a literal value from AST to middleware Value. /// @@ -215,38 +213,37 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { } } -/// Result of lowering: optional custom predicate batches and optional request -/// -/// A Podlang file can contain: -/// - Just custom predicates (batches: Some, request: None) -/// - Just a request (batches: None, request: Some) -/// - Both (batches: Some, request: Some) -/// - Neither (batches: None, request: None) - just imports -#[derive(Debug, Clone)] -pub struct LoweredOutput { - pub batches: Option, - pub request: Option, -} - pub use crate::lang::error::LoweringError; -/// Lower a validated AST to middleware structures +/// Lower a validated module AST to a Module /// -/// Returns both the custom predicate batch (if any) and the request (if any). -/// At least one will be Some if the document contains custom predicates or a request. -pub fn lower( +/// The validated AST must have been validated in Module mode. +pub fn lower_module( validated: ValidatedAST, params: &Params, - batch_name: String, -) -> Result { + module_name: &str, +) -> Result { if !validated.diagnostics().is_empty() { - // For now, treat any diagnostics as errors - // In future we could allow warnings return Err(LoweringError::ValidationErrors); } let lowerer = Lowerer::new(validated, params); - lowerer.lower(batch_name) + lowerer.lower_module(module_name) +} + +/// Lower a validated request AST to a PodRequest +/// +/// The validated AST must have been validated in Request mode. +pub fn lower_request( + validated: ValidatedAST, + params: &Params, +) -> Result { + if !validated.diagnostics().is_empty() { + return Err(LoweringError::ValidationErrors); + } + + let lowerer = Lowerer::new(validated, params); + lowerer.lower_request() } struct Lowerer<'a> { @@ -259,52 +256,33 @@ impl<'a> Lowerer<'a> { Self { validated, params } } - fn lower(self, batch_name: String) -> Result { - // Lower custom predicates (if any) - now supports multiple batches - let batches = self.lower_batches(batch_name)?; - - // Lower request (if any) - pass batches so refs can be resolved - let request = self.lower_request(batches.as_ref())?; - - Ok(LoweredOutput { batches, request }) - } - - fn lower_batches(&self, batch_name: String) -> Result, LoweringError> { + fn lower_module(self, module_name: &str) -> Result { // Extract and split custom predicates from document let custom_predicates = self.extract_and_split_predicates()?; - // If no custom predicates, return None - if custom_predicates.is_empty() { - return Ok(None); - } - - // Use the new batching module to pack predicates into batches - // Pass the symbol table for unified predicate resolution - let batches = frontend_ast_batch::batch_predicates( + // Build the module from split predicates + let module = module::build_module( custom_predicates, self.params, - &batch_name, + module_name, self.validated.symbols(), )?; - Ok(Some(batches)) + Ok(module) } - fn lower_request( - &self, - batches: Option<&PredicateBatches>, - ) -> Result, LoweringError> { + fn lower_request(self) -> Result { let doc = self.validated.document(); - // Find request definition (if any) - let request_def = doc.items.iter().find_map(|item| match item { - DocumentItem::RequestDef(req) => Some(req), - _ => None, - }); - - let Some(request_def) = request_def else { - return Ok(None); - }; + // Find request definition + let request_def = doc + .items + .iter() + .find_map(|item| match item { + DocumentItem::RequestDef(req) => Some(req), + _ => None, + }) + .expect("Request mode validation ensures REQUEST block exists"); // Build wildcard map from all wildcards used in the request statements let wildcard_map = self.build_request_wildcard_map(request_def); @@ -312,18 +290,17 @@ impl<'a> Lowerer<'a> { // Lower each statement to middleware templates, resolving predicates let mut request_templates = Vec::new(); for stmt in &request_def.statements { - let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?; + let mw_stmt = self.lower_request_statement(stmt, &wildcard_map)?; request_templates.push(mw_stmt); } - Ok(Some(crate::frontend::PodRequest::new(request_templates))) + Ok(crate::frontend::PodRequest::new(request_templates)) } fn lower_request_statement( &self, stmt: &StatementTmpl, wildcard_map: &HashMap, - batches: Option<&PredicateBatches>, ) -> Result { // Enforce argument count limit for request statements if stmt.args.len() > Params::max_statement_args() { @@ -333,16 +310,16 @@ impl<'a> Lowerer<'a> { }); } - let pred_name = &stmt.predicate.name; let symbols = self.validated.symbols(); // Resolve predicate using the unified resolution function - let context = ResolutionContext::Request { batches }; - let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| { - LoweringError::PredicateNotFound { - name: pred_name.clone(), - } - })?; + let context = ResolutionContext::Request; + let predicate = + resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| { + LoweringError::PredicateNotFound { + name: format!("{}", stmt.predicate), + } + })?; // Create a builder with the resolved predicate and desugar let mut builder = StatementTmplBuilder::new(predicate.clone()); @@ -453,31 +430,24 @@ impl<'a> Lowerer<'a> { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::lang::{ - frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang, + frontend_ast::parse::parse_document, + frontend_ast_validate::{validate, ParseMode}, + parser::parse_podlang, }; - fn parse_validate_and_lower( + fn parse_validate_and_lower_module( input: &str, params: &Params, - ) -> Result { + ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let validated = validate(document, &[]).expect("Failed to validate"); - lower(validated, params, "test_batch".to_string()) - } - - // Helper to get the first batch from the output (expecting it to exist) - fn expect_batch( - output: &LoweredOutput, - ) -> &std::sync::Arc { - output - .batches - .as_ref() - .expect("Expected batches to be present") - .first_batch() - .expect("Expected at least one batch") + let validated = + validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); + lower_module(validated, params, "test_batch") } #[test] @@ -489,16 +459,16 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); if let Err(e) = &result { eprintln!("Error: {:?}", e); } assert!(result.is_ok()); - let lowered = result.unwrap(); - assert_eq!(expect_batch(&lowered).predicates().len(), 1); + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 1); - let pred = &expect_batch(&lowered).predicates()[0]; + let pred = &module.batch.predicates()[0]; assert_eq!(pred.name, "my_pred"); assert_eq!(pred.args_len(), 2); assert_eq!(pred.wildcard_names().len(), 2); @@ -515,11 +485,11 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); - let lowered = result.unwrap(); - let pred = &expect_batch(&lowered).predicates()[0]; + let module = result.unwrap(); + let pred = &module.batch.predicates()[0]; assert_eq!(pred.args_len(), 1); // Only A is public assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total } @@ -534,11 +504,11 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); - let lowered = result.unwrap(); - let pred = &expect_batch(&lowered).predicates()[0]; + let module = result.unwrap(); + let pred = &module.batch.predicates()[0]; assert!(pred.is_disjunction()); } @@ -556,23 +526,22 @@ mod tests { "#; let params = Params::default(); // max_custom_predicate_arity = 5 - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); if let Err(e) = &result { eprintln!("Splitting error: {:?}", e); } assert!(result.is_ok()); - let lowered = result.unwrap(); + let module = result.unwrap(); // Should be automatically split into 2 predicates (my_pred and my_pred_1) - let batches = lowered.batches.as_ref().expect("Expected batches"); - assert_eq!(batches.total_predicate_count(), 2); + assert_eq!(module.batch.predicates().len(), 2); // With topological sorting, my_pred_1 comes first (since my_pred depends on it) // my_pred_1 has 2 statements // my_pred has 5 statements (4 + chain call) // Just verify we have the right total statement counts - let batch = batches.first_batch().unwrap(); - let total_statements: usize = batch + let total_statements: usize = module + .batch .predicates() .iter() .map(|p| p.statements().len()) @@ -593,11 +562,11 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); - let lowered = result.unwrap(); - assert_eq!(expect_batch(&lowered).predicates().len(), 2); + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 2); } #[test] @@ -613,11 +582,11 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); - let lowered = result.unwrap(); - let pred2 = &expect_batch(&lowered).predicates()[1]; + let module = result.unwrap(); + let pred2 = &module.batch.predicates()[1]; let stmt = &pred2.statements()[0]; // Should be BatchSelf(0) referring to pred1 @@ -638,7 +607,7 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); } @@ -651,11 +620,11 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); - let lowered = result.unwrap(); - let pred = &expect_batch(&lowered).predicates()[0]; + let module = result.unwrap(); + let pred = &module.batch.predicates()[0]; let stmt = &pred.statements()[0]; // Should desugar to the Contains predicate @@ -677,7 +646,7 @@ mod tests { "#; let params = Params::default(); - let result = parse_validate_and_lower(input, ¶ms); + let result = parse_validate_and_lower_module(input, ¶ms); assert!(result.is_ok()); } @@ -706,18 +675,18 @@ mod tests { let parsed = parse_podlang(&input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document"); - let validated = validate(document, &[]).expect("Failed to validate"); - let result = lower(validated, ¶ms, "test_batch".to_string()); + let validated = + validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); + let result = lower_module(validated, ¶ms, "test_batch"); assert!(result.is_ok(), "Lowering failed: {:?}", result.err()); - let lowered = result.unwrap(); - let batch = expect_batch(&lowered); + let module = result.unwrap(); // Should have one custom predicate - assert_eq!(batch.predicates().len(), 1); + assert_eq!(module.batch.predicates().len(), 1); - let pred = &batch.predicates()[0]; + let pred = &module.batch.predicates()[0]; assert_eq!(pred.name, "my_pred"); // 2 statements: Equal and external_check assert_eq!(pred.statements().len(), 2); diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index cc37463..72bb83c 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -620,7 +620,7 @@ fn generate_chain_predicates( .collect(); let chain_call = StatementTmpl { - predicate: next_pred_name, + predicate: PredicateRef::Local(next_pred_name), args: chain_call_args, span: None, }; @@ -832,7 +832,7 @@ mod tests { let original = &chain[1]; assert_eq!(original.name.name, "complex"); let last_stmt = original.statements.last().unwrap(); - assert_eq!(last_stmt.predicate.name, "complex_1"); + assert_eq!(last_stmt.predicate.predicate_name(), "complex_1"); } #[test] diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index bd7393f..49575b5 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -12,7 +12,7 @@ use std::{ use hex::ToHex; use crate::{ - lang::frontend_ast::*, + lang::{frontend_ast::*, Module}, middleware::{CustomPredicateBatch, Hash, NativePredicate}, }; @@ -49,6 +49,8 @@ pub struct SymbolTable { pub predicates: HashMap, /// Wildcard scopes for each custom predicate pub wildcard_scopes: HashMap, + /// Imported modules (bound name → Module reference) + pub imported_modules: HashMap>, } /// Information about a predicate @@ -71,6 +73,11 @@ pub enum PredicateKind { batch: Arc, index: usize, }, + ModuleImported { + module_name: String, + predicate_name: String, + predicate_index: usize, + }, IntroImported { name: String, verifier_data_hash: Hash, @@ -107,39 +114,45 @@ pub enum DiagnosticLevel { pub use crate::lang::error::ValidationError; -/// Validate an AST document +/// Mode for parsing/validation - determines what constructs are allowed +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParseMode { + /// Module mode: predicate definitions allowed, REQUEST block not allowed + Module, + /// Request mode: REQUEST block required, predicate definitions not allowed + Request, +} + +/// Validate an AST document in the given mode pub fn validate( document: Document, - available_batches: &[Arc], + available_modules: &HashMap>, + mode: ParseMode, ) -> Result { - let validator = Validator::new(available_batches); + let validator = Validator::new(available_modules, mode); validator.validate(document) } struct Validator { - available_batches: HashMap>, + available_modules: HashMap>, symbols: SymbolTable, diagnostics: Vec, custom_predicate_count: usize, + mode: ParseMode, } impl Validator { - fn new(batches: &[Arc]) -> Self { - let mut available_batches = HashMap::new(); - for batch in batches { - // Store by hex ID for lookup - let id = format!("0x{}", batch.id().encode_hex::()); - available_batches.insert(id, batch.clone()); - } - + fn new(available_modules: &HashMap>, mode: ParseMode) -> Self { Self { - available_batches, + available_modules: available_modules.clone(), symbols: SymbolTable { predicates: HashMap::new(), wildcard_scopes: HashMap::new(), + imported_modules: HashMap::new(), }, diagnostics: Vec::new(), custom_predicate_count: 0, + mode, } } @@ -160,25 +173,36 @@ impl Validator { fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> { // First process imports for item in &document.items { - if let DocumentItem::UseBatchStatement(use_stmt) = item { - self.process_use_batch_statement(use_stmt)?; + if let DocumentItem::UseModuleStatement(use_stmt) = item { + self.process_use_module_statement(use_stmt)?; } if let DocumentItem::UseIntroStatement(use_stmt) = item { self.process_use_intro_statement(use_stmt)?; } } - // Then process custom predicate definitions + // Check mode constraints for predicate definitions + let mut has_predicates = false; for item in &document.items { if let DocumentItem::CustomPredicateDef(pred_def) = item { + if self.mode == ParseMode::Request { + return Err(ValidationError::PredicatesNotAllowedInRequest { + span: pred_def.span, + }); + } + has_predicates = true; self.process_custom_predicate_def(pred_def)?; } } - // Check for multiple REQUEST definitions (only one allowed) + // Check mode constraints for REQUEST blocks + let mut has_request = false; let mut first_request_span = None; for item in &document.items { if let DocumentItem::RequestDef(req) = item { + if self.mode == ParseMode::Module { + return Err(ValidationError::RequestNotAllowedInModule { span: req.span }); + } if let Some(first_span) = first_request_span { return Err(ValidationError::MultipleRequestDefinitions { first_span: Some(first_span), @@ -186,61 +210,44 @@ impl Validator { }); } first_request_span = req.span; + has_request = true; } } + // Enforce that modules have predicates and requests have a REQUEST block + match self.mode { + ParseMode::Module if !has_predicates => { + return Err(ValidationError::NoPredicatesInModule); + } + ParseMode::Request if !has_request => { + return Err(ValidationError::NoRequestBlock); + } + _ => {} + } + Ok(()) } - fn process_use_batch_statement( + fn process_use_module_statement( &mut self, - use_stmt: &UseBatchStatement, + use_stmt: &UseModuleStatement, ) -> Result<(), ValidationError> { - let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::()); + let alias = &use_stmt.alias.name; + let hash = &use_stmt.hash.hash; - let batch = self.available_batches.get(&batch_id).ok_or_else(|| { - ValidationError::BatchNotFound { - id: batch_id.clone(), - span: use_stmt.batch_ref.span, - } - })?; + // Check if the module is available by hash + let module = + self.available_modules + .get(hash) + .ok_or_else(|| ValidationError::ModuleNotFound { + name: hash.encode_hex::(), + span: use_stmt.span, + })?; - if use_stmt.imports.len() != batch.predicates().len() { - return Err(ValidationError::ImportArityMismatch { - expected: batch.predicates().len(), - found: use_stmt.imports.len(), - span: use_stmt.span, - }); - } - - for (i, import) in use_stmt.imports.iter().enumerate() { - if let ImportName::Named(name) = import { - if self.symbols.predicates.contains_key(name) { - return Err(ValidationError::DuplicateImport { - name: name.clone(), - span: use_stmt.span, - }); - } - - let pred = &batch.predicates()[i]; - // CustomPredicate has args_len (public args) and wildcard_names (total args) - let total_arity = pred.wildcard_names.len(); - let public_arity = pred.args_len; - - self.symbols.predicates.insert( - name.clone(), - PredicateInfo { - kind: PredicateKind::BatchImported { - batch: batch.clone(), - index: i, - }, - arity: total_arity, - public_arity, - source_span: use_stmt.span, - }, - ); - } - } + // Store the module keyed by alias for later qualified name resolution + self.symbols + .imported_modules + .insert(alias.clone(), module.clone()); Ok(()) } @@ -435,7 +442,11 @@ impl Validator { stmt: &StatementTmpl, wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { - let pred_name = &stmt.predicate.name; + let pred_name = stmt.predicate.predicate_name(); + let pred_span = match &stmt.predicate { + PredicateRef::Local(id) => id.span, + PredicateRef::Qualified { predicate, .. } => predicate.span, + }; let wc_names = match wildcard_context { Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(), @@ -444,31 +455,65 @@ impl Validator { self.validate_wildcard_names(&wc_names)?; // Check if predicate exists - let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) { - // Native predicate - Some(PredicateInfo { - kind: PredicateKind::Native(native), - arity: native.arity(), - public_arity: native.arity(), - source_span: None, - }) - } else if let Some(info) = self.symbols.predicates.get(pred_name) { - // Custom or imported predicate - Some(info.clone()) - } else if wc_names.contains(pred_name) { - None - } else { - return Err(ValidationError::UndefinedPredicate { - name: pred_name.clone(), - span: stmt.predicate.span, - }); + let pred_info = match &stmt.predicate { + PredicateRef::Qualified { module, predicate } => { + // Look up the predicate in the imported module + let module_name = &module.name; + if let Some(imported_module) = self.symbols.imported_modules.get(module_name) { + // Find the predicate in the module + if let Some(&idx) = imported_module.predicate_index.get(&predicate.name) { + let module_pred = &imported_module.batch.predicates()[idx]; + Some(PredicateInfo { + kind: PredicateKind::ModuleImported { + module_name: module_name.clone(), + predicate_name: predicate.name.clone(), + predicate_index: idx, + }, + arity: module_pred.wildcard_names.len(), + public_arity: module_pred.args_len, + source_span: None, + }) + } else { + return Err(ValidationError::UndefinedPredicate { + name: format!("{}::{}", module_name, predicate.name), + span: pred_span, + }); + } + } else { + return Err(ValidationError::ModuleNotFound { + name: module_name.clone(), + span: module.span, + }); + } + } + PredicateRef::Local(_) => { + if let Ok(native) = NativePredicate::from_str(pred_name) { + // Native predicate + Some(PredicateInfo { + kind: PredicateKind::Native(native), + arity: native.arity(), + public_arity: native.arity(), + source_span: None, + }) + } else if let Some(info) = self.symbols.predicates.get(pred_name) { + // Custom or imported predicate + Some(info.clone()) + } else if wc_names.contains(&pred_name.to_string()) { + None + } else { + return Err(ValidationError::UndefinedPredicate { + name: pred_name.to_string(), + span: pred_span, + }); + } + } }; if let Some(ref pred_info) = pred_info { let expected_arity = pred_info.public_arity; if stmt.args.len() != expected_arity { return Err(ValidationError::ArgumentCountMismatch { - predicate: pred_name.clone(), + predicate: pred_name.to_string(), expected: expected_arity, found: stmt.args.len(), span: stmt.span, @@ -491,13 +536,15 @@ impl Validator { // For custom predicates, only wildcards and literals are allowed if matches!( pred_info.map(|i| &i.kind), - Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. }) + Some(PredicateKind::Custom { .. }) + | Some(PredicateKind::BatchImported { .. }) + | Some(PredicateKind::ModuleImported { .. }) ) { for arg in &stmt.args { match arg { StatementTmplArg::AnchoredKey(_) => { return Err(ValidationError::InvalidArgumentType { - predicate: stmt.predicate.name.clone(), + predicate: stmt.predicate.predicate_name().to_string(), span: stmt.span, }); } @@ -552,25 +599,30 @@ impl Validator { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::{ - lang::{frontend_ast::parse::parse_document, parser::parse_podlang}, + lang::{frontend_ast::parse::parse_document, parser::parse_podlang, Module}, middleware::{CustomPredicate, Params, EMPTY_HASH}, }; - fn parse_and_validate( + fn parse_and_validate_module( input: &str, - batches: &[Arc], + modules: &HashMap>, ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, batches) + validate(document, modules, ParseMode::Module) } - #[test] - fn test_validate_empty() { - let result = parse_and_validate("", &[]); - assert!(result.is_ok()); + fn parse_and_validate_request( + input: &str, + modules: &HashMap>, + ) -> Result { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + validate(document, modules, ParseMode::Request) } #[test] @@ -578,7 +630,7 @@ mod tests { let input = r#"REQUEST( Equal(A["foo"], B["bar"]) )"#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_request(input, &HashMap::new()); assert!(result.is_ok()); } @@ -589,7 +641,7 @@ mod tests { Equal(A["foo"], B["bar"]) ) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(result.is_ok()); let validated = result.unwrap(); @@ -602,7 +654,7 @@ mod tests { let input = r#"REQUEST( UndefinedPred(A, B) )"#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_request(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::UndefinedPredicate { .. }) @@ -616,7 +668,7 @@ mod tests { Equal(A["foo"], B["bar"]) ) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!( matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B") ); @@ -627,7 +679,7 @@ mod tests { let input = r#"REQUEST( Equal(A, B, C) )"#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_request(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::ArgumentCountMismatch { .. }) @@ -640,7 +692,7 @@ mod tests { my_pred(A) = AND (Equal(A["x"], 1)) my_pred(B) = AND (Equal(B["y"], 2)) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::DuplicatePredicate { .. }) @@ -652,7 +704,7 @@ mod tests { let input = r#" my_pred(A, A) = AND (Equal(A["x"], 1)) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::DuplicateWildcard { .. }) @@ -664,7 +716,7 @@ mod tests { let input = r#" my_pred(A, Lt) = AND (Equal(A["x"], Lt)) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::WildcardPredicateNameCollision { .. }) @@ -673,16 +725,36 @@ mod tests { #[test] fn test_custom_predicate_with_anchored_key() { - let input = r#" - my_pred(A, B) = AND ( - Equal(A["foo"], B["bar"]) - ) - + // First create a module with the predicate + let params = Params::default(); + let pred = CustomPredicate::and( + ¶ms, + "my_pred".to_string(), + vec![], + 2, + vec!["A".to_string(), "B".to_string()], + ) + .unwrap(); + + let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]); + let test_module = Arc::new(Module::new(batch, HashMap::new())); + let module_hash = test_module.id().encode_hex::(); + + let mut available_modules = HashMap::new(); + available_modules.insert(test_module.id(), test_module); + + // Test that passing anchored key to custom predicate fails + let input = format!( + r#" + use module 0x{} as testmod + REQUEST( - my_pred(X["key"], Y) + testmod::my_pred(X["key"], Y) ) - "#; - let result = parse_and_validate(input, &[]); + "#, + module_hash + ); + let result = parse_and_validate_request(&input, &available_modules); assert!(matches!( result, Err(ValidationError::InvalidArgumentType { .. }) @@ -695,12 +767,12 @@ mod tests { pred1(A) = AND ( pred2(A) ) - + pred2(B) = AND ( Equal(B["x"], 1) ) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(result.is_ok()); } @@ -712,7 +784,7 @@ mod tests { Equal(B["z"], C["w"]) ) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_module(input, &HashMap::new()); assert!(result.is_ok()); let validated = result.unwrap(); @@ -743,7 +815,7 @@ mod tests { span: None, })], }; - let result = validate(document, &[]); + let result = validate(document, &HashMap::new(), ParseMode::Module); assert!(matches!( result, Err(ValidationError::EmptyStatementList { .. }) @@ -756,7 +828,7 @@ mod tests { REQUEST(Equal(A["x"], 1)) REQUEST(Equal(B["y"], 2)) "#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_request(input, &HashMap::new()); assert!(matches!( result, Err(ValidationError::MultipleRequestDefinitions { .. }) @@ -764,10 +836,14 @@ mod tests { } #[test] - fn test_use_statement() { + fn test_use_module_statement() { + use std::sync::Arc; + + use hex::ToHex; + let params = Params::default(); - // Create a batch to import + // Create a module to import let pred = CustomPredicate::and( ¶ms, "imported".to_string(), @@ -778,28 +854,33 @@ mod tests { .unwrap(); let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]); + let test_module = Arc::new(Module::new(batch, HashMap::new())); + let module_hash = test_module.id().encode_hex::(); + + let mut available_modules = HashMap::new(); + available_modules.insert(test_module.id(), test_module); - let batch_id = batch.id().encode_hex::(); let input = format!( r#" - use batch imported_pred from 0x{} + use module 0x{} as testmod use intro intro_pred() from 0x{} REQUEST( - imported_pred(A, B) + testmod::imported(A, B) intro_pred() ) "#, - batch_id, + module_hash, EMPTY_HASH.encode_hex::() ); - let result = parse_and_validate(&input, &[batch]); + let result = parse_and_validate_request(&input, &available_modules); assert!(result.is_ok()); let validated = result.unwrap(); - assert!(validated.symbols.predicates.contains_key("imported_pred")); + // Module predicates are accessed via qualified names, so no local binding assert!(validated.symbols.predicates.contains_key("intro_pred")); + assert!(validated.symbols.imported_modules.contains_key("testmod")); } #[test] @@ -809,7 +890,7 @@ mod tests { DictContains(D, K, V) SetNotContains(S, E) )"#; - let result = parse_and_validate(input, &[]); + let result = parse_and_validate_request(input, &HashMap::new()); assert!(result.is_ok()); } } diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index f6d6baa..3002d15 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -26,12 +26,9 @@ arg_section = { public_arg_list = { identifier ~ ("," ~ identifier)* } private_arg_list = { identifier ~ ("," ~ identifier)* } -document = { SOI ~ (use_batch_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI } +document = { SOI ~ (use_module_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI } -use_batch_statement = { "use" ~ "batch" ~ use_predicate_list ~ "from" ~ batch_ref } -use_predicate_list = { import_name ~ ("," ~ import_name)* } -import_name = { identifier | "_" } -batch_ref = { hash_hex } +use_module_statement = { "use" ~ "module" ~ hash_hex ~ "as" ~ identifier } use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref } use_intro_arg_list = { identifier ~ ("," ~ identifier)* } @@ -55,7 +52,11 @@ statement_list = { statement+ } statement_arg = { literal_value | anchored_key | identifier } statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* } -statement = { identifier ~ "(" ~ statement_arg_list? ~ ")" } +// Predicate reference: either qualified (module::predicate) or local (predicate) +predicate_ref = { qualified_predicate_ref | identifier } +qualified_predicate_ref = { identifier ~ "::" ~ identifier } + +statement = { predicate_ref ~ "(" ~ statement_arg_list? ~ ")" } // Anchored Key: Var["key_literal"] or Var.key_identifier anchored_key = { diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 52979dd..d01377a 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -1,110 +1,118 @@ -//! Podlang front-end: parsing, validation, lowering, and multi-batch output. +//! Podlang front-end: parsing, validation, and lowering. //! -//! This module is the high-level entrypoint to the Podlang pipeline. It: -//! - Parses a Podlang document (`parse_podlang`). -//! - Validates names, imports, and well-formedness (`frontend_ast_validate`). -//! - Lowers to middleware structures, including automatic predicate splitting and -//! dependency-aware packing into one or more custom predicate batches (`frontend_ast_split`, -//! `frontend_ast_batch`, `frontend_ast_lower`). +//! This module is the high-level entrypoint to the Podlang pipeline. //! -//! The result is a [`PodlangOutput`], which contains: -//! - `custom_batches`: a [`PredicateBatches`] container (possibly empty) with all custom -//! predicates defined in the document. Use -//! [`PredicateBatches::apply_predicate`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate) -//! to apply a predicate into a `MainPodBuilder` (recommended primary API), or -//! [`apply_predicate_with`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate_with) -//! for advanced control. -//! - `request`: a `PodRequest` containing the request templates defined by a `REQUEST(...)` block -//! in the document (or empty if none was provided). +//! ## API //! -//! Notes -//! - Predicate splitting: large predicates are automatically split into a chain of smaller -//! predicates while preserving semantics; only the final chain result is public when applying a -//! predicate as public. -//! - Multi-batch packing: predicates are packed dependency-aware; cross-batch references always -//! point to earlier batches and forward references cannot occur. -//! - Backwards compatibility: `PodlangOutput::first_batch()` is provided to ease migration of code -//! that expects a single custom predicate batch. +//! - [`load_module`]: Load a module file containing predicate definitions. +//! Returns a [`Module`] wrapping a `CustomPredicateBatch`. +//! +//! - [`parse_request`]: Parse a request file containing a REQUEST block. +//! Returns a [`PodRequest`] with statement templates. +//! +//! ## Module vs Request +//! +//! - **Modules** contain predicate definitions (`pred(A) = AND(...)`) and imports. +//! They cannot contain a REQUEST block. +//! +//! - **Requests** contain a REQUEST block and imports. +//! They cannot define predicates. +//! +//! ## Using Modules +//! +//! Use [`Module::apply_predicate`] to apply a predicate into a `MainPodBuilder` +//! (recommended), or [`Module::apply_predicate_with`] for advanced control. +//! +//! Large predicates are automatically split into chains of smaller predicates; +//! `apply_predicate` handles this transparently. //! pub mod error; pub mod frontend_ast; -pub mod frontend_ast_batch; pub mod frontend_ast_lower; pub mod frontend_ast_split; pub mod frontend_ast_validate; +pub mod module; pub mod parser; pub mod pretty_print; use std::sync::Arc; pub use error::LangError; -pub use frontend_ast_batch::{MultiOperationError, PredicateBatches}; pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult}; +pub use module::{Module, MultiOperationError}; pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use pretty_print::PrettyPrint; -use crate::{ - frontend::PodRequest, - middleware::{CustomPredicateBatch, Params}, -}; +use crate::{frontend::PodRequest, middleware::Params}; -/// Final result of processing a Podlang document. +/// Load a module from Podlang source. /// -/// - `custom_batches`: all custom predicates defined in the document, possibly spanning multiple -/// batches. Use [`PredicateBatches`] APIs to look up predicates by name and apply them. -/// - `request`: the request templates defined in the document (empty if not present). -#[derive(Debug, Clone)] -pub struct PodlangOutput { - pub custom_batches: PredicateBatches, - pub request: PodRequest, -} - -impl PodlangOutput { - /// Get the first batch, if any (for backwards compatibility). - /// - /// Prefer using `custom_batches` directly if your code expects multiple batches. - pub fn first_batch(&self) -> Option<&Arc> { - self.custom_batches.first_batch() - } -} - -/// Parse, validate, and lower a Podlang document into middleware structures. +/// Modules contain predicate definitions and imports, but no REQUEST block. /// -/// - `input`: Podlang source. -/// - `params`: middleware parameters limiting sizes/arity and controlling lowering behavior. -/// - `available_batches`: external batches available for `use batch ... from 0x...` imports. -/// -/// Returns a [`PodlangOutput`] containing custom predicate batches (if any) and a `PodRequest` -/// (possibly empty). -pub fn parse( - input: &str, +/// - `source`: Podlang source code +/// - `name`: Name for the module (used in batch naming) +/// - `params`: Middleware parameters limiting sizes/arity +/// - `available_modules`: External modules available for `use module ...` imports +pub fn load_module( + source: &str, + name: &str, params: &Params, - available_batches: &[Arc], -) -> Result { - let pairs = parse_podlang(input)?; + available_modules: Vec>, +) -> Result { + let pairs = parse_podlang(source)?; let document_pair = pairs .into_iter() .next() .expect("parse_podlang should always return at least one pair for a valid document"); let document = frontend_ast::parse::parse_document(document_pair)?; - let validated = frontend_ast_validate::validate(document, available_batches)?; - let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?; + let available_modules_map = available_modules + .iter() + .map(|m| (m.id(), m.clone())) + .collect(); + let validated = frontend_ast_validate::validate( + document, + &available_modules_map, + frontend_ast_validate::ParseMode::Module, + )?; + let module = frontend_ast_lower::lower_module(validated, params, name)?; + Ok(module) +} - let custom_batches = lowered.batches.unwrap_or_default(); - - let request = lowered.request.unwrap_or_else(|| { - // If no request, create an empty one - PodRequest::new(vec![]) - }); - - Ok(PodlangOutput { - custom_batches, - request, - }) +/// Parse a request from Podlang source. +/// +/// Requests contain a REQUEST block and imports, but no predicate definitions. +/// +/// - `source`: Podlang source code +/// - `params`: Middleware parameters limiting sizes/arity +/// - `available_modules`: External modules available for `use module ...` imports +pub fn parse_request( + source: &str, + params: &Params, + available_modules: &[Arc], +) -> Result { + let pairs = parse_podlang(source)?; + let document_pair = pairs + .into_iter() + .next() + .expect("parse_podlang should always return at least one pair for a valid document"); + let document = frontend_ast::parse::parse_document(document_pair)?; + let available_modules_map = available_modules + .iter() + .map(|m| (m.id(), m.clone())) + .collect(); + let validated = frontend_ast_validate::validate( + document, + &available_modules_map, + frontend_ast_validate::ParseMode::Request, + )?; + let request = frontend_ast_lower::lower_request(validated, params)?; + Ok(request) } #[cfg(test)] mod tests { + use std::collections::HashMap; + use hex::ToHex; use pretty_assertions::assert_eq; @@ -143,11 +151,6 @@ mod tests { PredicateOrWildcard::Predicate(pred) } - // Helper to get the first batch from the output - fn first_batch(output: &super::PodlangOutput) -> &Arc { - output.first_batch().expect("Expected at least one batch") - } - #[test] fn test_e2e_simple_predicate() -> Result<(), LangError> { let input = r#" @@ -157,12 +160,9 @@ mod tests { "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let batch_result = first_batch(&processed); - let request_result = processed.request.templates(); + let module = load_module(input, "test_module", ¶ms, vec![])?; - assert_eq!(request_result.len(), 0); - assert_eq!(batch_result.predicates().len(), 1); + assert_eq!(module.batch.predicates().len(), 1); // Expected structure let expected_statements = vec![StatementTmpl { @@ -180,9 +180,9 @@ mod tests { names(&["PodA", "PodB"]), )?; let expected_batch = - CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); + CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]); - assert_eq!(*batch_result, expected_batch); + assert_eq!(&*module.batch, &*expected_batch); Ok(()) } @@ -197,10 +197,9 @@ mod tests { "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let request_templates = processed.request.templates(); + let request = parse_request(input, ¶ms, &[])?; + let request_templates = request.templates(); - assert!(processed.custom_batches.is_empty()); assert!(!request_templates.is_empty()); // Expected structure @@ -236,12 +235,9 @@ mod tests { "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let batch_result = first_batch(&processed); - let request_result = processed.request.templates(); + let module = load_module(input, "test_module", ¶ms, vec![])?; - assert_eq!(request_result.len(), 0); - assert_eq!(batch_result.predicates().len(), 1); + assert_eq!(module.batch.predicates().len(), 1); // Expected structure: Public args: A (index 0). Private args: Temp (index 1) let expected_statements = vec![ @@ -268,58 +264,51 @@ mod tests { names(&["A", "Temp"]), )?; let expected_batch = - CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); + CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]); - assert_eq!(*batch_result, expected_batch); + assert_eq!(&*module.batch, &*expected_batch); Ok(()) } #[test] fn test_e2e_request_with_custom_call() -> Result<(), LangError> { - let input = r#" + // First, load the module + let module_input = r#" my_pred(X, Y) = AND( Equal(X["val"], Y["val"]) ) - - REQUEST( - my_pred(Pod1, Pod2) - ) "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let batch_result = first_batch(&processed); - let request_templates = processed.request.templates(); + let module = Arc::new(load_module(module_input, "my_module", ¶ms, vec![])?); + + assert_eq!(module.batch.predicates().len(), 1); + + let module_hash = module.id().encode_hex::(); + + // Then, parse the request using the module + let request_input = format!( + r#" + use module 0x{} as my_module + + REQUEST( + my_module::my_pred(Pod1, Pod2) + ) + "#, + module_hash + ); + + let request = parse_request(&request_input, ¶ms, std::slice::from_ref(&module))?; + let request_templates = request.templates(); - assert_eq!(batch_result.predicates().len(), 1); assert!(!request_templates.is_empty()); - // Expected Batch structure - let expected_pred_statements = vec![StatementTmpl { - pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), - args: vec![ - sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val") - sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val") - ], - }]; - let expected_predicate = CustomPredicate::and( - ¶ms, - "my_pred".to_string(), - expected_pred_statements, - 2, // args_len (X, Y) - names(&["X", "Y"]), - )?; - let expected_batch = - CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); - - assert_eq!(*batch_result, expected_batch); - // Expected Request structure // Pod1 -> Wildcard 0, Pod2 -> Wildcard 1 let expected_request_templates = vec![StatementTmpl { pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - expected_batch, + module.batch.clone(), 0, ))), args: vec![ @@ -335,25 +324,36 @@ mod tests { #[test] fn test_e2e_request_with_various_args() -> Result<(), LangError> { - let input = r#" + // First, load the module + let module_input = r#" some_pred(A, B, C) = AND( Equal(A["foo"], B["bar"]) ) - - REQUEST( - some_pred( - Var1, // Wildcard - 12345, // Int Literal - "hello_string" // String Literal (Removed invalid AK args) - ) - Equal(AnotherPod["another_key"], Var1["some_field"]) - ) "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let batch_result = first_batch(&processed); - let request_templates = processed.request.templates(); + let module = Arc::new(load_module(module_input, "some_module", ¶ms, vec![])?); + + let module_hash = module.id().encode_hex::(); + + // Then, parse the request + let request_input = format!( + r#" + use module 0x{} as some_module + + REQUEST( + some_module::some_pred( + Var1, // Wildcard + 12345, // Int Literal + "hello_string" // String Literal + ) + Equal(AnotherPod["another_key"], Var1["some_field"]) + ) + "#, + module_hash + ); + + let request = parse_request(&request_input, ¶ms, std::slice::from_ref(&module))?; + let request_templates = request.templates(); - assert_eq!(batch_result.predicates().len(), 1); // some_pred is defined assert!(!request_templates.is_empty()); // Expected Wildcard Indices in Request Scope: @@ -364,7 +364,7 @@ mod tests { let expected_templates = vec![ StatementTmpl { pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - batch_result.clone(), + module.batch.clone(), 0, ))), // Refers to some_pred args: vec![ @@ -402,10 +402,9 @@ mod tests { "#; let params = Params::default(); - let processed = parse(input, ¶ms, &[])?; - let request_templates = processed.request.templates(); + let request = parse_request(input, ¶ms, &[])?; + let request_templates = request.templates(); - assert!(processed.custom_batches.is_empty()); assert!(!request_templates.is_empty()); let expected_templates = vec![ @@ -459,8 +458,8 @@ mod tests { "#; // Parse the input string - let processed = super::parse(input, &Params::default(), &[])?; - let parsed_templates = processed.request.templates(); + let request = parse_request(input, &Params::default(), &[])?; + let parsed_templates = request.templates(); // Define Expected Templates (Copied from prover/mod.rs) let now_minus_18y_val = Value::from(1169909388_i64); @@ -549,11 +548,6 @@ mod tests { "Parsed ZuKYC request templates do not match the expected hard-coded version" ); - assert!( - processed.custom_batches.is_empty(), - "Expected no custom predicates for a REQUEST only input" - ); - Ok(()) } @@ -591,14 +585,10 @@ mod tests { ) "#; - let processed = super::parse(input, ¶ms, &[])?; + let module = load_module(input, "ethdos", ¶ms, vec![])?; - assert!( - processed.request.templates().is_empty(), - "Expected no request templates" - ); assert_eq!( - first_batch(&processed).predicates().len(), + module.batch.predicates().len(), 4, "Expected 4 custom predicates" ); @@ -718,7 +708,7 @@ mod tests { )?; let expected_batch = CustomPredicateBatch::new( - "PodlangBatch".to_string(), + "ethdos".to_string(), vec![ expected_friend_pred, expected_base_pred, @@ -728,8 +718,7 @@ mod tests { ); assert_eq!( - *first_batch(&processed), - expected_batch, + &*module.batch, &*expected_batch, "Processed ETHDoS predicates do not match expected structure" ); @@ -737,10 +726,10 @@ mod tests { } #[test] - fn test_e2e_use_statement() -> Result<(), LangError> { + fn test_e2e_use_module_statement() -> Result<(), LangError> { let params = Params::default(); - // 1. Create a batch to be imported + // 1. Create a module with a predicate to be imported let imported_pred_stmts = vec![StatementTmpl { pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ @@ -755,98 +744,75 @@ mod tests { 2, names(&["A", "B"]), )?; - let available_batch = - CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]); - let available_batches = vec![available_batch.clone()]; + let batch = CustomPredicateBatch::new("my_module".to_string(), vec![imported_predicate]); + let module = Arc::new(Module::new(batch.clone(), HashMap::new())); + let module_hash = module.id().encode_hex::(); - // 2. Create the input string that uses the batch - let batch_id_str = available_batch.id().encode_hex::(); + // 2. Create the input string that uses the module let input = format!( r#" - use batch imported_pred from 0x{} + use module 0x{} as my_module REQUEST( - imported_pred(Pod1, Pod2) + my_module::imported_equal(Pod1, Pod2) ) "#, - batch_id_str + module_hash ); - // 3. Parse the input - let processed = parse(&input, ¶ms, &available_batches)?; - let request_templates = processed.request.templates(); + // 3. Parse the request + let request = parse_request(&input, ¶ms, std::slice::from_ref(&module))?; + let request_templates = request.templates(); - assert!( - processed.custom_batches.is_empty(), - "No custom predicates should be defined in the main input" - ); assert_eq!(request_templates.len(), 1, "Expected one request template"); - // 4. Check the resulting request template - let expected_request_templates = vec![StatementTmpl { - pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - available_batch, - 0, - ))), - args: vec![ - StatementTmplArg::Wildcard(wc("Pod1", 0)), - StatementTmplArg::Wildcard(wc("Pod2", 1)), - ], - }]; - - assert_eq!(request_templates, expected_request_templates); + // 4. Check the resulting request template uses the imported predicate + let template = &request_templates[0]; + assert_eq!(template.args.len(), 2); Ok(()) } #[test] - fn test_e2e_use_statement_complex() -> Result<(), LangError> { + fn test_e2e_use_module_complex() -> Result<(), LangError> { let params = Params::default(); - // 1. Create a batch with multiple predicates + // 1. Create a module with multiple predicates let pred1 = CustomPredicate::and(¶ms, "p1".into(), vec![], 1, names(&["A"]))?; let pred2 = CustomPredicate::and(¶ms, "p2".into(), vec![], 2, names(&["B", "C"]))?; let pred3 = CustomPredicate::and(¶ms, "p3".into(), vec![], 1, names(&["D"]))?; - let available_batch = - CustomPredicateBatch::new("MyBatch".to_string(), vec![pred1, pred2, pred3]); - let available_batches = vec![available_batch.clone()]; - - // 2. Create the input string that uses the batch with skips - let batch_id_str = available_batch.id().encode_hex::(); + let batch = CustomPredicateBatch::new("mymodule".to_string(), vec![pred1, pred2, pred3]); + let mymodule = Arc::new(Module::new(batch.clone(), HashMap::new())); + let module_hash = mymodule.id().encode_hex::(); + // 2. Create the input string that uses qualified predicate access let input = format!( r#" - use batch pred_one, _, pred_three from 0x{} + use module 0x{} as mymodule REQUEST( - pred_one(Pod1) - pred_three(Pod2) + mymodule::p1(Pod1) + mymodule::p3(Pod2) ) "#, - batch_id_str + module_hash ); - // 3. Parse the input - let processed = parse(&input, ¶ms, &available_batches)?; - let request_templates = processed.request.templates(); + // 3. Parse the request + let request = parse_request(&input, ¶ms, std::slice::from_ref(&mymodule))?; + let request_templates = request.templates(); assert_eq!(request_templates.len(), 2, "Expected two request templates"); // 4. Check the resulting request templates let expected_templates = vec![ StatementTmpl { - pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - available_batch.clone(), - 0, - ))), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch.clone(), 0))), args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))], }, StatementTmpl { - pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - available_batch, - 2, - ))), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 2))), args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))], }, ]; @@ -857,10 +823,10 @@ mod tests { } #[test] - fn test_e2e_custom_predicate_uses_import() -> Result<(), LangError> { + fn test_e2e_custom_predicate_uses_module() -> Result<(), LangError> { let params = Params::default(); - // 1. Create a batch with a predicate to be imported + // 1. Create a module with a predicate to be imported let imported_pred_stmts = vec![StatementTmpl { pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")], @@ -872,47 +838,38 @@ mod tests { 2, names(&["A", "B"]), )?; - let available_batch = - CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]); - let available_batches = vec![available_batch.clone()]; + let batch = CustomPredicateBatch::new("extmod".to_string(), vec![imported_predicate]); + let extmod = Arc::new(Module::new(batch.clone(), HashMap::new())); + let extmod_hash = extmod.id().encode_hex::(); // 2. Create the input string that defines a new predicate using the imported one - let batch_id_str = available_batch.id().encode_hex::(); - let input = format!( r#" - use batch imported_eq from 0x{} + use module 0x{} as extmod wrapper_pred(X, Y) = AND( - imported_eq(X, Y) + extmod::imported_equal(X, Y) ) "#, - batch_id_str + extmod_hash ); - // 3. Parse the input - let processed = parse(&input, ¶ms, &available_batches)?; + // 3. Load as module + let module = load_module(&input, "test", ¶ms, vec![extmod])?; - assert!( - processed.request.templates().is_empty(), - "No request should be defined" - ); assert_eq!( - first_batch(&processed).predicates().len(), + module.batch.predicates().len(), 1, "Expected one custom predicate to be defined" ); // 4. Check the resulting predicate definition - let defined_pred = &first_batch(&processed).predicates()[0]; + let defined_pred = &module.batch.predicates()[0]; assert_eq!(defined_pred.name, "wrapper_pred"); assert_eq!(defined_pred.statements.len(), 1); let expected_statement = StatementTmpl { - pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( - available_batch.clone(), - 0, - ))), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 0))), args: vec![ StatementTmplArg::Wildcard(wc("X", 0)), StatementTmplArg::Wildcard(wc("Y", 1)), @@ -939,8 +896,8 @@ mod tests { "#, ); - let processed = parse(&input, ¶ms, &[])?; - let request_templates = processed.request.templates(); + let request = parse_request(&input, ¶ms, &[])?; + let request_templates = request.templates(); assert_eq!(request_templates.len(), 1); if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) = @@ -998,8 +955,8 @@ mod tests { */ let params = Params::default(); - let processed = parse(&input, ¶ms, &[])?; - let request_templates = processed.request.templates(); + let request = parse_request(&input, ¶ms, &[])?; + let request_templates = request.templates(); let expected_templates = vec![ StatementTmpl { @@ -1034,29 +991,33 @@ mod tests { } #[test] - fn test_e2e_use_unknown_batch() { + fn test_e2e_use_unknown_module() { let params = Params::default(); - let available_batches = &[]; - - let unknown_batch_id = format!("0x{}", "a".repeat(64)); + // Use a hash that doesn't correspond to any loaded module + let fake_hash = EMPTY_HASH.encode_hex::(); let input = format!( r#" - use batch some_pred from {} + use module 0x{} as unknown_module + + REQUEST( + Equal(A["x"], 1) + ) "#, - unknown_batch_id + fake_hash ); - let result = parse(&input, ¶ms, available_batches); + let result = parse_request(&input, ¶ms, &[]); assert!(result.is_err()); match result.err().unwrap() { LangError::Validation(e) => match *e { - frontend_ast_validate::ValidationError::BatchNotFound { id, .. } => { - assert_eq!(id, unknown_batch_id); + frontend_ast_validate::ValidationError::ModuleNotFound { name, .. } => { + // The error now carries the hex-formatted hash + assert_eq!(name, fake_hash); } - _ => panic!("Expected BatchNotFound error, but got {:?}", e), + _ => panic!("Expected ModuleNotFound error, but got {:?}", e), }, e => panic!("Expected LangError::Validation, but got {:?}", e), } @@ -1065,17 +1026,15 @@ mod tests { #[test] fn test_e2e_undefined_wildcard() { let params = Params::default(); - let available_batches = &[]; let input = r#" identity_verified(username, private: identity_dict) = AND( Equal(identity_dict["username"], username) Equal(identity_dict["user_public_key"], user_public_key) ) - "# - .to_string(); + "#; - let result = parse(&input, ¶ms, available_batches); + let result = load_module(input, "test", ¶ms, vec![]); assert!(result.is_err()); diff --git a/src/lang/module.rs b/src/lang/module.rs new file mode 100644 index 0000000..aa4547a --- /dev/null +++ b/src/lang/module.rs @@ -0,0 +1,671 @@ +//! Podlang Module: definition, construction, and predicate application. +//! +//! A [`Module`] wraps a middleware `CustomPredicateBatch` with name resolution +//! and split chain metadata. Use [`build_module`] to construct a Module from +//! validated and split predicates. + +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + frontend::{CustomPredicateBatchBuilder, Operation, OperationArg, StatementTmplBuilder}, + lang::{ + error::BatchingError, + frontend_ast::{ConjunctionType, CustomPredicateDef}, + frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext}, + frontend_ast_split::{SplitChainInfo, SplitResult}, + frontend_ast_validate::SymbolTable, + }, + middleware::{CustomPredicateBatch, CustomPredicateRef, Hash, Params, Statement}, +}; + +/// Errors that can occur when applying predicates +#[derive(Debug, Clone, thiserror::Error)] +pub enum MultiOperationError { + #[error("Predicate not found: {0}")] + PredicateNotFound(String), + + #[error("Chain piece not found: {0}")] + ChainPieceNotFound(String), + + #[error( + "Wrong statement count for predicate '{predicate}': expected {expected}, got {actual}" + )] + WrongStatementCount { + predicate: String, + expected: usize, + actual: usize, + }, + + #[error("No operation steps to apply")] + NoSteps, +} + +/// A Podlang module wrapping a middleware CustomPredicateBatch with name resolution info. +#[derive(Debug, Clone)] +pub struct Module { + /// The middleware representation (CustomPredicateBatch) + pub batch: Arc, + + /// Map from predicate name to index in batch + pub predicate_index: HashMap, + + /// Split chain info for predicates that were split + pub split_chains: HashMap, +} + +impl Module { + /// Create a new Module from a batch, building the predicate_index automatically + pub fn new( + batch: Arc, + split_chains: HashMap, + ) -> Self { + let predicate_index = batch + .predicates() + .iter() + .enumerate() + .map(|(i, p)| (p.name.clone(), i)) + .collect(); + Self { + batch, + predicate_index, + split_chains, + } + } + + /// Root hash of the module's Merkle tree + pub fn id(&self) -> Hash { + self.batch.id() + } + + /// Get a reference to a predicate by name + pub fn predicate_ref_by_name(&self, name: &str) -> Option { + let idx = self.predicate_index.get(name)?; + Some(CustomPredicateRef::new(self.batch.clone(), *idx)) + } + + /// Check if the module contains any predicates + pub fn is_empty(&self) -> bool { + self.batch.predicates().is_empty() + } + + /// Apply a predicate directly into a `MainPodBuilder` (common case). + /// + /// For split predicates, earlier chain links are applied as private, and only the final + /// piece is applied as public when `public` is true. For non-split predicates, the single + /// operation is applied with the provided `public` flag. + /// + /// Arguments: + /// - `builder`: target builder to receive operations + /// - `name`: predicate name + /// - `statements`: user statements in original declaration order + /// - `public`: whether the final result should be public + pub fn apply_predicate( + &self, + builder: &mut crate::frontend::MainPodBuilder, + name: &str, + statements: Vec, + public: bool, + ) -> crate::frontend::Result { + self.apply_predicate_with(name, statements, public, |is_public, op| { + if is_public { + builder.pub_op(op) + } else { + builder.priv_op(op) + } + }) + } + + /// Advanced variant: apply using a custom closure. + /// + /// Prefer `apply_predicate` for common usage. This method allows callers to intercept each + /// operation (with its `public` flag) and decide how to execute it. + /// + /// Arguments: + /// - `name`: predicate name + /// - `statements`: user statements in original declaration order + /// - `public`: whether the final result should be public + /// - `apply_op`: closure `(is_public, operation) -> Result` used to execute each step + pub fn apply_predicate_with( + &self, + name: &str, + statements: Vec, + public: bool, + mut apply_op: F, + ) -> Result + where + F: FnMut(bool, Operation) -> Result, + E: From, + { + let steps = self.build_steps(name, statements, public)?; + + if steps.is_empty() { + return Err(MultiOperationError::NoSteps.into()); + } + + let mut prev_result: Option = None; + + for step in steps { + let op = if let Some(prev) = prev_result { + // Replace the last Statement::None arg with the previous result. + let mut args = step.operation.1; + let last = args + .last_mut() + .expect("chain statement should include placeholder arg"); + assert!( + matches!(last, OperationArg::Statement(Statement::None)), + "expected last arg to be a Statement::None placeholder" + ); + *last = OperationArg::Statement(prev); + Operation(step.operation.0, args, step.operation.2) + } else { + step.operation + }; + + prev_result = Some(apply_op(step.public, op)?); + } + + Ok(prev_result.unwrap()) + } + + /// Build operation steps for a predicate (internal helper) + fn build_steps( + &self, + predicate_name: &str, + statements: Vec, + public: bool, + ) -> Result, MultiOperationError> { + // Check if this predicate was split + let chain_info = match self.split_chains.get(predicate_name) { + Some(info) => info, + None => { + // Not split - single operation with all statements + let pred_ref = self.predicate_ref_by_name(predicate_name).ok_or_else(|| { + MultiOperationError::PredicateNotFound(predicate_name.to_string()) + })?; + + return Ok(vec![OperationStep { + operation: Operation::custom(pred_ref, statements), + public, + }]); + } + }; + + // Validate statement count + if statements.len() != chain_info.real_statement_count { + return Err(MultiOperationError::WrongStatementCount { + predicate: predicate_name.to_string(), + expected: chain_info.real_statement_count, + actual: statements.len(), + }); + } + + // Reorder statements from original order to split order + let mut reordered = vec![Statement::None; statements.len()]; + for (original_idx, stmt) in statements.into_iter().enumerate() { + let split_idx = chain_info.reorder_map[original_idx]; + reordered[split_idx] = stmt; + } + + // Build operations for each piece in execution order + let num_pieces = chain_info.chain_pieces.len(); + + // Compute the starting offset for each piece + let mut piece_offsets = vec![0usize; num_pieces]; + let mut offset = 0; + for i in (0..num_pieces).rev() { + piece_offsets[i] = offset; + offset += chain_info.chain_pieces[i].real_statement_count; + } + + let mut steps = Vec::new(); + for (piece_idx, piece) in chain_info.chain_pieces.iter().enumerate() { + let is_final = piece_idx == num_pieces - 1; + + let piece_ref = self + .predicate_ref_by_name(&piece.name) + .ok_or_else(|| MultiOperationError::ChainPieceNotFound(piece.name.clone()))?; + + let start = piece_offsets[piece_idx]; + let end = start + piece.real_statement_count; + let mut args: Vec = reordered[start..end].to_vec(); + + if piece.has_chain_call { + args.push(Statement::None); + } + + steps.push(OperationStep { + operation: Operation::custom(piece_ref, args), + public: public && is_final, + }); + } + + Ok(steps) + } +} + +/// A single step in a multi-operation sequence for split predicates +struct OperationStep { + operation: Operation, + public: bool, +} + +/// Build a single Module from split predicate results. +/// +/// Takes a list of split results (containing predicates and optional chain info) +/// and builds a single Module. With Merkle tree backing supporting up to 65536 +/// predicates, all predicates from a document fit in one module. +/// +/// `symbols` provides the symbol table for resolving predicate references, +/// including imported predicates from other modules and intro predicates. +pub fn build_module( + split_results: Vec, + params: &Params, + module_name: &str, + symbols: &SymbolTable, +) -> Result { + // Extract predicates and collect split chains + let mut predicates = Vec::new(); + let mut split_chains = HashMap::new(); + + for result in split_results { + // Collect chain info if present + if let Some(chain_info) = result.chain_info { + split_chains.insert(chain_info.original_name.clone(), chain_info); + } + // Flatten predicates + predicates.extend(result.predicates); + } + + if predicates.is_empty() { + // Return an empty module + let empty_batch = CustomPredicateBatch::new(module_name.to_string(), vec![]); + return Ok(Module::new(empty_batch, split_chains)); + } + + // Build reference map: name -> index + let reference_map: HashMap = predicates + .iter() + .enumerate() + .map(|(idx, pred)| (pred.name.name.clone(), idx)) + .collect(); + + // Build the batch + let batch = build_single_batch(&predicates, &reference_map, symbols, params, module_name)?; + + Ok(Module::new(batch, split_chains)) +} + +/// Build a batch with properly resolved references +fn build_single_batch( + predicates: &[CustomPredicateDef], + reference_map: &HashMap, + symbols: &SymbolTable, + params: &Params, + batch_name: &str, +) -> Result, BatchingError> { + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), batch_name.to_string()); + + for pred in predicates { + let name = &pred.name.name; + + // Collect argument names + let public_args: Vec<&str> = pred + .args + .public_args + .iter() + .map(|a| a.name.as_str()) + .collect(); + + let private_args: Vec<&str> = pred + .args + .private_args + .as_ref() + .map(|args| args.iter().map(|a| a.name.as_str()).collect()) + .unwrap_or_default(); + + // Build statement templates with resolved predicates + let statement_builders: Vec = pred + .statements + .iter() + .map(|stmt| build_statement_with_resolved_refs(stmt, reference_map, name, symbols)) + .collect::>()?; + + let conjunction = pred.conjunction_type == ConjunctionType::And; + + builder + .predicate( + name, + conjunction, + &public_args, + &private_args, + &statement_builders, + ) + .map_err(|e| BatchingError::Internal { + message: format!("Failed to add predicate '{}': {}", name, e), + })?; + } + + Ok(builder.finish()) +} + +/// Build a statement template with properly resolved predicate references +fn build_statement_with_resolved_refs( + stmt: &crate::lang::frontend_ast::StatementTmpl, + reference_map: &HashMap, + custom_predicate_name: &str, // custom pred that defines this statement template + symbols: &SymbolTable, +) -> Result { + // Resolve the predicate using the unified resolution function + let context = ResolutionContext::Module { + reference_map, + custom_predicate_name, + }; + + let pred_or_wc = + resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| { + BatchingError::Internal { + message: format!("Unknown predicate reference: '{}'", stmt.predicate), + } + })?; + + // Build the statement template + let mut builder = StatementTmplBuilder::new(pred_or_wc); + + for arg in &stmt.args { + builder = builder.arg(lower_statement_arg(arg)); + } + + Ok(builder) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + lang::{ + frontend_ast::parse::parse_document, + frontend_ast_split::split_predicate_if_needed, + frontend_ast_validate::{validate, ParseMode, ValidatedAST}, + load_module, + parser::parse_podlang, + }, + middleware::{CustomPredicateRef, Predicate, PredicateOrWildcard}, + }; + + /// Helper: parse and validate input, returning predicates and symbol table + fn parse_and_validate(input: &str) -> (Vec, ValidatedAST) { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + let validated = validate(document.clone(), &HashMap::new(), ParseMode::Module) + .expect("Failed to validate"); + + let predicates = document + .items + .into_iter() + .filter_map(|item| match item { + crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred), + _ => None, + }) + .collect(); + + (predicates, validated) + } + + /// Helper: wrap predicates into SplitResult (without actually splitting) + fn preds_to_split_results(predicates: Vec) -> Vec { + predicates + .into_iter() + .map(|pred| SplitResult { + predicates: vec![pred], + chain_info: None, + }) + .collect() + } + + #[test] + fn test_single_predicate() { + let input = r#" + my_pred(A, B) = AND( + Equal(A["x"], B["y"]) + ) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + let result = build_module( + preds_to_split_results(predicates), + ¶ms, + "TestModule", + validated.symbols(), + ); + assert!(result.is_ok()); + + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 1); + } + + #[test] + fn test_multiple_predicates() { + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + let result = build_module( + preds_to_split_results(predicates), + ¶ms, + "TestModule", + validated.symbols(), + ); + assert!(result.is_ok()); + + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 3); + } + + #[test] + fn test_intra_batch_forward_reference() { + // pred2 calls pred1, but pred2 is declared first + // This should work because they're in the same batch + let input = r#" + pred2(B) = AND(pred1(B)) + pred1(A) = AND(Equal(A["x"], 1)) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + let result = build_module( + preds_to_split_results(predicates), + ¶ms, + "TestModule", + validated.symbols(), + ); + assert!(result.is_ok()); + + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 2); + + // pred2 should reference pred1 via BatchSelf + let pred2 = &module.batch.predicates()[0]; + let stmt = &pred2.statements[0]; + assert!(matches!( + stmt.pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) + )); // pred1 is at index 1 + } + + #[test] + fn test_mutual_recursion() { + // pred1 calls pred2, pred2 calls pred1 - mutual recursion + // This should work because they're in the same batch + let input = r#" + pred1(A) = AND(pred2(A)) + pred2(B) = AND(pred1(B)) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + let result = build_module( + preds_to_split_results(predicates), + ¶ms, + "TestModule", + validated.symbols(), + ); + assert!(result.is_ok()); + + let module = result.unwrap(); + assert_eq!(module.batch.predicates().len(), 2); + + // Both should use BatchSelf references + let pred1 = &module.batch.predicates()[0]; + let pred2 = &module.batch.predicates()[1]; + assert!(matches!( + pred1.statements[0].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) + )); // calls pred2 + assert!(matches!( + pred2.statements[0].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) + )); // calls pred1 + } + + #[test] + fn test_predicate_ref_by_name() { + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + let module = build_module( + preds_to_split_results(predicates), + ¶ms, + "TestModule", + validated.symbols(), + ) + .unwrap(); + + // Should be able to look up both predicates + assert!(module.predicate_ref_by_name("pred1").is_some()); + assert!(module.predicate_ref_by_name("pred2").is_some()); + assert!(module.predicate_ref_by_name("nonexistent").is_none()); + } + + #[test] + fn test_split_predicate() { + // A predicate that will split into 2 pieces + let input = r#" + large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let (predicates, validated) = parse_and_validate(input); + let params = Params::default(); + + // Split the predicate + let mut split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + split_results.push(result); + } + + // Should split into 2 pieces + assert_eq!(split_results.len(), 1); + assert_eq!(split_results[0].predicates.len(), 2); + assert!(split_results[0].chain_info.is_some()); + + let module = + build_module(split_results, ¶ms, "TestModule", validated.symbols()).unwrap(); + + // Verify chain info is preserved + let chain_info = module.split_chains.get("large_pred").unwrap(); + assert_eq!(chain_info.chain_pieces.len(), 2); + assert_eq!(chain_info.real_statement_count, 6); + } + + #[test] + fn test_load_module_importing_two_modules() { + use hex::ToHex; + + let params = Params::default(); + + // Module "checks": defines is_equal + let checks = Arc::new( + load_module( + r#"is_equal(X, Y) = AND(Equal(X["val"], Y["val"]))"#, + "checks", + ¶ms, + vec![], + ) + .unwrap(), + ); + + // Module "ordering": defines is_less + let ordering = Arc::new( + load_module( + r#"is_less(X, Y) = AND(Lt(X["val"], Y["val"]))"#, + "ordering", + ¶ms, + vec![], + ) + .unwrap(), + ); + + let checks_hash = checks.id().encode_hex::(); + let ordering_hash = ordering.id().encode_hex::(); + + // Module "combined": imports both, uses predicates from each + let combined = load_module( + &format!( + r#" + use module 0x{} as checks + use module 0x{} as ordering + + equal_and_ordered(A, B, C) = AND( + checks::is_equal(A, B) + ordering::is_less(B, C) + ) + "#, + checks_hash, ordering_hash + ), + "combined", + ¶ms, + vec![checks.clone(), ordering.clone()], + ) + .unwrap(); + + assert_eq!(combined.batch.predicates().len(), 1); + let pred = &combined.batch.predicates()[0]; + assert_eq!(pred.name, "equal_and_ordered"); + assert_eq!(pred.statements.len(), 2); + + // First statement references checks::is_equal (external Custom ref, not BatchSelf) + let checks_ref = CustomPredicateRef::new(checks.batch.clone(), 0); + assert_eq!( + *pred.statements[0].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::Custom(checks_ref)) + ); + + // Second statement references ordering::is_less (external Custom ref, not BatchSelf) + let ordering_ref = CustomPredicateRef::new(ordering.batch.clone(), 0); + assert_eq!( + *pred.statements[1].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref)) + ); + } +} diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index 1ffeb96..7d6c90f 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -219,7 +219,7 @@ mod tests { use super::*; use crate::{ backends::plonky2::primitives::ec::schnorr::SecretKey, - lang::parse, + lang::load_module, middleware::{ CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard, @@ -388,20 +388,19 @@ mod tests { /// Helper function for round-trip testing fn assert_round_trip(input: &str) { let params = Params::default(); - let available_batches = &[]; // Step 1: Parse the input - let parsed_result = - parse(input, ¶ms, available_batches).expect("Initial parsing should succeed"); + let module = + load_module(input, "test", ¶ms, vec![]).expect("Initial parsing should succeed"); // Step 2: Pretty-print the parsed batch - let batch = parsed_result.first_batch().expect("Expected batch"); + let batch = &module.batch; let pretty_printed = batch.to_podlang_string(); // Step 3: Parse the pretty-printed result - let reparsed_result = - parse(&pretty_printed, ¶ms, available_batches).expect("Reparsing should succeed"); - let reparsed_batch = reparsed_result.first_batch().expect("Expected batch"); + let reparsed_module = load_module(&pretty_printed, "test", ¶ms, vec![]) + .expect("Reparsing should succeed"); + let reparsed_batch = &reparsed_module.batch; // Step 4: Verify the ASTs are equivalent assert_eq!( @@ -556,16 +555,17 @@ mod tests { "#; let params = Params::default(); - let parsed_result = parse(input, ¶ms, &[]).expect("Parsing should succeed"); - let batch = parsed_result.first_batch().expect("Expected batch"); + let module = load_module(input, "test", ¶ms, vec![]).expect("Parsing should succeed"); + let batch = &module.batch; let pretty_printed = batch.to_podlang_string(); println!("Original input:\n{}", input); println!("\nPretty-printed output:\n{}", pretty_printed); - let reparsed = parse(&pretty_printed, ¶ms, &[]).expect("Reparsing should succeed"); - let reparsed_batch = reparsed.first_batch().expect("Expected batch"); + let reparsed = load_module(&pretty_printed, "test", ¶ms, vec![]) + .expect("Reparsing should succeed"); + let reparsed_batch = &reparsed.batch; assert_eq!(batch.predicates(), reparsed_batch.predicates()); } @@ -629,14 +629,15 @@ mod tests { ); let params = Params::default(); - let parsed_result = parse(&input, ¶ms, &[]).expect("Should parse successfully"); - let batch = parsed_result.first_batch().expect("Expected batch"); + let module = + load_module(&input, "test", ¶ms, vec![]).expect("Should parse successfully"); + let batch = &module.batch; let pretty_printed = batch.to_podlang_string(); - let reparsed_result = - parse(&pretty_printed, ¶ms, &[]).expect("Should reparse successfully"); - let reparsed_batch = reparsed_result.first_batch().expect("Expected batch"); + let reparsed_module = load_module(&pretty_printed, "test", ¶ms, vec![]) + .expect("Should reparse successfully"); + let reparsed_batch = &reparsed_module.batch; assert_eq!( batch.predicates(),