From 879c7201adc02b96b927203024cc6983ec5f04c7 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Fri, 30 Jan 2026 19:30:57 +0100 Subject: [PATCH] Fix parsing of intro statement templates inside custom predicates (#467) * Fix parsing of intro statement templates inside custom predicates * Tidy up comments --- src/lang/frontend_ast_batch.rs | 167 ++++++++++----------- src/lang/frontend_ast_lower.rs | 265 +++++++++++++++++++++++---------- 2 files changed, 263 insertions(+), 169 deletions(-) diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs index fe4b30e..fb58748 100644 --- a/src/lang/frontend_ast_batch.rs +++ b/src/lang/frontend_ast_batch.rs @@ -12,7 +12,7 @@ //! cross-batch calls always point to earlier batches via `CustomPredicateRef`. //! - Forward cross-batch references cannot occur with this planner (they are treated as unreachable). -use std::{collections::HashMap, str::FromStr, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use petgraph::{algo::condensation, graph::DiGraph, prelude::NodeIndex, visit::EdgeRef}; @@ -21,12 +21,11 @@ use crate::{ lang::{ error::BatchingError, frontend_ast::{ConjunctionType, CustomPredicateDef}, - frontend_ast_lower::lower_statement_arg, + frontend_ast_lower::{lower_statement_arg, resolve_predicate, ResolutionContext}, frontend_ast_split::{SplitChainInfo, SplitResult}, + frontend_ast_validate::SymbolTable, }, - middleware::{ - CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, Predicate, Statement, - }, + middleware::{CustomPredicateBatch, CustomPredicateRef, Params, Statement}, }; /// A single step in a multi-operation sequence for split predicates @@ -318,13 +317,6 @@ struct PredicateAssignment { index_in_batch: usize, } -/// Information about an imported predicate for use during batching -#[derive(Debug, Clone)] -pub struct ImportedPredicateInfo { - pub batch: Arc, - pub index: usize, -} - /// Pack predicates into multiple batches /// /// Takes a list of split results (containing predicates and optional chain info) @@ -337,13 +329,13 @@ pub struct ImportedPredicateInfo { /// - Within a batch, predicates can reference each other freely via `BatchSelf`; cross-batch /// references always point to earlier batches via `CustomPredicateRef`. /// -/// `imported_predicates` maps predicate names to their imported batch info, -/// allowing predicates to call imported predicates from other batches. +/// `symbols` provides the symbol table for resolving predicate references, +/// including imported predicates from other batches and intro predicates. pub fn batch_predicates( split_results: Vec, params: &Params, base_batch_name: &str, - imported_predicates: &HashMap, + symbols: &SymbolTable, ) -> Result { // Extract predicates and collect split chains let mut predicates = Vec::new(); @@ -403,7 +395,7 @@ pub fn batch_predicates( batch_idx, &reference_map, &batches, - imported_predicates, + symbols, params, &batch_name, )?; @@ -593,7 +585,7 @@ fn build_single_batch( batch_idx: usize, reference_map: &HashMap, existing_batches: &[Arc], - imported_predicates: &HashMap, + symbols: &SymbolTable, params: &Params, batch_name: &str, ) -> Result, BatchingError> { @@ -624,11 +616,10 @@ fn build_single_batch( .map(|stmt| { build_statement_with_resolved_refs( stmt, - name, batch_idx, reference_map, existing_batches, - imported_predicates, + symbols, ) }) .collect::>()?; @@ -654,46 +645,26 @@ fn build_single_batch( /// Build a statement template with properly resolved predicate references fn build_statement_with_resolved_refs( stmt: &crate::lang::frontend_ast::StatementTmpl, - caller_name: &str, current_batch_idx: usize, reference_map: &HashMap, existing_batches: &[Arc], - imported_predicates: &HashMap, + symbols: &SymbolTable, ) -> Result { let callee_name = &stmt.predicate.name; - // Resolve the predicate - let predicate = if let Ok(native) = NativePredicate::from_str(callee_name) { - Predicate::Native(native) - } else if let Some(&(target_batch, target_idx)) = reference_map.get(callee_name) { - // Local predicate in this document - if target_batch == current_batch_idx { - // Same batch - use BatchSelf - Predicate::BatchSelf(target_idx) - } else if target_batch < current_batch_idx { - // Earlier batch - use Custom ref - let batch = &existing_batches[target_batch]; - Predicate::Custom(CustomPredicateRef::new(batch.clone(), target_idx)) - } else { - // Forward reference to a later batch should be impossible with the dependency-aware planner - unreachable!( - "Forward cross-batch reference: '{}' (batch {}) -> '{}' (batch {})", - caller_name, current_batch_idx, callee_name, target_batch - ); - } - } else if let Some(imported) = imported_predicates.get(callee_name) { - // Imported predicate from another batch - Predicate::Custom(CustomPredicateRef::new( - imported.batch.clone(), - imported.index, - )) - } else { - // Unknown predicate - return Err(BatchingError::Internal { - message: format!("Unknown predicate reference: '{}'", callee_name), - }); + // Resolve the predicate using the unified resolution function + let context = ResolutionContext::Batch { + current_batch_idx, + reference_map, + existing_batches, }; + let predicate = resolve_predicate(callee_name, symbols, &context).ok_or_else(|| { + BatchingError::Internal { + message: format!("Unknown predicate reference: '{}'", callee_name), + } + })?; + // Build the statement template let mut builder = StatementTmplBuilder::new(predicate); @@ -709,24 +680,30 @@ mod tests { use super::*; use crate::{ lang::{ - frontend_ast::parse::parse_document, frontend_ast_split::split_predicate_if_needed, + frontend_ast::parse::parse_document, + frontend_ast_split::split_predicate_if_needed, + frontend_ast_validate::{validate, ValidatedAST}, parser::parse_podlang, }, - middleware::PredicateOrWildcard, + middleware::{Predicate, PredicateOrWildcard}, }; - fn parse_predicates(input: &str) -> Vec { + /// Helper: parse and validate input, returning predicates and symbol table + fn parse_and_validate(input: &str) -> (Vec, ValidatedAST) { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + let validated = validate(document.clone(), &[]).expect("Failed to validate"); - document + let predicates = document .items .into_iter() .filter_map(|item| match item { crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred), _ => None, }) - .collect() + .collect(); + + (predicates, validated) } /// Helper: wrap predicates into SplitResult (without actually splitting) @@ -748,14 +725,14 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -772,14 +749,14 @@ mod tests { pred3(C) = AND(Equal(C["z"], 3)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // max_custom_batch_size = 4 let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -798,14 +775,14 @@ mod tests { pred5(E) = AND(Equal(E["v"], 5)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // max_custom_batch_size = 4 let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -828,14 +805,14 @@ mod tests { pred1(A) = AND(Equal(A["x"], 1)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -861,14 +838,14 @@ mod tests { pred2(B) = AND(pred1(B)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -902,14 +879,14 @@ mod tests { pred5(E) = AND(pred1(E)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // max_custom_batch_size = 4 let result = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_ok()); @@ -949,7 +926,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // Split the large predicate @@ -966,7 +943,7 @@ mod tests { // That's 5 predicates, which spans 2 batches assert_eq!(total_preds, 5); - let result = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new()); + let result = batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols()); assert!(result.is_ok()); let batches = result.unwrap(); @@ -995,14 +972,14 @@ mod tests { pred5(E) = AND(Equal(E["v"], 5)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // max_custom_batch_size = 4 let batches = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ) .expect("Planner should avoid forward cross-batch reference"); @@ -1019,8 +996,13 @@ mod tests { fn test_empty_input() { let split_results: Vec = vec![]; let params = Params::default(); + // For empty input, we need an empty symbol table + let empty_symbols = SymbolTable { + predicates: HashMap::new(), + wildcard_scopes: HashMap::new(), + }; - let result = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()); + let result = batch_predicates(split_results, ¶ms, "TestBatch", &empty_symbols); assert!(result.is_ok()); let batches = result.unwrap(); @@ -1035,14 +1017,14 @@ mod tests { pred2(B) = AND(Equal(B["y"], 2)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let batches = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ) .unwrap(); @@ -1061,7 +1043,7 @@ mod tests { pred2(B) = AND(pred1(B)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params { max_custom_batch_size: 1, // force SCC > capacity ..Default::default() @@ -1071,7 +1053,7 @@ mod tests { preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ); assert!(result.is_err()); assert!(result @@ -1098,7 +1080,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // max_custom_batch_size = 4 // Split and batch @@ -1107,8 +1089,9 @@ mod tests { let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); all_split_results.push(result); } - let batches = batch_predicates(all_split_results, ¶ms, "TestBatch", &HashMap::new()) - .expect("Batch failed"); + let batches = + batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols()) + .expect("Batch failed"); assert_eq!(batches.batch_count(), 2); @@ -1147,14 +1130,14 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let batches = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ) .unwrap(); @@ -1195,7 +1178,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // Split the predicate @@ -1210,7 +1193,7 @@ mod tests { assert_eq!(split_results[0].predicates.len(), 2); assert!(split_results[0].chain_info.is_some()); - let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) .expect("Batch failed"); // Verify chain info @@ -1259,7 +1242,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // Split the predicate @@ -1274,7 +1257,7 @@ mod tests { assert_eq!(split_results[0].predicates.len(), 3); assert!(split_results[0].chain_info.is_some()); - let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) .expect("Batch failed"); // Verify chain info @@ -1320,7 +1303,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); // Split the predicate @@ -1330,7 +1313,7 @@ mod tests { split_results.push(result); } - let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) .expect("Batch failed"); // Try with wrong number of statements (3 instead of 6) @@ -1363,14 +1346,14 @@ mod tests { my_pred(A) = AND(Equal(A["x"], 1)) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let batches = batch_predicates( preds_to_split_results(predicates), ¶ms, "TestBatch", - &HashMap::new(), + validated.symbols(), ) .unwrap(); @@ -1401,7 +1384,7 @@ mod tests { ) "#; - let predicates = parse_predicates(input); + let (predicates, validated) = parse_and_validate(input); let params = Params::default(); let mut split_results = Vec::new(); @@ -1410,7 +1393,7 @@ mod tests { split_results.push(result); } - let batches = batch_predicates(split_results, ¶ms, "TestBatch", &HashMap::new()) + let batches = batch_predicates(split_results, ¶ms, "TestBatch", validated.symbols()) .expect("Batch failed"); let statements: Vec = (0..6).map(test_statement).collect(); diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index 0e7238d..ce693b9 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -6,6 +6,7 @@ use std::{ collections::{HashMap, HashSet}, str::FromStr, + sync::Arc, }; use crate::{ @@ -14,15 +15,124 @@ use crate::{ frontend_ast::*, frontend_ast_batch::{self, PredicateBatches}, frontend_ast_split, - frontend_ast_validate::{PredicateKind, ValidatedAST}, + frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, }, middleware::{ - self, containers, IntroPredicateRef, NativePredicate, Params, Predicate, - PredicateOrWildcard, StatementTmpl as MWStatementTmpl, - StatementTmplArg as MWStatementTmplArg, Wildcard, + containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key, + NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl, + StatementTmplArg as MWStatementTmplArg, Value, Wildcard, }, }; +/// Context for predicate resolution - determines how local custom predicates are resolved +pub enum ResolutionContext<'a> { + /// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches + Request { + batches: Option<&'a PredicateBatches>, + }, + /// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef + Batch { + current_batch_idx: usize, + reference_map: &'a HashMap, + existing_batches: &'a [Arc], + }, +} + +/// Resolve a predicate name to a Predicate using the symbol table +pub fn resolve_predicate( + pred_name: &str, + symbols: &SymbolTable, + context: &ResolutionContext, +) -> Option { + // 1. Try native predicate first + if let Ok(native) = NativePredicate::from_str(pred_name) { + return Some(Predicate::Native(native)); + } + + // 2. Look up in symbol table + if let Some(info) = symbols.predicates.get(pred_name) { + let predicate = match &info.kind { + PredicateKind::Native(np) => Predicate::Native(*np), + + PredicateKind::Custom { .. } => match context { + ResolutionContext::Request { batches } => { + let batches = batches.as_ref()?; + let pred_ref = batches.predicate_ref_by_name(pred_name)?; + Predicate::Custom(pred_ref) + } + ResolutionContext::Batch { + current_batch_idx, + reference_map, + existing_batches, + } => resolve_local_predicate( + pred_name, + *current_batch_idx, + reference_map, + existing_batches, + )?, + }, + + PredicateKind::BatchImported { batch, index } => { + Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index)) + } + + PredicateKind::IntroImported { + name, + verifier_data_hash, + } => Predicate::Intro(IntroPredicateRef { + name: name.clone(), + args_len: info.public_arity, + verifier_data_hash: *verifier_data_hash, + }), + }; + return Some(predicate); + } + + // 3. In batch context, also check reference_map for split chain pieces + // (predicates created by splitting that aren't in the original symbol table) + if let ResolutionContext::Batch { + current_batch_idx, + reference_map, + existing_batches, + } = context + { + if reference_map.contains_key(pred_name) { + return resolve_local_predicate( + pred_name, + *current_batch_idx, + reference_map, + existing_batches, + ); + } + } + + None +} + +/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map +fn resolve_local_predicate( + pred_name: &str, + current_batch_idx: usize, + reference_map: &HashMap, + existing_batches: &[Arc], +) -> Option { + let &(target_batch, target_idx) = reference_map.get(pred_name)?; + if target_batch == current_batch_idx { + Some(Predicate::BatchSelf(target_idx)) + } else if target_batch < current_batch_idx { + let batch = &existing_batches[target_batch]; + Some(Predicate::Custom(CustomPredicateRef::new( + batch.clone(), + target_idx, + ))) + } else { + unreachable!( + "Forward cross-batch reference should be impossible: {} -> {}", + current_batch_idx, target_batch + ); + } +} + // ============================================================================ // Shared lowering utilities // ============================================================================ @@ -33,37 +143,37 @@ use crate::{ /// Lower a literal value from AST to middleware Value. /// /// This is a pure conversion that cannot fail. -pub fn lower_literal(lit: &LiteralValue) -> middleware::Value { +pub fn lower_literal(lit: &LiteralValue) -> Value { match lit { - LiteralValue::Int(i) => middleware::Value::from(i.value), - LiteralValue::Bool(b) => middleware::Value::from(b.value), - LiteralValue::String(s) => middleware::Value::from(s.value.clone()), - LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash), - LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point), - LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()), + LiteralValue::Int(i) => Value::from(i.value), + LiteralValue::Bool(b) => Value::from(b.value), + LiteralValue::String(s) => Value::from(s.value.clone()), + LiteralValue::Raw(r) => Value::from(r.hash.hash), + LiteralValue::PublicKey(pk) => Value::from(pk.point), + LiteralValue::SecretKey(sk) => Value::from(sk.secret_key.clone()), LiteralValue::Array(a) => { let elements: Vec<_> = a.elements.iter().map(lower_literal).collect(); let array = containers::Array::new(elements); - middleware::Value::from(array) + Value::from(array) } LiteralValue::Set(s) => { let elements: std::collections::HashSet<_> = s.elements.iter().map(lower_literal).collect(); let set = containers::Set::new(elements); - middleware::Value::from(set) + Value::from(set) } LiteralValue::Dict(d) => { - let pairs: std::collections::HashMap<_, _> = d + let pairs: HashMap<_, _> = d .pairs .iter() .map(|pair| { - let key = middleware::Key::from(pair.key.value.as_str()); + let key = Key::from(pair.key.value.as_str()); let value = lower_literal(&pair.value); (key, value) }) .collect(); let dict = containers::Dictionary::new(pairs); - middleware::Value::from(dict) + Value::from(dict) } } } @@ -151,41 +261,18 @@ impl<'a> Lowerer<'a> { return Ok(None); } - // Build map of imported predicates for batching - let imported_predicates = self.build_imported_predicates_map(); - // Use the new batching module to pack predicates into batches + // Pass the symbol table for unified predicate resolution let batches = frontend_ast_batch::batch_predicates( custom_predicates, self.params, &batch_name, - &imported_predicates, + self.validated.symbols(), )?; Ok(Some(batches)) } - fn build_imported_predicates_map( - &self, - ) -> HashMap { - let symbols = self.validated.symbols(); - let mut imported = HashMap::new(); - - for (name, info) in &symbols.predicates { - if let PredicateKind::BatchImported { batch, index } = &info.kind { - imported.insert( - name.clone(), - frontend_ast_batch::ImportedPredicateInfo { - batch: batch.clone(), - index: *index, - }, - ); - } - } - - imported - } - fn lower_request( &self, batches: Option<&PredicateBatches>, @@ -232,42 +319,13 @@ impl<'a> Lowerer<'a> { let pred_name = &stmt.predicate.name; let symbols = self.validated.symbols(); - // Resolve predicate - for request statements, local custom predicates - // must be resolved to CustomPredicateRef (not BatchSelf) - let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) { - Predicate::Native(native) - } else if let Some(info) = symbols.predicates.get(pred_name) { - match &info.kind { - PredicateKind::Native(np) => Predicate::Native(*np), - PredicateKind::Custom { .. } => { - // Local custom predicates - resolve to CustomPredicateRef - let batches = batches.ok_or_else(|| LoweringError::PredicateNotFound { - name: pred_name.clone(), - })?; - let pred_ref = batches.predicate_ref_by_name(pred_name).ok_or_else(|| { - LoweringError::PredicateNotFound { - name: pred_name.clone(), - } - })?; - Predicate::Custom(pred_ref) - } - PredicateKind::BatchImported { batch, index } => { - Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index)) - } - PredicateKind::IntroImported { - name, - verifier_data_hash, - } => Predicate::Intro(IntroPredicateRef { - name: name.clone(), - args_len: info.public_arity, - verifier_data_hash: *verifier_data_hash, - }), - } - } else { - return Err(LoweringError::PredicateNotFound { + // Resolve predicate using the unified resolution function + let context = ResolutionContext::Request { batches }; + let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| { + LoweringError::PredicateNotFound { name: pred_name.clone(), - }); - }; + } + })?; // Create a builder with the resolved predicate and desugar let mut builder = StatementTmplBuilder::new(predicate); @@ -291,7 +349,7 @@ impl<'a> Lowerer<'a> { .get(&root_name) .expect("Root wildcard not found"); let wildcard = Wildcard::new(root_name, *root_index); - let key = middleware::Key::from(key_str.as_str()); + let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } }; @@ -646,4 +704,57 @@ mod tests { assert_eq!(batches.total_predicate_count(), 5); assert_eq!(batches.batch_count(), 2); } + + #[test] + fn test_intro_predicate_in_custom_predicate() { + use hex::ToHex; + + use crate::middleware::EMPTY_HASH; + + // Import an intro predicate and use it inside a custom predicate definition + let intro_hash = EMPTY_HASH.encode_hex::(); + let input = format!( + r#" + use intro external_check(X) from 0x{intro_hash} + + my_pred(A) = AND ( + Equal(A["foo"], 42) + external_check(A) + ) + "# + ); + + let params = Params::default(); + + // Parse, validate, and lower + let parsed = parse_podlang(&input).expect("Failed to parse"); + let document = + parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document"); + let validated = validate(document, &[]).expect("Failed to validate"); + let result = lower(validated, ¶ms, "test_batch".to_string()); + + assert!(result.is_ok(), "Lowering failed: {:?}", result.err()); + + let lowered = result.unwrap(); + let batch = expect_batch(&lowered); + + // Should have one custom predicate + assert_eq!(batch.predicates().len(), 1); + + let pred = &batch.predicates()[0]; + assert_eq!(pred.name, "my_pred"); + // 2 statements: Equal and external_check + assert_eq!(pred.statements().len(), 2); + + // Verify the second statement is an intro predicate reference + let intro_stmt = &pred.statements()[1]; + match intro_stmt.pred_or_wc() { + PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => { + assert_eq!(intro_ref.name, "external_check"); + assert_eq!(intro_ref.args_len, 1); + assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH); + } + other => panic!("Expected Intro predicate, got {:?}", other), + } + } }