From d1b7b4d37e87536219db499412718e281742aa1c Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Wed, 28 Jan 2026 06:54:21 +0100 Subject: [PATCH] Improved predicate splitting (#445) * Multi-batch splitting * Invoke split predicates by name, passing in full argument list * Reorder batches to prevent failure of forward references where possible * Rename APIs for clarity * Simplify example * Add more docs * Review updates * Remove duplicate code * Comment topological sort algorithm --- Cargo.toml | 1 + examples/main_pod_points.rs | 5 +- src/backends/plonky2/mainpod/mod.rs | 4 +- src/examples/custom.rs | 6 +- src/frontend/error.rs | 6 + src/frontend/mod.rs | 85 +- src/lang/error.rs | 34 +- src/lang/frontend_ast_batch.rs | 1446 +++++++++++++++++++++++++++ src/lang/frontend_ast_lower.rs | 592 +++++------ src/lang/frontend_ast_split.rs | 243 +++-- src/lang/mod.rs | 113 ++- src/lang/pretty_print.rs | 21 +- 12 files changed, 2090 insertions(+), 466 deletions(-) create mode 100644 src/lang/frontend_ast_batch.rs diff --git a/Cargo.toml b/Cargo.toml index c080093..537c38e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ rand = "0.8.5" hashbrown = { version = "0.14.3", default-features = false, features = ["serde"] } pest = "2.8.0" pest_derive = "2.8.0" +petgraph = "0.6" directories = { version = "6.0.0", optional = true } minicbor-serde = { version = "0.5.0", features = ["std"], optional = true } serde_bytes = "0.11" diff --git a/examples/main_pod_points.rs b/examples/main_pod_points.rs index 8eb3cfc..2b5f257 100644 --- a/examples/main_pod_points.rs +++ b/examples/main_pod_points.rs @@ -88,7 +88,10 @@ fn main() -> Result<(), Box> { game_pk = game_pk, ); println!("# custom predicate batch:{}", input); - let batch = parse(&input, ¶ms, &[])?.custom_batch; + let batch = parse(&input, ¶ms, &[])? + .first_batch() + .expect("Expected batch") + .clone(); let points_pred = batch.predicate_ref_by_name("points").unwrap(); let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap(); diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 68334b2..78b617c 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1179,7 +1179,9 @@ pub mod tests { &[], ) .unwrap() - .custom_batch; + .first_batch() + .unwrap() + .clone(); let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); let cpr = CustomPredicateRef { batch, index: 0 }; let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap(); diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 63f12bb..8c4550c 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -32,7 +32,11 @@ pub fn eth_dos_batch(params: &Params) -> Result> { eth_dos_ind(src, dst, distance) ) "#; - let batch = parse(input, params, &[]).expect("lang parse").custom_batch; + let batch = parse(input, params, &[]) + .expect("lang parse") + .first_batch() + .expect("Expected batch") + .clone(); println!("a.0. {}", batch.predicates[0]); println!("a.1. {}", batch.predicates[1]); println!("a.2. {}", batch.predicates[2]); diff --git a/src/frontend/error.rs b/src/frontend/error.rs index aab6eb3..3d162d5 100644 --- a/src/frontend/error.rs +++ b/src/frontend/error.rs @@ -71,6 +71,12 @@ impl From for Error { } } +impl From for Error { + fn from(value: crate::lang::MultiOperationError) -> Self { + Error::custom(value.to_string()) + } +} + impl Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index b71ef5c..aa96f46 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1390,7 +1390,11 @@ pub mod tests { Equal(b, 5) ) "#; - let batch = parse(input, ¶ms, &[]).unwrap().custom_batch; + let batch = parse(input, ¶ms, &[]) + .unwrap() + .first_batch() + .unwrap() + .clone(); let pred_test = batch.predicate_ref_by_name("Test").unwrap(); // Try to build with wrong type in 1st arg @@ -1414,4 +1418,83 @@ pub mod tests { Ok(()) } + + #[test] + fn test_apply_predicate_e2e() -> Result<()> { + // End-to-end test of apply_predicate with MockProver + // Tests a split predicate being applied through the full pipeline + let params = Params::default(); + let vd_set = &*MOCK_VD_SET; + + // Create a predicate that will split (6 Equal statements) + // The predicate checks that values at different keys are equal to specific literals + let input = r#" + large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + // Parse and batch the predicate (this handles splitting internally) + let parsed = parse(input, ¶ms, &[])?; + let batches = &parsed.custom_batches; + + // Verify it was split + assert!(batches.split_chain("large_pred").is_some()); + let chain_info = batches.split_chain("large_pred").unwrap(); + assert_eq!(chain_info.chain_pieces.len(), 2); + assert_eq!(chain_info.real_statement_count, 6); + + // Create a signed dict with the required entries + let mut signed_builder = SignedDictBuilder::new(¶ms); + signed_builder.insert("a", 1); + signed_builder.insert("b", 2); + signed_builder.insert("c", 3); + signed_builder.insert("d", 4); + signed_builder.insert("e", 5); + signed_builder.insert("f", 6); + let signer = Signer(SecretKey(1u32.into())); + let signed_dict = signed_builder.sign(&signer)?; + + // Build the main pod + let mut builder = MainPodBuilder::new(¶ms, vd_set); + builder.pub_op(Operation::dict_signed_by(&signed_dict))?; + + // Create 6 Equal statements (one for each predicate constraint) in original order + // Each proves that signed_dict["x"] = n, matching the Equal(A["x"], n) template + let st_a = builder.priv_op(Operation::eq((&signed_dict, "a"), 1))?; + let st_b = builder.priv_op(Operation::eq((&signed_dict, "b"), 2))?; + let st_c = builder.priv_op(Operation::eq((&signed_dict, "c"), 3))?; + let st_d = builder.priv_op(Operation::eq((&signed_dict, "d"), 4))?; + let st_e = builder.priv_op(Operation::eq((&signed_dict, "e"), 5))?; + let st_f = builder.priv_op(Operation::eq((&signed_dict, "f"), 6))?; + + // Pass statements in original declaration order + let statements = vec![st_a, st_b, st_c, st_d, st_e, st_f]; + + // Use apply_predicate (primary API) to automatically wire the split chain + let result = batches.apply_predicate(&mut builder, "large_pred", statements, true)?; + + // The result should be a valid statement + let predicate = batches.predicate_ref_by_name("large_pred").unwrap(); + match &result { + Statement::Custom(pred_ref, _) => { + assert_eq!(pred_ref, &predicate); + } + _ => panic!("Expected Statement::Custom, got {:?}", result), + } + + // Prove with MockProver + let prover = MockProver {}; + let pod = builder.prove(&prover)?; + + // Verify the pod + pod.pod.verify()?; + + Ok(()) + } } diff --git a/src/lang/error.rs b/src/lang/error.rs index 72637ea..318e715 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -22,6 +22,9 @@ pub enum LangError { #[error("Lowering error: {0}")] Lowering(Box), + + #[error("Batching error: {0}")] + Batching(Box), } /// Validation errors from frontend AST validation @@ -90,14 +93,6 @@ pub enum ValidationError { /// Lowering errors from frontend AST lowering to middleware #[derive(Debug, thiserror::Error)] pub enum LoweringError { - #[error("Too many custom predicates in batch '{batch_name}': {count} exceeds limit of {max}{}", if *.original_count != *.count { format!(" (started with {} predicates before automatic splitting)", original_count) } else { String::new() })] - TooManyPredicates { - batch_name: String, - count: usize, - max: usize, - original_count: usize, - }, - #[error("Too many statements in predicate '{predicate}': {count} exceeds limit of {max}")] TooManyStatements { predicate: String, @@ -127,6 +122,9 @@ pub enum LoweringError { #[error("Splitting error: {0}")] Splitting(#[from] SplittingError), + #[error("Batching error: {0}")] + Batching(#[from] BatchingError), + #[error("Cannot lower document with validation errors")] ValidationErrors, } @@ -235,6 +233,13 @@ fn format_public_args_at_split_error( msg } +/// Batching errors from multi-batch packing +#[derive(Debug, thiserror::Error)] +pub enum BatchingError { + #[error("Internal batching error: {message}")] + Internal { message: String }, +} + /// Splitting errors from predicate splitting #[derive(Debug, thiserror::Error)] pub enum SplittingError { @@ -271,13 +276,6 @@ pub enum SplittingError { max_allowed: usize, suggestion: Option>, }, - - #[error("Too many predicates in chain for '{predicate}': {count} exceeds batch limit of {max_allowed}")] - TooManyPredicatesInChain { - predicate: String, - count: usize, - max_allowed: usize, - }, } impl From for LangError { @@ -303,3 +301,9 @@ impl From for LangError { LangError::Lowering(Box::new(err)) } } + +impl From for LangError { + fn from(err: BatchingError) -> Self { + LangError::Batching(Box::new(err)) + } +} diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs new file mode 100644 index 0000000..fe4b30e --- /dev/null +++ b/src/lang/frontend_ast_batch.rs @@ -0,0 +1,1446 @@ +//! Multi-batch packing for predicates +//! +//! This module implements packing of multiple predicates (including split chains) +//! into multiple CustomPredicateBatches when they exceed single-batch limits. +//! +//! Packing strategy (dependency-aware): +//! - Build a dependency graph of predicates (edges: callee → caller for local refs). +//! - Condense strongly connected components (SCCs) to ensure mutually-recursive preds stay together. +//! - Topologically order the SCC DAG; within each topological layer, pack larger components first +//! (ties broken by declaration order) to reduce wasted space. +//! - Within a batch, intra-batch calls use `BatchSelf` and work regardless of declaration order; +//! cross-batch calls always point to earlier batches via `CustomPredicateRef`. +//! - Forward cross-batch references cannot occur with this planner (they are treated as unreachable). + +use std::{collections::HashMap, str::FromStr, sync::Arc}; + +use petgraph::{algo::condensation, graph::DiGraph, prelude::NodeIndex, visit::EdgeRef}; + +use crate::{ + frontend::{CustomPredicateBatchBuilder, Operation, OperationArg, StatementTmplBuilder}, + lang::{ + error::BatchingError, + frontend_ast::{ConjunctionType, CustomPredicateDef}, + frontend_ast_lower::lower_statement_arg, + frontend_ast_split::{SplitChainInfo, SplitResult}, + }, + middleware::{ + CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, Predicate, Statement, + }, +}; + +/// A single step in a multi-operation sequence for split predicates +#[derive(Debug, Clone)] +struct OperationStep { + /// The operation to perform + operation: Operation, + /// Whether this step's result should be public + public: bool, +} + +/// Errors that can occur when building multi-operations +#[derive(Debug, Clone, thiserror::Error)] +pub enum MultiOperationError { + #[error("Predicate not found: {0}")] + PredicateNotFound(String), + + #[error("Chain piece not found: {0}")] + ChainPieceNotFound(String), + + #[error( + "Wrong statement count for predicate '{predicate}': expected {expected}, got {actual}" + )] + WrongStatementCount { + predicate: String, + expected: usize, + actual: usize, + }, + + #[error("No operation steps to apply")] + NoSteps, +} + +/// Container for multiple predicate batches +#[derive(Debug, Clone)] +pub struct PredicateBatches { + batches: Vec>, + /// Maps predicate name to (batch_index, predicate_index_within_batch) + predicate_index: HashMap, + /// Split chain metadata for predicates that were split + /// Maps original predicate name to its chain info + split_chains: HashMap, +} + +impl Default for PredicateBatches { + fn default() -> Self { + Self::new() + } +} + +impl PredicateBatches { + pub fn new() -> Self { + Self { + batches: Vec::new(), + predicate_index: HashMap::new(), + split_chains: HashMap::new(), + } + } + + /// Get split chain info for a predicate (if it was split) + pub fn split_chain(&self, name: &str) -> Option<&SplitChainInfo> { + self.split_chains.get(name) + } + + /// Get a reference to a predicate by name + pub fn predicate_ref_by_name(&self, name: &str) -> Option { + let (batch_idx, pred_idx) = self.predicate_index.get(name)?; + let batch = self.batches.get(*batch_idx)?; + Some(CustomPredicateRef::new(batch.clone(), *pred_idx)) + } + + /// Get all batches + pub fn batches(&self) -> &[Arc] { + &self.batches + } + + /// Get the first batch (for backwards compatibility) + pub fn first_batch(&self) -> Option<&Arc> { + self.batches.first() + } + + /// Get batch count + pub fn batch_count(&self) -> usize { + self.batches.len() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.batches.is_empty() + } + + /// Total predicate count across all batches + pub fn total_predicate_count(&self) -> usize { + self.batches.iter().map(|b| b.predicates().len()).sum() + } + + /// Build operation steps for a predicate (internal helper) + /// + /// For non-split predicates, returns a single operation. + /// For split predicates, returns the chain of operations in execution order + /// (innermost first), with chain link placeholders. + fn build_steps( + &self, + predicate_name: &str, + statements: Vec, + public: bool, + ) -> Result, MultiOperationError> { + // Check if this predicate was split + let chain_info = match self.split_chains.get(predicate_name) { + Some(info) => info, + None => { + // Not split - single operation with all statements + let pred_ref = self.predicate_ref_by_name(predicate_name).ok_or_else(|| { + MultiOperationError::PredicateNotFound(predicate_name.to_string()) + })?; + + return Ok(vec![OperationStep { + operation: Operation::custom(pred_ref, statements), + public, + }]); + } + }; + + // Validate statement count + if statements.len() != chain_info.real_statement_count { + return Err(MultiOperationError::WrongStatementCount { + predicate: predicate_name.to_string(), + expected: chain_info.real_statement_count, + actual: statements.len(), + }); + } + + // Reorder statements from original order to split order + // reorder_map[original_idx] = split_idx + // So we need to place statements[i] at position reorder_map[i] + let mut reordered = vec![Statement::None; statements.len()]; + for (original_idx, stmt) in statements.into_iter().enumerate() { + let split_idx = chain_info.reorder_map[original_idx]; + reordered[split_idx] = stmt; + } + + // Build operations for each piece in execution order (innermost first) + // + // chain_pieces are in execution order: [continuation_N, ..., continuation_1, main] + // But in split order, statements are laid out: [main's stmts, cont_1's stmts, ..., cont_N's stmts] + // So we need to compute offsets from the END for the first pieces. + // + // Example with 6 statements, max_arity 5: + // split order: [stmt0, stmt1, stmt2, stmt3, stmt4, stmt5] + // chain_pieces[0] (large_pred_1): takes stmt5 (the last 1) + // chain_pieces[1] (large_pred): takes stmt0-4 (the first 5) + // + // We compute offsets by going through pieces in reverse order (matching split order). + + let num_pieces = chain_info.chain_pieces.len(); + + // Compute the starting offset for each piece by iterating in reverse + // (reverse of chain_pieces = same order as split layout) + let mut piece_offsets = vec![0usize; num_pieces]; + let mut offset = 0; + for i in (0..num_pieces).rev() { + piece_offsets[i] = offset; + offset += chain_info.chain_pieces[i].real_statement_count; + } + + let mut steps = Vec::new(); + for (piece_idx, piece) in chain_info.chain_pieces.iter().enumerate() { + let is_final = piece_idx == num_pieces - 1; + + // Get predicate ref for this piece + let piece_ref = self + .predicate_ref_by_name(&piece.name) + .ok_or_else(|| MultiOperationError::ChainPieceNotFound(piece.name.clone()))?; + + // Slice the reordered statements for this piece + let start = piece_offsets[piece_idx]; + let end = start + piece.real_statement_count; + let piece_statements: Vec = reordered[start..end].to_vec(); + + // Build the operation + // For non-final pieces, we'll add a placeholder that will be replaced + // with the previous step's result when applied + let mut args = piece_statements; + if piece.has_chain_call { + // Add placeholder for chain link - will be replaced by apply_multi_operation + args.push(Statement::None); + } + + steps.push(OperationStep { + operation: Operation::custom(piece_ref, args), + public: public && is_final, // Only final piece is public + }); + } + + Ok(steps) + } + + /// Apply a predicate directly into a `MainPodBuilder` (common case). + /// + /// For split predicates, earlier chain links are applied as private, and only the final + /// piece is applied as public when `public` is true. For non-split predicates, the single + /// operation is applied with the provided `public` flag. + /// + /// Arguments: + /// - `builder`: target builder to receive operations + /// - `name`: predicate name + /// - `statements`: user statements in original declaration order + /// - `public`: whether the final result should be public + pub fn apply_predicate( + &self, + builder: &mut crate::frontend::MainPodBuilder, + name: &str, + statements: Vec, + public: bool, + ) -> crate::frontend::Result { + self.apply_predicate_with(name, statements, public, |is_public, op| { + if is_public { + builder.pub_op(op) + } else { + builder.priv_op(op) + } + }) + } + + /// Advanced variant: apply using a custom closure. + /// + /// Prefer `apply_predicate` for common usage. This method allows callers to intercept each + /// operation (with its `public` flag) and decide how to execute it. + /// + /// Arguments: + /// - `name`: predicate name + /// - `statements`: user statements in original declaration order + /// - `public`: whether the final result should be public + /// - `apply_op`: closure `(is_public, operation) -> Result` used to execute each step + pub fn apply_predicate_with( + &self, + name: &str, + statements: Vec, + public: bool, + mut apply_op: F, + ) -> Result + where + F: FnMut(bool, Operation) -> Result, + E: From, + { + let steps = self.build_steps(name, statements, public)?; + + if steps.is_empty() { + return Err(MultiOperationError::NoSteps.into()); + } + + let mut prev_result: Option = None; + + for step in steps { + let op = if let Some(prev) = prev_result { + // Replace the last Statement::None arg with the previous result. + // By construction, all steps after the first include a chain placeholder + // as their last argument. + let mut args = step.operation.1; + let last = args + .last_mut() + .expect("chain statement should include placeholder arg"); + assert!( + matches!(last, OperationArg::Statement(Statement::None)), + "expected last arg to be a Statement::None placeholder" + ); + *last = OperationArg::Statement(prev); + Operation(step.operation.0, args, step.operation.2) + } else { + step.operation + }; + + prev_result = Some(apply_op(step.public, op)?); + } + + // Safe to unwrap because we checked steps.is_empty() above + Ok(prev_result.unwrap()) + } +} + +/// Assignment of a predicate to a batch +#[derive(Debug, Clone)] +struct PredicateAssignment { + /// Full name (e.g., "my_pred_1" for split link) + full_name: String, + /// Which batch this goes into + batch_index: usize, + /// Index within that batch + index_in_batch: usize, +} + +/// Information about an imported predicate for use during batching +#[derive(Debug, Clone)] +pub struct ImportedPredicateInfo { + pub batch: Arc, + pub index: usize, +} + +/// Pack predicates into multiple batches +/// +/// Takes a list of split results (containing predicates and optional chain info) +/// and packs them into batches, handling cross-batch references correctly. +/// +/// Predicates are packed dependency‑aware: +/// - Mutually recursive predicates (SCCs) are kept together. +/// - Components are ordered topologically; within each layer, larger components are packed first +/// (ties by declaration order) to reduce wasted space. +/// - Within a batch, predicates can reference each other freely via `BatchSelf`; cross-batch +/// references always point to earlier batches via `CustomPredicateRef`. +/// +/// `imported_predicates` maps predicate names to their imported batch info, +/// allowing predicates to call imported predicates from other batches. +pub fn batch_predicates( + split_results: Vec, + params: &Params, + base_batch_name: &str, + imported_predicates: &HashMap, +) -> Result { + // Extract predicates and collect split chains + let mut predicates = Vec::new(); + let mut split_chains = HashMap::new(); + + for result in split_results { + // Collect chain info if present + if let Some(chain_info) = result.chain_info { + split_chains.insert(chain_info.original_name.clone(), chain_info); + } + // Flatten predicates + predicates.extend(result.predicates); + } + + if predicates.is_empty() { + return Ok(PredicateBatches::new()); + } + + // Plan batch assignments in declaration order + let assignments = plan_batch_assignments(&predicates, params.max_custom_batch_size)?; + + // Build reference map: name -> (batch_idx, idx_in_batch) + let reference_map: HashMap = assignments + .iter() + .map(|a| (a.full_name.clone(), (a.batch_index, a.index_in_batch))) + .collect(); + + // Determine number of batches + let num_batches = assignments + .iter() + .map(|a| a.batch_index) + .max() + .map(|m| m + 1) + .unwrap_or(0); + + // Build batches in order + let mut batches = Vec::new(); + let mut predicate_index = HashMap::new(); + + for batch_idx in 0..num_batches { + // Collect predicates for this batch (in assignment order) + let batch_predicates: Vec<_> = predicates + .iter() + .zip(assignments.iter()) + .filter(|(_, a)| a.batch_index == batch_idx) + .map(|(p, _)| p.clone()) + .collect(); + + let batch_name = if num_batches == 1 { + base_batch_name.to_string() + } else { + format!("{}_{}", base_batch_name, batch_idx) + }; + + let batch = build_single_batch( + &batch_predicates, + batch_idx, + &reference_map, + &batches, + imported_predicates, + params, + &batch_name, + )?; + + // Update predicate index + for (idx, pred) in batch_predicates.iter().enumerate() { + predicate_index.insert(pred.name.name.clone(), (batch_idx, idx)); + } + + batches.push(batch); + } + + Ok(PredicateBatches { + batches, + predicate_index, + split_chains, + }) +} + +/// Plan batch assignments (greedy fill in declaration order) +fn plan_batch_assignments( + predicates: &[CustomPredicateDef], + max_batch_size: usize, +) -> Result, BatchingError> { + // Map name -> original index + let mut name_to_index: HashMap = HashMap::new(); + let index_to_name: Vec = predicates + .iter() + .enumerate() + .map(|(i, pred)| { + name_to_index.insert(pred.name.name.clone(), i); + pred.name.name.clone() + }) + .collect(); + + let n = predicates.len(); + // Build graph with nodes 0..n and edges callee -> caller for local refs + let mut graph: DiGraph = DiGraph::new(); + let nodes: Vec = (0..n).map(|i| graph.add_node(i)).collect(); + for (caller_idx, pred) in predicates.iter().enumerate() { + for stmt in &pred.statements { + if let Some(&callee_idx) = name_to_index.get(&stmt.predicate.name) { + graph.add_edge(nodes[callee_idx], nodes[caller_idx], ()); + } + } + } + + // Condense SCCs into DAG; each node weight is Vec of members + // Pass `true` to remove self-loops, ensuring acyclicity for topo sort + let mut condensed = condensation(graph, /*make_acyclic=*/ true); + + // Verify each component fits in a batch and sort members by original index + for comp_members in condensed.node_weights_mut() { + comp_members.sort_unstable(); + if comp_members.len() > max_batch_size { + let members = comp_members + .iter() + .map(|&i| index_to_name[i].clone()) + .collect::>() + .join(", "); + // An SCC larger than the per-batch capacity cannot be packed: all members of a + // mutually-recursive group must live in the same batch. Splitting reduces per‑predicate + // arity but does not break cycles, and the split chain for a single predicate remains + // acyclic (so it does not increase the SCC size). Users must refactor to break the + // cycle or increase `max_custom_batch_size`. + return Err(BatchingError::Internal { + message: format!( + "Mutually recursive group of size {} exceeds batch capacity {}. Predicates: [{}]. \\n+ Consider breaking the cycle or increasing max_custom_batch_size.", + comp_members.len(), + max_batch_size, + members + ), + }); + } + } + + // Topological sort using a layer-wise variant of Kahn's algorithm. + // + // Standard Kahn's algorithm processes nodes one at a time from a queue. This variant + // instead processes entire "layers" (all nodes at the same topological depth) together, + // which allows sorting within each layer for better bin-packing while still respecting + // dependency order. + // + // Algorithm: + // 1. Compute in-degree for each node + // 2. Initialize first layer with all zero in-degree nodes (no dependencies) + // 3. For each layer: + // a. Sort by component size (desc) for bin-packing, then by key for determinism + // b. Add to output order + // c. Decrement in-degree of all neighbors; those hitting zero form the next layer + // 4. Assert all nodes visited (would fail if graph had cycles, but condensation ensures DAG) + + let node_count = condensed.node_count(); + + // Step 1: Compute in-degrees + let mut indeg = vec![0usize; node_count]; + for e in condensed.edge_references() { + indeg[e.target().index()] += 1; + } + + // Stable key per component: minimal original index inside the component + // Used as tiebreaker when sorting layers for deterministic output + let mut comp_key: Vec = vec![0; node_count]; + for ni in condensed.node_indices() { + let members = &condensed[ni]; + let key = members.iter().copied().min().expect("non-empty component"); + comp_key[ni.index()] = key; + } + + // Step 2: Initialize with zero in-degree nodes + let mut current_layer: Vec = condensed + .node_indices() + .filter(|&ni| indeg[ni.index()] == 0) + .collect(); + + let mut order: Vec = Vec::with_capacity(node_count); + use std::cmp::Reverse; + + // Step 3: Process layer by layer + while !current_layer.is_empty() { + // Sort by size desc (for bin-packing), then by comp_key asc (for determinism) + current_layer.sort_by_key(|&ni| { + let size = condensed[ni].len(); + (Reverse(size), comp_key[ni.index()]) + }); + + // Add this layer to the output order + order.extend(current_layer.iter().copied()); + + // Build next layer: decrement in-degrees, collect nodes that hit zero + let mut next_layer: Vec = Vec::new(); + for &u in ¤t_layer { + for v in condensed.neighbors(u) { + let idx = v.index(); + indeg[idx] -= 1; + if indeg[idx] == 0 { + next_layer.push(v); + } + } + } + current_layer = next_layer; + } + + // Step 4: Verify all nodes were visited (cycle detection) + assert_eq!(order.len(), node_count, "condensed graph must be acyclic"); + + // Greedy pack components by the layer-aware order + let mut pred_batch: Vec = vec![0; n]; + let mut current_batch = 0usize; + let mut current_count = 0usize; + for cid in order { + let comp = &condensed[cid]; + let comp_size = comp.len(); + // If the next component doesn't fit in the remaining capacity, start a new batch. + // This is the normal batch boundary; precedence is preserved, and we mitigate wasted + // space by sorting components within each topo layer by size (desc) earlier. + if current_count + comp_size > max_batch_size { + current_batch += 1; + current_count = 0; + } + for &pi in comp { + pred_batch[pi] = current_batch; + } + current_count += comp_size; + } + + // Compute index_in_batch by original order to match builder's enumeration + let mut per_batch_counts: HashMap = HashMap::new(); + let mut assignments = Vec::with_capacity(n); + for (i, pred) in predicates.iter().enumerate() { + let b = pred_batch[i]; + let idx = per_batch_counts.get(&b).cloned().unwrap_or(0); + per_batch_counts.insert(b, idx + 1); + assignments.push(PredicateAssignment { + full_name: pred.name.name.clone(), + batch_index: b, + index_in_batch: idx, + }); + } + + Ok(assignments) +} + +/// Build a single batch with properly resolved references +fn build_single_batch( + predicates: &[CustomPredicateDef], + batch_idx: usize, + reference_map: &HashMap, + existing_batches: &[Arc], + imported_predicates: &HashMap, + params: &Params, + batch_name: &str, +) -> Result, BatchingError> { + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), batch_name.to_string()); + + for pred in predicates { + let name = &pred.name.name; + + // Collect argument names + let public_args: Vec<&str> = pred + .args + .public_args + .iter() + .map(|a| a.name.as_str()) + .collect(); + + let private_args: Vec<&str> = pred + .args + .private_args + .as_ref() + .map(|args| args.iter().map(|a| a.name.as_str()).collect()) + .unwrap_or_default(); + + // Build statement templates with resolved predicates + let statement_builders: Vec = pred + .statements + .iter() + .map(|stmt| { + build_statement_with_resolved_refs( + stmt, + name, + batch_idx, + reference_map, + existing_batches, + imported_predicates, + ) + }) + .collect::>()?; + + let conjunction = pred.conjunction_type == ConjunctionType::And; + + builder + .predicate( + name, + conjunction, + &public_args, + &private_args, + &statement_builders, + ) + .map_err(|e| BatchingError::Internal { + message: format!("Failed to add predicate '{}': {}", name, e), + })?; + } + + Ok(builder.finish()) +} + +/// Build a statement template with properly resolved predicate references +fn build_statement_with_resolved_refs( + stmt: &crate::lang::frontend_ast::StatementTmpl, + caller_name: &str, + current_batch_idx: usize, + reference_map: &HashMap, + existing_batches: &[Arc], + imported_predicates: &HashMap, +) -> Result { + let callee_name = &stmt.predicate.name; + + // Resolve the predicate + let predicate = if let Ok(native) = NativePredicate::from_str(callee_name) { + Predicate::Native(native) + } else if let Some(&(target_batch, target_idx)) = reference_map.get(callee_name) { + // Local predicate in this document + if target_batch == current_batch_idx { + // Same batch - use BatchSelf + Predicate::BatchSelf(target_idx) + } else if target_batch < current_batch_idx { + // Earlier batch - use Custom ref + let batch = &existing_batches[target_batch]; + Predicate::Custom(CustomPredicateRef::new(batch.clone(), target_idx)) + } else { + // Forward reference to a later batch should be impossible with the dependency-aware planner + unreachable!( + "Forward cross-batch reference: '{}' (batch {}) -> '{}' (batch {})", + caller_name, current_batch_idx, callee_name, target_batch + ); + } + } else if let Some(imported) = imported_predicates.get(callee_name) { + // Imported predicate from another batch + Predicate::Custom(CustomPredicateRef::new( + imported.batch.clone(), + imported.index, + )) + } else { + // Unknown predicate + return Err(BatchingError::Internal { + message: format!("Unknown predicate reference: '{}'", callee_name), + }); + }; + + // Build the statement template + let mut builder = StatementTmplBuilder::new(predicate); + + for arg in &stmt.args { + builder = builder.arg(lower_statement_arg(arg)); + } + + Ok(builder) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + lang::{ + frontend_ast::parse::parse_document, frontend_ast_split::split_predicate_if_needed, + parser::parse_podlang, + }, + middleware::PredicateOrWildcard, + }; + + fn parse_predicates(input: &str) -> Vec { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + + document + .items + .into_iter() + .filter_map(|item| match item { + crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred), + _ => None, + }) + .collect() + } + + /// Helper: wrap predicates into SplitResult (without actually splitting) + fn preds_to_split_results(predicates: Vec) -> Vec { + predicates + .into_iter() + .map(|pred| SplitResult { + predicates: vec![pred], + chain_info: None, + }) + .collect() + } + + #[test] + fn test_single_predicate_single_batch() { + let input = r#" + my_pred(A, B) = AND( + Equal(A["x"], B["y"]) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 1); + assert_eq!(batches.total_predicate_count(), 1); + } + + #[test] + fn test_multiple_predicates_single_batch() { + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); // max_custom_batch_size = 4 + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 1); + assert_eq!(batches.total_predicate_count(), 3); + } + + #[test] + fn test_predicates_span_multiple_batches() { + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + pred4(D) = AND(Equal(D["w"], 4)) + pred5(E) = AND(Equal(E["v"], 5)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); // max_custom_batch_size = 4 + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 2); + assert_eq!(batches.total_predicate_count(), 5); + + // First batch should have 4 predicates + assert_eq!(batches.batches()[0].predicates().len(), 4); + // Second batch should have 1 predicate + assert_eq!(batches.batches()[1].predicates().len(), 1); + } + + #[test] + fn test_intra_batch_forward_reference() { + // pred2 calls pred1, but pred2 is declared first + // This should work because they're in the same batch + let input = r#" + pred2(B) = AND(pred1(B)) + pred1(A) = AND(Equal(A["x"], 1)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 1); + + // pred2 should reference pred1 via BatchSelf + use crate::middleware::PredicateOrWildcard; + let pred2 = &batches.batches()[0].predicates()[0]; + let stmt = &pred2.statements[0]; + assert!(matches!( + stmt.pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) + )); // pred1 is at index 1 + } + + #[test] + fn test_mutual_recursion_in_same_batch() { + // pred1 calls pred2, pred2 calls pred1 - mutual recursion + // This should work because they're in the same batch + let input = r#" + pred1(A) = AND(pred2(A)) + pred2(B) = AND(pred1(B)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 1); + assert_eq!(batches.total_predicate_count(), 2); + + // Both should use BatchSelf references + let pred1 = &batches.batches()[0].predicates()[0]; + let pred2 = &batches.batches()[0].predicates()[1]; + assert!(matches!( + pred1.statements[0].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(1)) + )); // calls pred2 + assert!(matches!( + pred2.statements[0].pred_or_wc(), + PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) + )); // calls pred1 + } + + #[test] + fn test_cross_batch_reference() { + // 5 predicates where pred5 calls pred1 + // pred1-4 go in batch 0, pred5 in batch 1 + // pred5's call to pred1 should be a cross-batch reference + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + pred4(D) = AND(Equal(D["w"], 4)) + pred5(E) = AND(pred1(E)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); // max_custom_batch_size = 4 + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 2); + + // pred5 should reference pred1 via CustomPredicateRef + let pred5_batch = &batches.batches()[1]; + let pred5 = &pred5_batch.predicates()[0]; + let pred5_stmt = &pred5.statements[0]; + + // The predicate should be a Custom reference to batch 0 + match pred5_stmt.pred_or_wc() { + PredicateOrWildcard::Predicate(Predicate::Custom(ref_)) => { + // Should reference batch 0, index 0 (pred1) + assert_eq!(ref_.batch.id(), batches.batches()[0].id()); + } + _ => panic!("Expected Custom predicate reference"), + } + } + + #[test] + fn test_split_chain_spans_batches() { + // Create a predicate that will split into 2-3 predicates + // Then add more predicates to force the chain to span batches + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + large_pred(D) = AND( + Equal(D["a"], 1) + Equal(D["b"], 2) + Equal(D["c"], 3) + Equal(D["d"], 4) + Equal(D["e"], 5) + Equal(D["f"], 6) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + // Split the large predicate + let mut all_split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + all_split_results.push(result); + } + + // Count total predicates across all split results + let total_preds: usize = all_split_results.iter().map(|r| r.predicates.len()).sum(); + + // We should have: pred1, pred2, pred3, large_pred_1 (continuation), large_pred + // That's 5 predicates, which spans 2 batches + assert_eq!(total_preds, 5); + + let result = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new()); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert_eq!(batches.batch_count(), 2); + assert_eq!(batches.total_predicate_count(), 5); + + // Verify chain info was captured + let chain_info = batches.split_chain("large_pred"); + assert!(chain_info.is_some()); + let info = chain_info.unwrap(); + assert_eq!(info.original_name, "large_pred"); + assert_eq!(info.real_statement_count, 6); + } + + #[test] + fn test_forward_cross_batch_reference_avoided_by_planner() { + // 5 predicates where pred4 calls pred5 (forward declaration) + // With max_custom_batch_size = 4, naive packing would place pred5 in batch 1 + // The dependency-aware planner should instead pack pred5 before pred4 + // to avoid a forward cross-batch reference. + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + pred3(C) = AND(Equal(C["z"], 3)) + pred4(D) = AND(pred5(D)) + pred5(E) = AND(Equal(E["v"], 5)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); // max_custom_batch_size = 4 + + let batches = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ) + .expect("Planner should avoid forward cross-batch reference"); + + // Expect two batches and the reference to point within the same batch or earlier batch. + assert_eq!(batches.batch_count(), 2); + // pred5 should be in batch 0 and pred4 in batch 1 (given stable topo + packing) + let pred5_ref = batches.predicate_ref_by_name("pred5").unwrap(); + let pred4_ref = batches.predicate_ref_by_name("pred4").unwrap(); + assert_eq!(pred5_ref.batch.id(), batches.batches()[0].id()); + assert_eq!(pred4_ref.batch.id(), batches.batches()[1].id()); + } + + #[test] + fn test_empty_input() { + let split_results: Vec = vec![]; + let params = Params::default(); + + let result = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()); + assert!(result.is_ok()); + + let batches = result.unwrap(); + assert!(batches.is_empty()); + assert_eq!(batches.batch_count(), 0); + } + + #[test] + fn test_predicate_ref_by_name() { + let input = r#" + pred1(A) = AND(Equal(A["x"], 1)) + pred2(B) = AND(Equal(B["y"], 2)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let batches = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ) + .unwrap(); + + // Should be able to look up both predicates + assert!(batches.predicate_ref_by_name("pred1").is_some()); + assert!(batches.predicate_ref_by_name("pred2").is_some()); + assert!(batches.predicate_ref_by_name("nonexistent").is_none()); + } + + #[test] + fn test_mutual_recursion_exceeds_capacity_error() { + // Two predicates that call each other (SCC size = 2) with max batch size 1 + // Should error because an SCC cannot be split across batches + let input = r#" + pred1(A) = AND(pred2(A)) + pred2(B) = AND(pred1(B)) + "#; + + let predicates = parse_predicates(input); + let params = Params { + max_custom_batch_size: 1, // force SCC > capacity + ..Default::default() + }; + + let result = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("exceeds batch capacity")); + } + + #[test] + fn test_split_chain_across_batches_placement() { + // Create a large predicate that splits into 2 pieces, plus enough predicates + // to force the chain to span batches; verify continuation is placed earlier batch + let input = r#" + p1(A) = AND(Equal(A["x"], 1)) + p2(B) = AND(Equal(B["y"], 2)) + p3(C) = AND(Equal(C["z"], 3)) + large_pred(D) = AND( + Equal(D["a"], 1) + Equal(D["b"], 2) + Equal(D["c"], 3) + Equal(D["d"], 4) + Equal(D["e"], 5) + Equal(D["f"], 6) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); // max_custom_batch_size = 4 + + // Split and batch + let mut all_split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + all_split_results.push(result); + } + let batches = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new()) + .expect("Batch failed"); + + assert_eq!(batches.batch_count(), 2); + + // Verify chain info + let chain_info = batches + .split_chain("large_pred") + .expect("Expected chain info"); + assert_eq!(chain_info.chain_pieces.len(), 2); + // Expect continuation piece name to be large_pred_1 (innermost first) + let cont_name = &chain_info.chain_pieces[0].name; + assert_eq!(cont_name, "large_pred_1"); + + // Expect continuation in batch 0 and main in batch 1 + let cont_ref = batches.predicate_ref_by_name("large_pred_1").unwrap(); + let main_ref = batches.predicate_ref_by_name("large_pred").unwrap(); + assert_eq!(cont_ref.batch.id(), batches.batches()[0].id()); + assert_eq!(main_ref.batch.id(), batches.batches()[1].id()); + } + + /// Helper: create a unique Statement for testing + /// Uses Equal with distinct literal values to create distinguishable statements + fn test_statement(id: usize) -> Statement { + use crate::middleware::ValueRef; + Statement::Equal( + ValueRef::Literal((id as i64).into()), + ValueRef::Literal((id as i64).into()), + ) + } + + #[test] + fn test_apply_predicate_non_split() { + // A simple predicate that doesn't need splitting + let input = r#" + my_pred(A, B) = AND( + Equal(A["x"], B["y"]) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let batches = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ) + .unwrap(); + + // Create fake statements + let statements = vec![Statement::None, Statement::None]; + + // Track operations applied + let mut operations_applied: Vec<(bool, usize)> = Vec::new(); + let mut stmt_counter = 0; + + let result: Result = + batches.apply_predicate_with("my_pred", statements, true, |public, op| { + operations_applied.push((public, op.1.len())); + stmt_counter += 1; + Ok(test_statement(stmt_counter)) + }); + + assert!(result.is_ok()); + // Should be exactly one operation + assert_eq!(operations_applied.len(), 1); + // Should be public + assert!(operations_applied[0].0); + // Should have 2 arguments + assert_eq!(operations_applied[0].1, 2); + } + + #[test] + fn test_apply_predicate_2_piece_split() { + // A predicate that will split into 2 pieces + let input = r#" + large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + // Split the predicate + let mut split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + split_results.push(result); + } + + // Should split into 2 pieces + assert_eq!(split_results.len(), 1); + assert_eq!(split_results[0].predicates.len(), 2); + assert!(split_results[0].chain_info.is_some()); + + let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + .expect("Batch failed"); + + // Verify chain info + let chain_info = batches.split_chain("large_pred").unwrap(); + assert_eq!(chain_info.chain_pieces.len(), 2); + assert_eq!(chain_info.real_statement_count, 6); + + // Create fake statements (6 for the 6 Equal statements) + let statements: Vec = (0..6).map(test_statement).collect(); + + // Track operations + let mut operations_applied: Vec<(bool, usize)> = Vec::new(); + let mut stmt_counter = 100; + + let result: Result = + batches.apply_predicate_with("large_pred", statements, true, |public, op| { + operations_applied.push((public, op.1.len())); + stmt_counter += 1; + Ok(test_statement(stmt_counter)) + }); + + assert!(result.is_ok()); + // Should be exactly 2 operations (innermost continuation first, then main) + assert_eq!(operations_applied.len(), 2); + // First operation (continuation) should be private + assert!(!operations_applied[0].0); + // Second operation (main) should be public + assert!(operations_applied[1].0); + } + + #[test] + fn test_apply_predicate_3_piece_split() { + // A predicate that will split into 3 pieces (needs more statements) + let input = r#" + very_large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + Equal(A["g"], 7) + Equal(A["h"], 8) + Equal(A["i"], 9) + Equal(A["j"], 10) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + // Split the predicate + let mut split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + split_results.push(result); + } + + // Should split into 3 pieces + assert_eq!(split_results.len(), 1); + assert_eq!(split_results[0].predicates.len(), 3); + assert!(split_results[0].chain_info.is_some()); + + let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + .expect("Batch failed"); + + // Verify chain info + let chain_info = batches.split_chain("very_large_pred").unwrap(); + assert_eq!(chain_info.chain_pieces.len(), 3); + assert_eq!(chain_info.real_statement_count, 10); + + // Create fake statements (10 for the 10 Equal statements) + let statements: Vec = (0..10).map(test_statement).collect(); + + // Track operations + let mut operations_applied: Vec<(bool, usize)> = Vec::new(); + let mut stmt_counter = 100; + + let result: Result = + batches.apply_predicate_with("very_large_pred", statements, true, |public, op| { + operations_applied.push((public, op.1.len())); + stmt_counter += 1; + Ok(test_statement(stmt_counter)) + }); + + assert!(result.is_ok()); + // Should be exactly 3 operations + assert_eq!(operations_applied.len(), 3); + // First two operations (continuations) should be private + assert!(!operations_applied[0].0); + assert!(!operations_applied[1].0); + // Final operation (main) should be public + assert!(operations_applied[2].0); + } + + #[test] + fn test_apply_predicate_wrong_statement_count() { + // A predicate that will split + let input = r#" + large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + // Split the predicate + let mut split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + split_results.push(result); + } + + let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + .expect("Batch failed"); + + // Try with wrong number of statements (3 instead of 6) + let statements: Vec = (0..3).map(test_statement).collect(); + + let result: Result = + batches.apply_predicate_with("large_pred", statements, true, |_, _| { + Ok(test_statement(999)) + }); + + assert!(result.is_err()); + let err = result.unwrap_err(); + match err { + MultiOperationError::WrongStatementCount { + predicate, + expected, + actual, + } => { + assert_eq!(predicate, "large_pred"); + assert_eq!(expected, 6); + assert_eq!(actual, 3); + } + _ => panic!("Expected WrongStatementCount error, got {:?}", err), + } + } + + #[test] + fn test_apply_predicate_not_found() { + let input = r#" + my_pred(A) = AND(Equal(A["x"], 1)) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let batches = batch_predicates( + preds_to_split_results(predicates), + ¶ms, + "TestBatch", + &HashMap::new(), + ) + .unwrap(); + + let result: Result = + batches + .apply_predicate_with("nonexistent", vec![], true, |_, _| Ok(test_statement(999))); + + assert!(result.is_err()); + match result.unwrap_err() { + MultiOperationError::PredicateNotFound(name) => { + assert_eq!(name, "nonexistent"); + } + e => panic!("Expected PredicateNotFound error, got {:?}", e), + } + } + + #[test] + fn test_apply_predicate_chain_wiring() { + // Test that chain links are properly wired (previous result replaces Statement::None) + let input = r#" + large_pred(A) = AND( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let predicates = parse_predicates(input); + let params = Params::default(); + + let mut split_results = Vec::new(); + for pred in predicates { + let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); + split_results.push(result); + } + + let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + .expect("Batch failed"); + + let statements: Vec = (0..6).map(test_statement).collect(); + + // Track whether the second operation has the first result as its last arg + let mut last_args_of_ops: Vec> = Vec::new(); + let mut stmt_counter = 100; + + let result: Result = + batches.apply_predicate_with("large_pred", statements, true, |_, op| { + // Check the last argument + let last_arg = op.1.last().map(|arg| { + if let OperationArg::Statement(s) = arg { + s.clone() + } else { + Statement::None + } + }); + last_args_of_ops.push(last_arg); + stmt_counter += 1; + Ok(test_statement(stmt_counter)) + }); + + assert!(result.is_ok()); + assert_eq!(last_args_of_ops.len(), 2); + + // First operation's last arg should NOT be the result of previous (no previous) + // It might be Statement::None if no chain call, or a regular arg + + // Second operation's last arg SHOULD be test_statement(101) - the result from first op + assert_eq!(last_args_of_ops[1], Some(test_statement(101))); + } +} diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index 55d4591..0e7238d 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -1,38 +1,103 @@ //! Lowering from frontend AST to middleware structures //! //! This module converts validated frontend AST to middleware data structures. -//! Currently implements basic 1:1 conversion without automatic predicate splitting. +//! Supports automatic predicate splitting and multi-batch packing. use std::{ collections::{HashMap, HashSet}, str::FromStr, - sync::Arc, }; use crate::{ - frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, + frontend::{BuilderArg, StatementTmplBuilder}, lang::{ frontend_ast::*, + frontend_ast_batch::{self, PredicateBatches}, frontend_ast_split, frontend_ast_validate::{PredicateKind, ValidatedAST}, }, middleware::{ - self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params, - Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl, + self, containers, IntroPredicateRef, NativePredicate, Params, Predicate, + PredicateOrWildcard, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Wildcard, }, }; -/// Result of lowering: optional custom predicate batch and optional request +// ============================================================================ +// Shared lowering utilities +// ============================================================================ +// These functions convert AST types to middleware/builder types and are used +// by both the request lowering (in this module) and predicate batching +// (in frontend_ast_batch). + +/// Lower a literal value from AST to middleware Value. +/// +/// This is a pure conversion that cannot fail. +pub fn lower_literal(lit: &LiteralValue) -> middleware::Value { + match lit { + LiteralValue::Int(i) => middleware::Value::from(i.value), + LiteralValue::Bool(b) => middleware::Value::from(b.value), + LiteralValue::String(s) => middleware::Value::from(s.value.clone()), + LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash), + LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point), + LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()), + LiteralValue::Array(a) => { + let elements: Vec<_> = a.elements.iter().map(lower_literal).collect(); + let array = containers::Array::new(elements); + middleware::Value::from(array) + } + LiteralValue::Set(s) => { + let elements: std::collections::HashSet<_> = + s.elements.iter().map(lower_literal).collect(); + let set = containers::Set::new(elements); + middleware::Value::from(set) + } + LiteralValue::Dict(d) => { + let pairs: std::collections::HashMap<_, _> = d + .pairs + .iter() + .map(|pair| { + let key = middleware::Key::from(pair.key.value.as_str()); + let value = lower_literal(&pair.value); + (key, value) + }) + .collect(); + let dict = containers::Dictionary::new(pairs); + middleware::Value::from(dict) + } + } +} + +/// Lower a statement argument from AST to BuilderArg. +/// +/// This is a pure conversion that cannot fail. +pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { + match arg { + StatementTmplArg::Literal(lit) => { + let value = lower_literal(lit); + BuilderArg::Literal(value) + } + StatementTmplArg::Wildcard(id) => BuilderArg::WildcardLiteral(id.name.clone()), + StatementTmplArg::AnchoredKey(ak) => { + let key_str = match &ak.key { + AnchoredKeyPath::Bracket(s) => s.value.clone(), + AnchoredKeyPath::Dot(id) => id.name.clone(), + }; + BuilderArg::Key(ak.root.name.clone(), key_str) + } + } +} + +/// Result of lowering: optional custom predicate batches and optional request /// /// A Podlang file can contain: -/// - Just custom predicates (batch: Some, request: None) -/// - Just a request (batch: None, request: Some) -/// - Both (batch: Some, request: Some) -/// - Neither (batch: None, request: None) - just imports +/// - Just custom predicates (batches: Some, request: None) +/// - Just a request (batches: None, request: Some) +/// - Both (batches: Some, request: Some) +/// - Neither (batches: None, request: None) - just imports #[derive(Debug, Clone)] pub struct LoweredOutput { - pub batch: Option>, + pub batches: Option, pub request: Option, } @@ -60,71 +125,70 @@ pub fn lower( struct Lowerer<'a> { validated: ValidatedAST, params: &'a Params, - /// Map of predicate names to their index in the current batch (for split predicates) - batch_predicate_index: HashMap, } impl<'a> Lowerer<'a> { fn new(validated: ValidatedAST, params: &'a Params) -> Self { - Self { - validated, - params, - batch_predicate_index: HashMap::new(), - } + Self { validated, params } } - fn lower(mut self, batch_name: String) -> Result { - // Lower custom predicates (if any) - let batch = self.lower_batch(batch_name)?; + fn lower(self, batch_name: String) -> Result { + // Lower custom predicates (if any) - now supports multiple batches + let batches = self.lower_batches(batch_name)?; - // Lower request (if any) - pass batch so BatchSelf refs can be converted to Custom refs - let request = self.lower_request(batch.as_ref())?; + // Lower request (if any) - pass batches so refs can be resolved + let request = self.lower_request(batches.as_ref())?; - Ok(LoweredOutput { batch, request }) + Ok(LoweredOutput { batches, request }) } - fn lower_batch( - &mut self, - batch_name: String, - ) -> Result>, LoweringError> { + fn lower_batches(&self, batch_name: String) -> Result, LoweringError> { // Extract and split custom predicates from document - let (custom_predicates, original_count) = self.extract_and_split_predicates()?; + let custom_predicates = self.extract_and_split_predicates()?; // If no custom predicates, return None if custom_predicates.is_empty() { return Ok(None); } - // Check batch size constraint - if custom_predicates.len() > self.params.max_custom_batch_size { - return Err(LoweringError::TooManyPredicates { - batch_name: batch_name.clone(), - count: custom_predicates.len(), - max: self.params.max_custom_batch_size, - original_count, - }); + // Build map of imported predicates for batching + let imported_predicates = self.build_imported_predicates_map(); + + // Use the new batching module to pack predicates into batches + let batches = frontend_ast_batch::batch_predicates( + custom_predicates, + self.params, + &batch_name, + &imported_predicates, + )?; + + Ok(Some(batches)) + } + + fn build_imported_predicates_map( + &self, + ) -> HashMap { + let symbols = self.validated.symbols(); + let mut imported = HashMap::new(); + + for (name, info) in &symbols.predicates { + if let PredicateKind::BatchImported { batch, index } = &info.kind { + imported.insert( + name.clone(), + frontend_ast_batch::ImportedPredicateInfo { + batch: batch.clone(), + index: *index, + }, + ); + } } - // Build index of all predicates in the batch - for (idx, pred) in custom_predicates.iter().enumerate() { - self.batch_predicate_index - .insert(pred.name.name.clone(), idx); - } - - // Create custom predicate batch using builder - let mut cpb_builder = - CustomPredicateBatchBuilder::new(self.params.clone(), batch_name.clone()); - - for pred_def in &custom_predicates { - self.lower_custom_predicate(pred_def, &mut cpb_builder)?; - } - - Ok(Some(cpb_builder.finish())) + imported } fn lower_request( &self, - batch: Option<&Arc>, + batches: Option<&PredicateBatches>, ) -> Result, LoweringError> { let doc = self.validated.document(); @@ -141,44 +205,78 @@ impl<'a> Lowerer<'a> { // Build wildcard map from all wildcards used in the request statements let wildcard_map = self.build_request_wildcard_map(request_def); - // Lower each statement to a builder first - let mut statement_builders = Vec::new(); - for stmt in &request_def.statements { - let stmt_builder = self.lower_statement_to_builder(stmt)?; - statement_builders.push(stmt_builder); - } - - // Resolve builders to middleware statement templates + // Lower each statement to middleware templates, resolving predicates let mut request_templates = Vec::new(); - for stmt_builder in statement_builders { - let mw_stmt = - self.resolve_request_statement_builder(stmt_builder, &wildcard_map, batch)?; + for stmt in &request_def.statements { + let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?; request_templates.push(mw_stmt); } Ok(Some(crate::frontend::PodRequest::new(request_templates))) } - fn resolve_request_statement_builder( + fn lower_request_statement( &self, - stmt_builder: StatementTmplBuilder, + stmt: &StatementTmpl, wildcard_map: &HashMap, - batch: Option<&Arc>, + batches: Option<&PredicateBatches>, ) -> Result { - // First desugar the builder - let desugared = stmt_builder.desugar(); - - // Convert BatchSelf predicate to Custom if we have a batch - let mut predicate = desugared.predicate; - if let Some(batch_ref) = batch { - if let Predicate::BatchSelf(index) = predicate { - predicate = Predicate::Custom(middleware::CustomPredicateRef::new( - batch_ref.clone(), - index, - )); - } + // Enforce argument count limit for request statements + if stmt.args.len() > self.params.max_statement_args { + return Err(LoweringError::TooManyStatementArgs { + count: stmt.args.len(), + max: self.params.max_statement_args, + }); } + let pred_name = &stmt.predicate.name; + let symbols = self.validated.symbols(); + + // Resolve predicate - for request statements, local custom predicates + // must be resolved to CustomPredicateRef (not BatchSelf) + let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) { + Predicate::Native(native) + } else if let Some(info) = symbols.predicates.get(pred_name) { + match &info.kind { + PredicateKind::Native(np) => Predicate::Native(*np), + PredicateKind::Custom { .. } => { + // Local custom predicates - resolve to CustomPredicateRef + let batches = batches.ok_or_else(|| LoweringError::PredicateNotFound { + name: pred_name.clone(), + })?; + let pred_ref = batches.predicate_ref_by_name(pred_name).ok_or_else(|| { + LoweringError::PredicateNotFound { + name: pred_name.clone(), + } + })?; + Predicate::Custom(pred_ref) + } + PredicateKind::BatchImported { batch, index } => { + Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index)) + } + PredicateKind::IntroImported { + name, + verifier_data_hash, + } => Predicate::Intro(IntroPredicateRef { + name: name.clone(), + args_len: info.public_arity, + verifier_data_hash: *verifier_data_hash, + }), + } + } else { + return Err(LoweringError::PredicateNotFound { + name: pred_name.clone(), + }); + }; + + // Create a builder with the resolved predicate and desugar + let mut builder = StatementTmplBuilder::new(predicate); + for arg in &stmt.args { + let builder_arg = lower_statement_arg(arg); + builder = builder.arg(builder_arg); + } + let desugared = builder.desugar(); + // Convert BuilderArgs to StatementTmplArgs let mut mw_args = Vec::new(); for builder_arg in desugared.args { @@ -202,7 +300,7 @@ impl<'a> Lowerer<'a> { Ok(MWStatementTmpl { // TODO: Support wildcard - pred_or_wc: PredicateOrWildcard::Predicate(predicate), + pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate), args: mw_args, }) } @@ -251,7 +349,7 @@ impl<'a> Lowerer<'a> { fn extract_and_split_predicates( &self, - ) -> Result<(Vec, usize), LoweringError> { + ) -> Result, LoweringError> { let doc = self.validated.document(); let predicates: Vec = doc .items @@ -262,182 +360,14 @@ impl<'a> Lowerer<'a> { }) .collect(); - let original_count = predicates.len(); - // Apply splitting to each predicate as needed - let mut split_predicates = Vec::new(); + let mut split_results = Vec::new(); for pred in predicates { - let chain = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; - split_predicates.extend(chain); + let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; + split_results.push(result); } - Ok((split_predicates, original_count)) - } - - fn lower_custom_predicate( - &self, - pred_def: &CustomPredicateDef, - cpb_builder: &mut CustomPredicateBatchBuilder, - ) -> Result<(), LoweringError> { - let name = pred_def.name.name.clone(); - - // Note: Constraint checking is handled by the splitting phase - // Predicates passed here should already be within limits - - // Collect public and private argument names - let mut public_arg_names = Vec::new(); - let mut private_arg_names = Vec::new(); - - for arg in &pred_def.args.public_args { - public_arg_names.push(arg.name.clone()); - } - - if let Some(private_args) = &pred_def.args.private_args { - for arg in private_args { - private_arg_names.push(arg.name.clone()); - } - } - - // Lower statements to builders - let mut statement_builders = Vec::new(); - for stmt in &pred_def.statements { - let stmt_builder = self.lower_statement_to_builder(stmt)?; - statement_builders.push(stmt_builder); - } - - // Convert to &str slices for builder API - let public_args_str: Vec<&str> = public_arg_names.iter().map(|s| s.as_str()).collect(); - let private_args_str: Vec<&str> = private_arg_names.iter().map(|s| s.as_str()).collect(); - - // Add predicate to batch using builder - let conjunction = pred_def.conjunction_type == ConjunctionType::And; - - cpb_builder - .predicate( - &name, - conjunction, - &public_args_str, - &private_args_str, - &statement_builders, - ) - .map_err(|e| match e { - crate::frontend::Error::Middleware(mw_err) => LoweringError::Middleware(mw_err), - _ => LoweringError::InvalidArgumentType, - })?; - - Ok(()) - } - - fn lower_statement_to_builder( - &self, - stmt: &StatementTmpl, - ) -> Result { - // Get predicate - let pred_name = &stmt.predicate.name; - let symbols = self.validated.symbols(); - - // Check for native predicates first - let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) { - Predicate::Native(native) - } else if let Some(&index) = self.batch_predicate_index.get(pred_name) { - // References to other predicates in the same batch (including split chains) - Predicate::BatchSelf(index) - } else if let Some(info) = symbols.predicates.get(pred_name) { - match &info.kind { - PredicateKind::Native(np) => Predicate::Native(*np), - PredicateKind::Custom { index } => Predicate::BatchSelf(*index), - PredicateKind::BatchImported { batch, index } => { - Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index)) - } - PredicateKind::IntroImported { - name, - verifier_data_hash, - } => Predicate::Intro(IntroPredicateRef { - name: name.clone(), - args_len: info.public_arity, - verifier_data_hash: *verifier_data_hash, - }), - } - } else { - unreachable!("Predicate {} not found", pred_name); - }; - - // Check args count - if stmt.args.len() > self.params.max_statement_args { - return Err(LoweringError::TooManyStatementArgs { - count: stmt.args.len(), - max: self.params.max_statement_args, - }); - } - - // Convert AST args to BuilderArgs - let mut builder = StatementTmplBuilder::new(predicate); - for arg in &stmt.args { - let builder_arg = Self::lower_statement_arg_to_builder(arg)?; - builder = builder.arg(builder_arg); - } - - // Return builder without calling .desugar() - that will happen later - Ok(builder) - } - - fn lower_statement_arg_to_builder(arg: &StatementTmplArg) -> Result { - match arg { - StatementTmplArg::Literal(lit) => { - let value = Self::lower_literal(lit)?; - Ok(BuilderArg::Literal(value)) - } - StatementTmplArg::Wildcard(id) => { - // For builder, we just need the wildcard name - Ok(BuilderArg::WildcardLiteral(id.name.clone())) - } - StatementTmplArg::AnchoredKey(ak) => { - let key_str = match &ak.key { - AnchoredKeyPath::Bracket(s) => s.value.clone(), - AnchoredKeyPath::Dot(id) => id.name.clone(), - }; - Ok(BuilderArg::Key(ak.root.name.clone(), key_str)) - } - } - } - - fn lower_literal(lit: &LiteralValue) -> Result { - let value = match lit { - LiteralValue::Int(i) => middleware::Value::from(i.value), - LiteralValue::Bool(b) => middleware::Value::from(b.value), - LiteralValue::String(s) => middleware::Value::from(s.value.clone()), - LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash), - LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point), - LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()), - LiteralValue::Array(a) => { - let elements: Result, _> = - a.elements.iter().map(Self::lower_literal).collect(); - let array = containers::Array::new(elements?); - middleware::Value::from(array) - } - LiteralValue::Set(s) => { - let elements: Result, _> = - s.elements.iter().map(Self::lower_literal).collect(); - let set_values: std::collections::HashSet<_> = elements?.into_iter().collect(); - let set = containers::Set::new(set_values); - middleware::Value::from(set) - } - LiteralValue::Dict(d) => { - let pairs: Result, LoweringError> = d - .pairs - .iter() - .map(|pair| { - let key = middleware::Key::from(pair.key.value.as_str()); - let value = Self::lower_literal(&pair.value)?; - Ok((key, value)) - }) - .collect(); - let dict_map: std::collections::HashMap<_, _> = pairs?.into_iter().collect(); - let dict = containers::Dictionary::new(dict_map); - middleware::Value::from(dict) - } - }; - Ok(value) + Ok(split_results) } } @@ -458,9 +388,16 @@ mod tests { lower(validated, params, "test_batch".to_string()) } - // Helper to get the batch from the output (expecting it to exist) - fn expect_batch(output: &LoweredOutput) -> &Arc { - output.batch.as_ref().expect("Expected batch to be present") + // Helper to get the first batch from the output (expecting it to exist) + fn expect_batch( + output: &LoweredOutput, + ) -> &std::sync::Arc { + output + .batches + .as_ref() + .expect("Expected batches to be present") + .first_batch() + .expect("Expected at least one batch") } #[test] @@ -547,13 +484,20 @@ mod tests { let lowered = result.unwrap(); // Should be automatically split into 2 predicates (my_pred and my_pred_1) - assert_eq!(expect_batch(&lowered).predicates().len(), 2); + let batches = lowered.batches.as_ref().expect("Expected batches"); + assert_eq!(batches.total_predicate_count(), 2); - // First predicate should have 5 statements (4 + chain call) - assert_eq!(expect_batch(&lowered).predicates()[0].statements().len(), 5); - - // Second predicate should have 2 statements - assert_eq!(expect_batch(&lowered).predicates()[1].statements().len(), 2); + // With topological sorting, my_pred_1 comes first (since my_pred depends on it) + // my_pred_1 has 2 statements + // my_pred has 5 statements (4 + chain call) + // Just verify we have the right total statement counts + let batch = batches.first_batch().unwrap(); + let total_statements: usize = batch + .predicates() + .iter() + .map(|p| p.statements().len()) + .sum(); + assert_eq!(total_statements, 7); // 5 + 2 = 7 total statements } #[test] @@ -642,108 +586,64 @@ mod tests { } #[test] - fn test_error_message_with_splitting() { - // Create a document with predicates that will exceed the batch limit after splitting - // We'll create 2 predicates with 4 statements each (max arity = 5) - // Each will NOT split individually, but together they exceed a small batch limit + fn test_multi_batch_packing() { + // Create more predicates than fit in a single batch + // With max_custom_batch_size = 4, 5 predicates should span 2 batches let input = r#" - pred1(A) = AND ( - Equal(A["a"], 1) - Equal(A["b"], 2) - ) - pred2(B) = AND ( - Equal(B["c"], 3) - Equal(B["d"], 4) - ) + pred1(A) = AND(Equal(A["a"], 1)) + pred2(B) = AND(Equal(B["b"], 2)) + pred3(C) = AND(Equal(C["c"], 3)) + pred4(D) = AND(Equal(D["d"], 4)) + pred5(E) = AND(Equal(E["e"], 5)) "#; - // Use very restrictive params to force the error - let params = Params { - max_custom_batch_size: 1, - ..Default::default() - }; + let params = Params::default(); // max_custom_batch_size = 4 let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); - // Should fail with TooManyPredicates error - assert!(result.is_err()); - let err = result.unwrap_err(); + let lowered = result.unwrap(); + let batches = lowered.batches.as_ref().expect("Expected batches"); - if let LoweringError::TooManyPredicates { - count, - max, - original_count, - .. - } = err - { - assert_eq!(count, 2); // 2 predicates after splitting (no splitting occurred) - assert_eq!(max, 1); - assert_eq!(original_count, 2); // Started with 2 predicates + // Should have 2 batches + assert_eq!(batches.batch_count(), 2); + assert_eq!(batches.total_predicate_count(), 5); - // Error message should NOT mention splitting since no splitting occurred - let err_msg = format!("{}", err); - assert!(!err_msg.contains("before automatic splitting")); - } else { - panic!("Expected TooManyPredicates error, got: {:?}", err); - } + // First batch should have 4 predicates + assert_eq!(batches.batches()[0].predicates().len(), 4); + // Second batch should have 1 predicate + assert_eq!(batches.batches()[1].predicates().len(), 1); } #[test] - fn test_error_message_after_splitting() { - // Create TWO predicates that will EACH split into 2 predicates - // This tests the case where splitting causes the batch to be too large - // but no individual predicate chain exceeds the limit + fn test_split_chains_span_batches() { + // Create predicates that will split, plus additional predicates + // to force the split chains across batch boundaries let input = r#" - pred1(A) = AND ( - Equal(A["a"], 1) - Equal(A["b"], 2) - Equal(A["c"], 3) - Equal(A["d"], 4) - Equal(A["e"], 5) - Equal(A["f"], 6) - ) - pred2(B) = AND ( - Equal(B["a"], 1) - Equal(B["b"], 2) - Equal(B["c"], 3) - Equal(B["d"], 4) - Equal(B["e"], 5) - Equal(B["f"], 6) + pred1(A) = AND(Equal(A["a"], 1)) + pred2(B) = AND(Equal(B["b"], 2)) + pred3(C) = AND(Equal(C["c"], 3)) + large_pred(D) = AND( + Equal(D["a"], 1) + Equal(D["b"], 2) + Equal(D["c"], 3) + Equal(D["d"], 4) + Equal(D["e"], 5) + Equal(D["f"], 6) ) "#; - // Use params where each predicate splits into 2, but total of 4 exceeds batch limit - let params = Params { - // Allow 3 predicates in batch - // Default max_custom_predicate_arity is 5, so each will split into 2 predicates - // Total: 2 original predicates -> 4 after splitting (exceeds limit of 3) - max_custom_batch_size: 3, - ..Default::default() - }; + let params = Params::default(); let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); - // Should fail with TooManyPredicates error - assert!(result.is_err()); - let err = result.unwrap_err(); + let lowered = result.unwrap(); + let batches = lowered.batches.as_ref().expect("Expected batches"); - if let LoweringError::TooManyPredicates { - count, - max, - original_count, - .. - } = err - { - assert_eq!(count, 4); // 4 predicates after splitting (2 from each) - assert_eq!(max, 3); - assert_eq!(original_count, 2); // Started with 2 predicates - - // Error message SHOULD mention splitting since splitting occurred - let err_msg = format!("{}", err); - assert!(err_msg.contains("before automatic splitting")); - assert!(err_msg.contains("started with 2 predicates")); - } else { - panic!("Expected TooManyPredicates error, got: {:?}", err); - } + // pred1, pred2, pred3 + large_pred split into 2 = 5 total predicates + // Should span 2 batches + assert_eq!(batches.total_predicate_count(), 5); + assert_eq!(batches.batch_count(), 2); } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index a8e3780..303720e 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -34,6 +34,40 @@ pub struct ChainLink { pub public_args_out: Vec, } +/// Information about a single piece of a split predicate chain +#[derive(Debug, Clone)] +pub struct SplitChainPiece { + /// Name of this predicate piece (e.g., "foo_1") + pub name: String, + /// Number of "real" statements in this piece (excludes chain call) + pub real_statement_count: usize, + /// Whether this piece has a chain call to the next piece + pub has_chain_call: bool, +} + +/// Metadata about a split predicate chain +#[derive(Debug, Clone)] +pub struct SplitChainInfo { + /// Original predicate name (e.g., "foo") + pub original_name: String, + /// Chain pieces in execution order (innermost continuation first: [foo_2, foo_1, foo]) + pub chain_pieces: Vec, + /// Total number of "real" user statements (excludes chain calls) + pub real_statement_count: usize, + /// Maps original statement index → reordered index + /// e.g., if original stmt 0 became reordered stmt 3, then `reorder_map[0] = 3` + pub reorder_map: Vec, +} + +/// Result of splitting a predicate +#[derive(Debug, Clone)] +pub struct SplitResult { + /// The predicates (continuations first, original last if split) + pub predicates: Vec, + /// Split chain info, if splitting occurred (None for non-split) + pub chain_info: Option, +} + /// Wildcard usage information #[derive(Debug, Clone)] struct WildcardUsage { @@ -66,19 +100,25 @@ pub fn validate_predicate_is_splittable( pub fn split_predicate_if_needed( pred: CustomPredicateDef, params: &Params, -) -> Result, SplittingError> { +) -> Result { // Early validation validate_predicate_is_splittable(&pred, params)?; // If within limits, no splitting needed if pred.statements.len() <= params.max_custom_predicate_arity { - return Ok(vec![pred]); + return Ok(SplitResult { + predicates: vec![pred], + chain_info: None, + }); } // Need to split - execute the splitting algorithm - let chain = split_into_chain(pred, params)?; + let (predicates, chain_info) = split_into_chain(pred, params)?; - Ok(chain) + Ok(SplitResult { + predicates, + chain_info: Some(chain_info), + }) } fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap { @@ -121,18 +161,33 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { } /// Order constraints optimally to minimize liveness at boundaries +/// Result of ordering statements optimally for splitting +struct OrderingResult { + /// Reordered statements + statements: Vec, + /// Maps original statement index → reordered index + /// reorder_map[original_idx] = new_idx + reorder_map: Vec, +} + fn order_constraints_optimally( statements: Vec, _usage: &HashMap, params: &Params, -) -> Vec { - // If no splitting needed, preserve original order - if statements.len() <= params.max_custom_predicate_arity { - return statements; +) -> OrderingResult { + let n = statements.len(); + + // If no splitting needed, preserve original order (identity mapping) + if n <= params.max_custom_predicate_arity { + return OrderingResult { + statements, + reorder_map: (0..n).collect(), + }; } let mut ordered = Vec::new(); - let mut remaining: HashSet = (0..statements.len()).collect(); + let mut reorder_map = vec![0; n]; + let mut remaining: HashSet = (0..n).collect(); let mut active_wildcards: HashSet = HashSet::new(); while !remaining.is_empty() { @@ -146,6 +201,9 @@ fn order_constraints_optimally( remaining.remove(&best_idx); let stmt = &statements[best_idx]; + + // Record the mapping: original index best_idx → new index ordered.len() + reorder_map[best_idx] = ordered.len(); ordered.push(stmt.clone()); // Update active wildcards @@ -160,7 +218,10 @@ fn order_constraints_optimally( active_wildcards.retain(|w| needed_later.contains(w)); } - ordered + OrderingResult { + statements: ordered, + reorder_map, + } } /// Compute tie-breaker metrics for deterministic ordering when scores are equal @@ -360,16 +421,20 @@ fn generate_refactor_suggestion( } /// Split into chain using bucket-filling approach +/// Returns the split predicates and metadata about the split fn split_into_chain( pred: CustomPredicateDef, params: &Params, -) -> Result, SplittingError> { +) -> Result<(Vec, SplitChainInfo), SplittingError> { let original_name = pred.name.name.clone(); let conjunction = pred.conjunction_type; let usage = analyze_wildcards(&pred.statements); + let real_statement_count = pred.statements.len(); - let ordered_statements = order_constraints_optimally(pred.statements, &usage, params); + let ordering_result = order_constraints_optimally(pred.statements, &usage, params); + let ordered_statements = ordering_result.statements; + let reorder_map = ordering_result.reorder_map; let original_public_args: Vec = pred .args @@ -479,12 +544,43 @@ fn split_into_chain( } } - let chain_predicates = + // Build SplitChainInfo from chain_links before generating predicates + // Pieces are in execution order: innermost continuation first, original last + let num_links = chain_links.len(); + let mut chain_pieces = Vec::new(); + for i in (0..num_links).rev() { + let link = &chain_links[i]; + let is_last = i == num_links - 1; + let name = if i == 0 { + original_name.clone() + } else { + format!("{}_{}", original_name, i) + }; + chain_pieces.push(SplitChainPiece { + name, + real_statement_count: link.statements.len(), + has_chain_call: !is_last, + }); + } + + let chain_info = SplitChainInfo { + original_name: original_name.clone(), + chain_pieces, + real_statement_count, + reorder_map, + }; + + let mut chain_predicates = generate_chain_predicates(&original_name, chain_links, conjunction, params)?; - validate_chain(&chain_predicates, &original_name, params)?; + validate_chain(&chain_predicates, params)?; - Ok(chain_predicates) + // Reverse so continuations come before callers in declaration order. + // This ensures that when batched, continuations are in earlier batches + // and can be referenced by their callers. + chain_predicates.reverse(); + + Ok((chain_predicates, chain_info)) } /// Phase 4: Generate synthetic predicates from chain links @@ -519,20 +615,19 @@ fn generate_chain_predicates( span: None, }; - // Create arguments for chain call: all public args (incoming + promoted) - let mut chain_call_args = Vec::new(); - for arg_name in &link.public_args_in { - chain_call_args.push(StatementTmplArg::Wildcard(Identifier { - name: arg_name.clone(), - span: None, - })); - } - for arg_name in &link.public_args_out { - chain_call_args.push(StatementTmplArg::Wildcard(Identifier { - name: arg_name.clone(), - span: None, - })); - } + // Create arguments for chain call: use next link's public_args_in + // which is the deduplicated union of current public_args_in and public_args_out + let next_link = &chain_links[i + 1]; + let chain_call_args: Vec = next_link + .public_args_in + .iter() + .map(|arg_name| { + StatementTmplArg::Wildcard(Identifier { + name: arg_name.clone(), + span: None, + }) + }) + .collect(); let chain_call = StatementTmpl { predicate: next_pred_name, @@ -587,19 +682,10 @@ fn generate_chain_predicates( } /// Phase 5: Validate the generated chain -fn validate_chain( - chain: &[CustomPredicateDef], - original_name: &str, - params: &Params, -) -> Result<(), SplittingError> { - if chain.len() > params.max_custom_batch_size { - return Err(SplittingError::TooManyPredicatesInChain { - predicate: original_name.to_string(), - count: chain.len(), - max_allowed: params.max_custom_batch_size, - }); - } - +/// +/// Note: We no longer check chain length against max_custom_batch_size since +/// chains can now span multiple batches thanks to multi-batch support. +fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> { for pred in chain { // Each predicate should have ≤ max_statements assert!(pred.statements.len() <= params.max_custom_predicate_arity); @@ -681,8 +767,9 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); - assert_eq!(chain.len(), 1); // No split needed + let split_result = result.unwrap(); + assert_eq!(split_result.predicates.len(), 1); // No split needed + assert!(split_result.chain_info.is_none()); // No chain info for non-split } #[test] @@ -704,14 +791,29 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); + let split_result = result.unwrap(); + let chain = &split_result.predicates; assert_eq!(chain.len(), 2); // Should split into 2 predicates - // First predicate: 4 statements + chain call = 5 - assert_eq!(chain[0].statements.len(), 5); + // Chain is reversed: continuation comes first, original comes last + // Last predicate (index 1): original name, 4 statements + chain call = 5 + assert_eq!(chain[1].statements.len(), 5); + assert_eq!(chain[1].name.name, "my_pred"); - // Second predicate: 2 remaining statements - assert_eq!(chain[1].statements.len(), 2); + // First predicate (index 0): continuation, 2 remaining statements + assert_eq!(chain[0].statements.len(), 2); + assert_eq!(chain[0].name.name, "my_pred_1"); + + // Verify chain_info is present + let chain_info = split_result.chain_info.as_ref().unwrap(); + assert_eq!(chain_info.original_name, "my_pred"); + assert_eq!(chain_info.real_statement_count, 6); + assert_eq!(chain_info.chain_pieces.len(), 2); + // Pieces are in execution order: innermost first + assert_eq!(chain_info.chain_pieces[0].name, "my_pred_1"); + assert!(!chain_info.chain_pieces[0].has_chain_call); + assert_eq!(chain_info.chain_pieces[1].name, "my_pred"); + assert!(chain_info.chain_pieces[1].has_chain_call); } #[test] @@ -733,12 +835,15 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); + let split_result = result.unwrap(); + let chain = &split_result.predicates; assert_eq!(chain.len(), 2); // Should split into 2 predicates - // First predicate should have wildcards that cross boundary promoted - // Check that chain call is present - let last_stmt = &chain[0].statements.last().unwrap(); + // Chain is reversed: continuation first, original last + // Original predicate should have chain call as last statement + let original = &chain[1]; + assert_eq!(original.name.name, "complex"); + let last_stmt = original.statements.last().unwrap(); assert_eq!(last_stmt.predicate.name, "complex_1"); } @@ -766,15 +871,29 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); + let split_result = result.unwrap(); + let chain = &split_result.predicates; assert_eq!(chain.len(), 3); // Should split into 3 predicates - // First: 4 + chain call = 5 - assert_eq!(chain[0].statements.len(), 5); - // Second: 4 + chain call = 5 + // Chain is reversed: [_2, _1, original] + // chain[0]: large_pred_2 (3 remaining statements) + assert_eq!(chain[0].statements.len(), 3); + assert_eq!(chain[0].name.name, "large_pred_2"); + // chain[1]: large_pred_1 (4 + chain call = 5) assert_eq!(chain[1].statements.len(), 5); - // Third: 3 remaining - assert_eq!(chain[2].statements.len(), 3); + assert_eq!(chain[1].name.name, "large_pred_1"); + // chain[2]: large_pred (4 + chain call = 5) + assert_eq!(chain[2].statements.len(), 5); + assert_eq!(chain[2].name.name, "large_pred"); + + // Verify chain_info + let chain_info = split_result.chain_info.as_ref().unwrap(); + assert_eq!(chain_info.real_statement_count, 11); + assert_eq!(chain_info.chain_pieces.len(), 3); + // Execution order: innermost first + assert_eq!(chain_info.chain_pieces[0].name, "large_pred_2"); + assert_eq!(chain_info.chain_pieces[1].name, "large_pred_1"); + assert_eq!(chain_info.chain_pieces[2].name, "large_pred"); } #[test] @@ -801,7 +920,8 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); + let split_result = result.unwrap(); + let chain = &split_result.predicates; // Should split into 2 predicates // T is used in first segment and crosses to second, then used again in second assert_eq!(chain.len(), 2); @@ -867,7 +987,8 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); - let chain = result.unwrap(); + let split_result = result.unwrap(); + let chain = &split_result.predicates; assert_eq!(chain.len(), 2, "Predicate should split into 2 links"); let second_pred = &chain[1]; diff --git a/src/lang/mod.rs b/src/lang/mod.rs index da21569..432991c 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -1,5 +1,34 @@ +//! Podlang front-end: parsing, validation, lowering, and multi-batch output. +//! +//! This module is the high-level entrypoint to the Podlang pipeline. It: +//! - Parses a Podlang document (`parse_podlang`). +//! - Validates names, imports, and well-formedness (`frontend_ast_validate`). +//! - Lowers to middleware structures, including automatic predicate splitting and +//! dependency-aware packing into one or more custom predicate batches (`frontend_ast_split`, +//! `frontend_ast_batch`, `frontend_ast_lower`). +//! +//! The result is a [`PodlangOutput`], which contains: +//! - `custom_batches`: a [`PredicateBatches`] container (possibly empty) with all custom +//! predicates defined in the document. Use +//! [`PredicateBatches::apply_predicate`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate) +//! to apply a predicate into a `MainPodBuilder` (recommended primary API), or +//! [`apply_predicate_with`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate_with) +//! for advanced control. +//! - `request`: a `PodRequest` containing the request templates defined by a `REQUEST(...)` block +//! in the document (or empty if none was provided). +//! +//! Notes +//! - Predicate splitting: large predicates are automatically split into a chain of smaller +//! predicates while preserving semantics; only the final chain result is public when applying a +//! predicate as public. +//! - Multi-batch packing: predicates are packed dependency-aware; cross-batch references always +//! point to earlier batches and forward references cannot occur. +//! - Backwards compatibility: `PodlangOutput::first_batch()` is provided to ease migration of code +//! that expects a single custom predicate batch. +//! pub mod error; pub mod frontend_ast; +pub mod frontend_ast_batch; pub mod frontend_ast_lower; pub mod frontend_ast_split; pub mod frontend_ast_validate; @@ -9,6 +38,8 @@ pub mod pretty_print; use std::sync::Arc; pub use error::LangError; +pub use frontend_ast_batch::{MultiOperationError, PredicateBatches}; +pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult}; pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use pretty_print::PrettyPrint; @@ -17,12 +48,34 @@ use crate::{ middleware::{CustomPredicateBatch, Params}, }; -#[derive(Debug, Clone, PartialEq)] +/// Final result of processing a Podlang document. +/// +/// - `custom_batches`: all custom predicates defined in the document, possibly spanning multiple +/// batches. Use [`PredicateBatches`] APIs to look up predicates by name and apply them. +/// - `request`: the request templates defined in the document (empty if not present). +#[derive(Debug, Clone)] pub struct PodlangOutput { - pub custom_batch: Arc, + pub custom_batches: PredicateBatches, pub request: PodRequest, } +impl PodlangOutput { + /// Get the first batch, if any (for backwards compatibility). + /// + /// Prefer using `custom_batches` directly if your code expects multiple batches. + pub fn first_batch(&self) -> Option<&Arc> { + self.custom_batches.first_batch() + } +} + +/// Parse, validate, and lower a Podlang document into middleware structures. +/// +/// - `input`: Podlang source. +/// - `params`: middleware parameters limiting sizes/arity and controlling lowering behavior. +/// - `available_batches`: external batches available for `use batch ... from 0x...` imports. +/// +/// Returns a [`PodlangOutput`] containing custom predicate batches (if any) and a `PodRequest` +/// (possibly empty). pub fn parse( input: &str, params: &Params, @@ -37,10 +90,7 @@ pub fn parse( let validated = frontend_ast_validate::validate(document, available_batches)?; let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?; - let custom_batch = lowered.batch.unwrap_or_else(|| { - // If no batch, create an empty one - CustomPredicateBatch::new(params, "PodlangBatch".to_string(), vec![]) - }); + let custom_batches = lowered.batches.unwrap_or_default(); let request = lowered.request.unwrap_or_else(|| { // If no request, create an empty one @@ -48,7 +98,7 @@ pub fn parse( }); Ok(PodlangOutput { - custom_batch, + custom_batches, request, }) } @@ -93,6 +143,11 @@ mod tests { PredicateOrWildcard::Predicate(pred) } + // Helper to get the first batch from the output + fn first_batch(output: &super::PodlangOutput) -> &Arc { + output.first_batch().expect("Expected at least one batch") + } + #[test] fn test_e2e_simple_predicate() -> Result<(), LangError> { let input = r#" @@ -103,14 +158,12 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; + let batch_result = first_batch(&processed); let request_result = processed.request.templates(); assert_eq!(request_result.len(), 0); assert_eq!(batch_result.predicates.len(), 1); - let batch = batch_result; - // Expected structure let expected_statements = vec![StatementTmpl { pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), @@ -132,7 +185,7 @@ mod tests { vec![expected_predicate], ); - assert_eq!(batch, expected_batch); + assert_eq!(*batch_result, expected_batch); Ok(()) } @@ -148,10 +201,9 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; let request_templates = processed.request.templates(); - assert_eq!(batch_result.predicates.len(), 0); + assert!(processed.custom_batches.is_empty()); assert!(!request_templates.is_empty()); // Expected structure @@ -188,14 +240,12 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; + let batch_result = first_batch(&processed); let request_result = processed.request.templates(); assert_eq!(request_result.len(), 0); assert_eq!(batch_result.predicates.len(), 1); - let batch = batch_result; - // Expected structure: Public args: A (index 0). Private args: Temp (index 1) let expected_statements = vec![ StatementTmpl { @@ -226,7 +276,7 @@ mod tests { vec![expected_predicate], ); - assert_eq!(batch, expected_batch); + assert_eq!(*batch_result, expected_batch); Ok(()) } @@ -245,14 +295,12 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; + let batch_result = first_batch(&processed); let request_templates = processed.request.templates(); assert_eq!(batch_result.predicates.len(), 1); assert!(!request_templates.is_empty()); - let batch = batch_result; - // Expected Batch structure let expected_pred_statements = vec![StatementTmpl { pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), @@ -274,7 +322,7 @@ mod tests { vec![expected_predicate], ); - assert_eq!(batch, expected_batch); + assert_eq!(*batch_result, expected_batch); // Expected Request structure // Pod1 -> Wildcard 0, Pod2 -> Wildcard 1 @@ -311,7 +359,7 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; + let batch_result = first_batch(&processed); let request_templates = processed.request.templates(); assert_eq!(batch_result.predicates.len(), 1); // some_pred is defined @@ -324,7 +372,10 @@ mod tests { // Expected structure let expected_templates = vec![ StatementTmpl { - pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch_result, 0))), // Refers to some_pred + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + batch_result.clone(), + 0, + ))), // Refers to some_pred args: vec![ StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1 StatementTmplArg::Literal(Value::from(12345i64)), // 12345 @@ -361,10 +412,9 @@ mod tests { let params = Params::default(); let processed = parse(input, ¶ms, &[])?; - let batch_result = processed.custom_batch; let request_templates = processed.request.templates(); - assert_eq!(batch_result.predicates.len(), 0); + assert!(processed.custom_batches.is_empty()); assert!(!request_templates.is_empty()); let expected_templates = vec![ @@ -509,7 +559,7 @@ mod tests { ); assert!( - processed.custom_batch.predicates.is_empty(), + processed.custom_batches.is_empty(), "Expected no custom predicates for a REQUEST only input" ); @@ -560,7 +610,7 @@ mod tests { "Expected no request templates" ); assert_eq!( - processed.custom_batch.predicates.len(), + first_batch(&processed).predicates.len(), 4, "Expected 4 custom predicates" ); @@ -691,7 +741,8 @@ mod tests { ); assert_eq!( - processed.custom_batch, expected_batch, + *first_batch(&processed), + expected_batch, "Processed ETHDoS predicates do not match expected structure" ); @@ -739,7 +790,7 @@ mod tests { let request_templates = processed.request.templates(); assert!( - processed.custom_batch.predicates.is_empty(), + processed.custom_batches.is_empty(), "No custom predicates should be defined in the main input" ); assert_eq!(request_templates.len(), 1, "Expected one request template"); @@ -860,13 +911,13 @@ mod tests { "No request should be defined" ); assert_eq!( - processed.custom_batch.predicates.len(), + first_batch(&processed).predicates.len(), 1, "Expected one custom predicate to be defined" ); // 4. Check the resulting predicate definition - let defined_pred = &processed.custom_batch.predicates[0]; + let defined_pred = &first_batch(&processed).predicates[0]; assert_eq!(defined_pred.name, "wrapper_pred"); assert_eq!(defined_pred.statements.len(), 1); diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index 7440825..b176681 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -395,15 +395,17 @@ mod tests { parse(input, ¶ms, available_batches).expect("Initial parsing should succeed"); // Step 2: Pretty-print the parsed batch - let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + let batch = parsed_result.first_batch().expect("Expected batch"); + let pretty_printed = batch.to_podlang_string(); // Step 3: Parse the pretty-printed result let reparsed_result = parse(&pretty_printed, ¶ms, available_batches).expect("Reparsing should succeed"); + let reparsed_batch = reparsed_result.first_batch().expect("Expected batch"); // Step 4: Verify the ASTs are equivalent assert_eq!( - parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates, + batch.predicates, reparsed_batch.predicates, "Original AST should match reparsed AST.\nOriginal input:\n{}\nPretty-printed:\n{}\n", input, pretty_printed ); @@ -553,18 +555,17 @@ mod tests { let params = Params::default(); let parsed_result = parse(input, ¶ms, &[]).expect("Parsing should succeed"); + let batch = parsed_result.first_batch().expect("Expected batch"); - let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + let pretty_printed = batch.to_podlang_string(); println!("Original input:\n{}", input); println!("\nPretty-printed output:\n{}", pretty_printed); let reparsed = parse(&pretty_printed, ¶ms, &[]).expect("Reparsing should succeed"); + let reparsed_batch = reparsed.first_batch().expect("Expected batch"); - assert_eq!( - parsed_result.custom_batch.predicates, - reparsed.custom_batch.predicates - ); + assert_eq!(batch.predicates, reparsed_batch.predicates); } #[test] @@ -627,14 +628,16 @@ mod tests { let params = Params::default(); let parsed_result = parse(&input, ¶ms, &[]).expect("Should parse successfully"); + let batch = parsed_result.first_batch().expect("Expected batch"); - let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + let pretty_printed = batch.to_podlang_string(); let reparsed_result = parse(&pretty_printed, ¶ms, &[]).expect("Should reparse successfully"); + let reparsed_batch = reparsed_result.first_batch().expect("Expected batch"); assert_eq!( - parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates, + batch.predicates, reparsed_batch.predicates, "Round-trip failed for string: {:?}\nPretty-printed: {}", test_string, pretty_printed );