From 21ab3c2d0d77042a3d37ba9977c592c8d44ef64d Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Fri, 13 Jun 2025 19:09:08 +0200 Subject: [PATCH] Basic 'use' syntax for importing custom predicates (#286) * Basic 'use' syntax for importing custom predicates * Add extra test for unknown batches * Fix unused import * Enforce that imports must match number of predicates in a batch --- src/lang/error.rs | 18 ++- src/lang/grammar.pest | 9 +- src/lang/mod.rs | 281 +++++++++++++++++++++++++++++++---- src/lang/parser.rs | 12 +- src/lang/processor.rs | 286 ++++++++++++++++++++++++------------ src/middleware/basetypes.rs | 29 +++- 6 files changed, 499 insertions(+), 136 deletions(-) diff --git a/src/lang/error.rs b/src/lang/error.rs index b359f68..b24d8d7 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -17,7 +17,7 @@ pub enum LangError { Frontend(Box), } -/// Errors that can occur during the processing of Podlog Pest tree into middleware structures. +/// Errors that can occur during the processing of Podlang Pest tree into middleware structures. #[derive(thiserror::Error, Debug)] pub enum ProcessorError { #[error("Undefined identifier: '{name}' at {span:?}")] @@ -71,6 +71,22 @@ pub enum ProcessorError { value: String, span: Option<(usize, usize)>, }, + #[error("Batch with ID '{id}' not found at {span:?}")] + BatchNotFound { + id: String, + span: Option<(usize, usize)>, + }, + #[error("Number of predicates in 'use' statement ({found}) exceeds the number of predicates in the batch ({expected}) at {span:?}")] + ImportArityMismatch { + expected: usize, + found: usize, + span: Option<(usize, usize)>, + }, + #[error("Duplicate import name '{name}' at {span:?}")] + DuplicateImportName { + name: String, + span: Option<(usize, usize)>, + }, #[error("Frontend error: {0}")] Frontend(#[from] frontend::Error), } diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index bd75359..0a7de02 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -1,4 +1,4 @@ -// Grammar for the "Podlog" language. Used for describing POD2 Custom +// Grammar for the "Podlang" language. Used for describing POD2 Custom // Predicates and Proof Requests. // Silent rules (`_`) are automatically handled by Pest between other rules. @@ -27,7 +27,12 @@ arg_section = { public_arg_list = { identifier ~ ("," ~ identifier)* } private_arg_list = { identifier ~ ("," ~ identifier)* } -document = { SOI ~ (custom_predicate_def | request_def)* ~ EOI } +document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI } + +use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref } +use_predicate_list = { import_name ~ ("," ~ import_name)* } +import_name = { identifier | "_" } +batch_ref = { literal_raw } request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 5d527a9..6022ba9 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -2,28 +2,37 @@ pub mod error; pub mod parser; pub mod processor; +use std::sync::Arc; + pub use error::LangError; -pub use parser::{parse_podlog, Pairs, ParseError, Rule}; +pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use processor::process_pest_tree; -use processor::ProcessedOutput; +use processor::PodlangOutput; -use crate::middleware::Params; +use crate::middleware::{CustomPredicateBatch, Params}; -pub fn parse(input: &str, params: &Params) -> Result { - let pairs = parse_podlog(input)?; - processor::process_pest_tree(pairs, params).map_err(LangError::from) +pub fn parse( + input: &str, + params: &Params, + available_batches: &[Arc], +) -> Result { + let pairs = parse_podlang(input)?; + processor::process_pest_tree(pairs, params, available_batches).map_err(LangError::from) } #[cfg(test)] mod tests { - + use hex::ToHex; use pretty_assertions::assert_eq; use super::*; - use crate::middleware::{ - CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard, - NativePredicate, Params, PodType, Predicate, SelfOrWildcard, StatementTmpl, - StatementTmplArg, Value, Wildcard, SELF_ID_HASH, + use crate::{ + lang::error::ProcessorError, + middleware::{ + CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard, + NativePredicate, Params, PodType, Predicate, SelfOrWildcard, StatementTmpl, + StatementTmplArg, Value, Wildcard, SELF_ID_HASH, + }, }; // Helper functions @@ -67,8 +76,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_result = processed.request_templates; @@ -92,8 +100,11 @@ mod tests { 2, // args_len (PodA, PodB) names(&["PodA", "PodB"]), )?; - let expected_batch = - CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + let expected_batch = CustomPredicateBatch::new( + ¶ms, + "PodlangBatch".to_string(), + vec![expected_predicate], + ); assert_eq!(batch, expected_batch); @@ -110,8 +121,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_templates = processed.request_templates; @@ -153,8 +163,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_result = processed.request_templates; @@ -187,8 +196,11 @@ mod tests { 1, // args_len (A) names(&["A", "Temp"]), )?; - let expected_batch = - CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + let expected_batch = CustomPredicateBatch::new( + ¶ms, + "PodlangBatch".to_string(), + vec![expected_predicate], + ); assert_eq!(batch, expected_batch); @@ -208,8 +220,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_templates = processed.request_templates; @@ -234,8 +245,11 @@ mod tests { 2, // args_len (X, Y) names(&["X", "Y"]), )?; - let expected_batch = - CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + let expected_batch = CustomPredicateBatch::new( + ¶ms, + "PodlangBatch".to_string(), + vec![expected_predicate], + ); assert_eq!(batch, expected_batch); @@ -270,8 +284,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_templates = processed.request_templates; @@ -323,8 +336,7 @@ mod tests { "#; let params = Params::default(); - let pairs = parse_podlog(input)?; - let processed = process_pest_tree(pairs, ¶ms)?; + let processed = parse(input, ¶ms, &[])?; let batch_result = processed.custom_batch; let request_templates = processed.request_templates; @@ -384,7 +396,7 @@ mod tests { "#; // Parse the input string - let processed = super::parse(input, &Params::default())?; + let processed = super::parse(input, &Params::default(), &[])?; let parsed_templates = processed.request_templates; // Define Expected Templates (Copied from prover/mod.rs) @@ -529,7 +541,7 @@ mod tests { ) "#; - let processed = super::parse(input, ¶ms)?; + let processed = super::parse(input, ¶ms, &[])?; assert!( processed.request_templates.is_empty(), @@ -681,7 +693,7 @@ mod tests { let expected_batch = CustomPredicateBatch::new( ¶ms, - "PodlogBatch".to_string(), + "PodlangBatch".to_string(), vec![ expected_friend_pred, expected_base_pred, @@ -697,4 +709,209 @@ mod tests { Ok(()) } + + #[test] + fn test_e2e_use_statement() -> Result<(), LangError> { + let params = Params::default(); + + // 1. Create a batch to be imported + let imported_pred_stmts = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("A", 0), k("foo")), // ?A["foo"] + sta_ak(("B", 1), k("bar")), // ?B["bar"] + ], + }]; + let imported_predicate = CustomPredicate::and( + ¶ms, + "imported_equal".to_string(), + imported_pred_stmts, + 2, + names(&["A", "B"]), + )?; + let available_batch = + CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![imported_predicate]); + let available_batches = vec![available_batch.clone()]; + + // 2. Create the input string that uses the batch + let batch_id_str = available_batch.id().encode_hex::(); + let input = format!( + r#" + use imported_pred from 0x{} + + REQUEST( + imported_pred(?Pod1, ?Pod2) + ) + "#, + batch_id_str + ); + + // 3. Parse the input + let processed = parse(&input, ¶ms, &available_batches)?; + let request_templates = processed.request_templates; + + assert!( + processed.custom_batch.predicates.is_empty(), + "No custom predicates should be defined in the main input" + ); + assert_eq!(request_templates.len(), 1, "Expected one request template"); + + // 4. Check the resulting request template + let expected_request_templates = vec![StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)), + args: vec![ + StatementTmplArg::WildcardLiteral(wc("Pod1", 0)), + StatementTmplArg::WildcardLiteral(wc("Pod2", 1)), + ], + }]; + + assert_eq!(request_templates, expected_request_templates); + + Ok(()) + } + + #[test] + fn test_e2e_use_statement_complex() -> Result<(), LangError> { + let params = Params::default(); + + // 1. Create a batch with multiple predicates + let pred1 = CustomPredicate::and(¶ms, "p1".into(), vec![], 1, names(&["A"]))?; + let pred2 = CustomPredicate::and(¶ms, "p2".into(), vec![], 2, names(&["B", "C"]))?; + let pred3 = CustomPredicate::and(¶ms, "p3".into(), vec![], 1, names(&["D"]))?; + + let available_batch = + CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![pred1, pred2, pred3]); + let available_batches = vec![available_batch.clone()]; + + // 2. Create the input string that uses the batch with skips + let batch_id_str = available_batch.id().encode_hex::(); + + let input = format!( + r#" + use pred_one, _, pred_three from 0x{} + + REQUEST( + pred_one(?Pod1) + pred_three(?Pod2) + ) + "#, + batch_id_str + ); + + // 3. Parse the input + let processed = parse(&input, ¶ms, &available_batches)?; + let request_templates = processed.request_templates; + + assert_eq!(request_templates.len(), 2, "Expected two request templates"); + + // 4. Check the resulting request templates + let expected_templates = vec![ + StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)), + args: vec![StatementTmplArg::WildcardLiteral(wc("Pod1", 0))], + }, + StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)), + args: vec![StatementTmplArg::WildcardLiteral(wc("Pod2", 1))], + }, + ]; + + assert_eq!(request_templates, expected_templates); + + Ok(()) + } + + #[test] + fn test_e2e_custom_predicate_uses_import() -> Result<(), LangError> { + let params = Params::default(); + + // 1. Create a batch with a predicate to be imported + let imported_pred_stmts = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("A", 0), k("foo")), sta_ak(("B", 1), k("bar"))], + }]; + let imported_predicate = CustomPredicate::and( + ¶ms, + "imported_equal".to_string(), + imported_pred_stmts, + 2, + names(&["A", "B"]), + )?; + let available_batch = + CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![imported_predicate]); + let available_batches = vec![available_batch.clone()]; + + // 2. Create the input string that defines a new predicate using the imported one + let batch_id_str = available_batch.id().encode_hex::(); + + let input = format!( + r#" + use imported_eq from 0x{} + + wrapper_pred(X, Y) = AND( + imported_eq(?X, ?Y) + ) + "#, + batch_id_str + ); + + // 3. Parse the input + let processed = parse(&input, ¶ms, &available_batches)?; + + assert!( + processed.request_templates.is_empty(), + "No request should be defined" + ); + assert_eq!( + processed.custom_batch.predicates.len(), + 1, + "Expected one custom predicate to be defined" + ); + + // 4. Check the resulting predicate definition + let defined_pred = &processed.custom_batch.predicates[0]; + assert_eq!(defined_pred.name, "wrapper_pred"); + assert_eq!(defined_pred.statements.len(), 1); + + let expected_statement = StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)), + args: vec![ + StatementTmplArg::WildcardLiteral(wc("X", 0)), + StatementTmplArg::WildcardLiteral(wc("Y", 1)), + ], + }; + + assert_eq!(defined_pred.statements[0], expected_statement); + + Ok(()) + } + + #[test] + fn test_e2e_use_unknown_batch() { + let params = Params::default(); + let available_batches = &[]; + + let unknown_batch_id = format!("0x{}", "a".repeat(64)); + + let input = format!( + r#" + use some_pred from {} + "#, + unknown_batch_id + ); + + let result = parse(&input, ¶ms, available_batches); + + assert!(result.is_err()); + + match result.err().unwrap() { + LangError::Processor(e) => match *e { + ProcessorError::BatchNotFound { id, .. } => { + assert_eq!(id, unknown_batch_id); + } + _ => panic!("Expected BatchNotFound error, but got {:?}", e), + }, + e => panic!("Expected LangError::Processor, but got {:?}", e), + } + } } diff --git a/src/lang/parser.rs b/src/lang/parser.rs index e2e6634..a648fe6 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -6,7 +6,7 @@ use pest_derive::Parser; // and changes to the grammar file will not automatically trigger a recompile. #[derive(Parser)] #[grammar = "lang/grammar.pest"] -pub struct PodlogParser; +pub struct PodlangParser; pub type Pairs<'a, R> = PestPairs<'a, R>; @@ -22,9 +22,9 @@ impl From> for ParseError { } } -/// Parses a Podlog input string according to the grammar rules. -pub fn parse_podlog(input: &str) -> Result, ParseError> { - let pairs = PodlogParser::parse(Rule::document, input)?; +/// Parses a Podlang input string according to the grammar rules. +pub fn parse_podlang(input: &str) -> Result, ParseError> { + let pairs = PodlangParser::parse(Rule::document, input)?; Ok(pairs) } @@ -33,14 +33,14 @@ mod tests { use super::*; fn assert_parses(rule: Rule, input: &str) { - match PodlogParser::parse(rule, input) { + match PodlangParser::parse(rule, input) { Ok(_) => (), // Successfully parsed Err(e) => panic!("Failed to parse input:\n{}\nError: {}", input, e), } } fn assert_fails(rule: Rule, input: &str) { - match PodlogParser::parse(rule, input) { + match PodlangParser::parse(rule, input) { Ok(pairs) => panic!( "Expected parse to fail, but it succeeded. Parsed:\n{:#?}", pairs diff --git a/src/lang/processor.rs b/src/lang/processor.rs index ae5549e..263c399 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -53,13 +53,15 @@ pub fn native_predicate_from_string(s: &str) -> Option { } #[derive(Debug, Clone, PartialEq)] -pub struct ProcessedOutput { +pub struct PodlangOutput { pub custom_batch: Arc, pub request_templates: Vec, } struct ProcessingContext<'a> { params: &'a Params, + /// Maps imported predicate names to their full reference (batch and index) + imported_predicates: HashMap, /// Maps predicate names to their batch index and public argument count (from Pass 1) custom_predicate_signatures: HashMap, /// Stores the original Pest pairs for custom predicate definitions for Pass 2 @@ -72,6 +74,7 @@ impl<'a> ProcessingContext<'a> { fn new(params: &'a Params) -> Self { ProcessingContext { params, + imported_predicates: HashMap::new(), custom_predicate_signatures: HashMap::new(), custom_predicate_pairs: Vec::new(), request_pair: None, @@ -82,7 +85,8 @@ impl<'a> ProcessingContext<'a> { pub fn process_pest_tree( mut pairs_iterator_for_document_rule: Pairs<'_, Rule>, params: &Params, -) -> Result { + available_batches: &[Arc], +) -> Result { let mut processing_ctx = ProcessingContext::new(params); let document_node = pairs_iterator_for_document_rule.next().ok_or_else(|| { @@ -102,7 +106,11 @@ pub fn process_pest_tree( let document_content_pairs = document_node.into_inner(); - first_pass(document_content_pairs, &mut processing_ctx)?; + first_pass( + document_content_pairs, + &mut processing_ctx, + available_batches, + )?; second_pass(&mut processing_ctx) } @@ -112,12 +120,16 @@ pub fn process_pest_tree( fn first_pass<'a>( document_pairs: Pairs<'a, Rule>, ctx: &mut ProcessingContext<'a>, + available_batches: &[Arc], ) -> Result<(), ProcessorError> { let mut defined_custom_names: HashSet = HashSet::new(); let mut first_request_span: Option<(usize, usize)> = None; for pair in document_pairs { match pair.as_rule() { + Rule::use_statement => { + process_use_statement(&pair, ctx, available_batches)?; + } Rule::custom_predicate_def => { let pred_name_pair = pair .clone() @@ -126,7 +138,9 @@ fn first_pass<'a>( .unwrap(); let pred_name = pred_name_pair.as_str().to_string(); - if defined_custom_names.contains(&pred_name) { + if defined_custom_names.contains(&pred_name) + || ctx.imported_predicates.contains_key(&pred_name) + { return Err(ProcessorError::DuplicateDefinition { name: pred_name, span: Some(get_span(&pred_name_pair)), @@ -179,9 +193,85 @@ fn count_public_args(pred_def_pair: &Pair) -> Result Result { +fn process_use_statement( + use_pair: &Pair, + ctx: &mut ProcessingContext, + available_batches: &[Arc], +) -> Result<(), ProcessorError> { + let mut inner = use_pair.clone().into_inner(); + + let import_list_pair = inner + .find(|p| p.as_rule() == Rule::use_predicate_list) + .unwrap(); + let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap(); + let batch_id_pair = batch_ref_pair.into_inner().next().unwrap(); + let batch_id_str_full = batch_id_pair.as_str(); + + let batch_id_hex = batch_id_str_full + .strip_prefix("0x") + .unwrap_or(batch_id_str_full); + let batch_id_val = parse_hex_str_to_raw_value(batch_id_hex).map_err(|_| { + ProcessorError::InvalidLiteralFormat { + kind: "batch ID hash".to_string(), + value: batch_id_str_full.to_string(), + span: Some(get_span(&batch_id_pair)), + } + })?; + + let target_batch = available_batches + .iter() + .find(|b| b.id().0 == batch_id_val.0) + .ok_or_else(|| ProcessorError::BatchNotFound { + id: batch_id_str_full.to_string(), + span: Some(get_span(&batch_id_pair)), + })?; + + let import_names: Vec> = import_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::import_name) + .collect(); + + if import_names.len() != target_batch.predicates().len() { + return Err(ProcessorError::ImportArityMismatch { + expected: target_batch.predicates().len(), + found: import_names.len(), + span: Some(get_span(use_pair)), + }); + } + + for (i, import_name_pair) in import_names.into_iter().enumerate() { + if import_name_pair.as_str() == "_" { + continue; + } + + let name = import_name_pair.as_str().to_string(); + + if ctx.imported_predicates.contains_key(&name) { + return Err(ProcessorError::DuplicateImportName { + name, + span: Some(get_span(&import_name_pair)), + }); + } + + let custom_pred_ref = CustomPredicateRef::new(target_batch.clone(), i); + ctx.imported_predicates.insert(name, custom_pred_ref); + } + + Ok(()) +} + +enum StatementContext<'a> { + CustomPredicate, + Request { + custom_batch: &'a Arc, + wildcard_names: &'a mut Vec, + defined_wildcards: &'a mut HashSet, + }, +} + +fn second_pass(ctx: &mut ProcessingContext) -> Result { let mut cpb_builder = - CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlogBatch".to_string()); + CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlangBatch".to_string()); for pred_pair in &ctx.custom_predicate_pairs { process_and_add_custom_predicate_to_batch(pred_pair, ctx, &mut cpb_builder)?; @@ -195,7 +285,7 @@ fn second_pass(ctx: &mut ProcessingContext) -> Result { + Predicate::Custom(custom_ref) => { + let expected_arity = custom_ref.predicate().args_len; + if args.len() != expected_arity { + return Err(ProcessorError::ArgumentCountMismatch { + predicate: stmt_name_str.to_string(), + expected: expected_arity, + found: args.len(), + span: Some(stmt_name_span), + }); + } + for (idx, arg) in args.iter().enumerate() { + if !matches!(arg, BuilderArg::WildcardLiteral(_) | BuilderArg::Literal(_)) { + return Err(ProcessorError::TypeError { + expected: "Wildcard or Literal".to_string(), + found: format!("{:?}", arg), + item: format!( + "argument {} of custom predicate call '{}'", + idx + 1, + stmt_name_str + ), + span: Some(stmt_span), + }); + } + } + } + Predicate::BatchSelf(_) => { let (_original_pred_idx, expected_arity_val) = processing_ctx .custom_predicate_signatures .get(stmt_name_str) @@ -454,36 +569,10 @@ fn process_and_add_custom_predicate_to_batch( .into_inner() .filter(|p| p.as_rule() == Rule::statement) { - let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); - let stmt_name_pair = inner_stmt_pairs - .find(|p| p.as_rule() == Rule::identifier) - .unwrap_or_else(|| unreachable!("statement name must be present in statement")); - let stmt_name_str = stmt_name_pair.as_str(); - - let builder_args = parse_statement_args(&stmt_pair)?; - - let middleware_predicate_type = - if let Some(native_pred) = native_predicate_from_string(stmt_name_str) { - Predicate::Native(native_pred) - } else if let Some((pred_index, _expected_arity)) = processing_ctx - .custom_predicate_signatures - .get(stmt_name_str) - { - Predicate::BatchSelf(*pred_index) - } else { - return Err(ProcessorError::UndefinedIdentifier { - name: stmt_name_str.to_string(), - span: Some(get_span(&stmt_name_pair)), - }); - }; - - let stb = validate_and_build_statement_template( - stmt_name_str, - &middleware_predicate_type, - builder_args, + let stb = process_statement_template( + &stmt_pair, processing_ctx, - get_span(&stmt_pair), - get_span(&stmt_name_pair), + StatementContext::CustomPredicate, )?; statement_builders.push(stb); } @@ -520,12 +609,14 @@ fn process_request_def( .into_inner() .filter(|p| p.as_rule() == Rule::statement) { - let built_stb = process_proof_request_statement_template( + let built_stb = process_statement_template( &stmt_pair, processing_ctx, - Some(custom_batch), // Pass as Option<&Arc<...>> - &mut request_wildcard_names, - &mut defined_request_wildcards, + StatementContext::Request { + custom_batch, + wildcard_names: &mut request_wildcard_names, + defined_wildcards: &mut defined_request_wildcards, + }, )?; request_statement_builders.push(built_stb); } @@ -542,12 +633,10 @@ fn process_request_def( Ok(request_templates) } -fn process_proof_request_statement_template( +fn process_statement_template( stmt_pair: &Pair, processing_ctx: &ProcessingContext, - custom_batch_for_request: Option<&Arc>, - request_wildcard_names: &mut Vec, - defined_request_wildcards: &mut HashSet, + mut context: StatementContext, ) -> Result { let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); let name_pair = inner_stmt_pairs @@ -556,50 +645,58 @@ fn process_proof_request_statement_template( let stmt_name_str = name_pair.as_str(); let builder_args = parse_statement_args(stmt_pair)?; - let mut temp_stmt_wildcard_names: Vec = Vec::new(); - for arg in &builder_args { - match arg { - BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()), - BuilderArg::Key(pod_id_str, key_wc_str) => { - if let SelfOrWildcardStr::Wildcard(name) = pod_id_str { - temp_stmt_wildcard_names.push(name.clone()); - } - if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str { - temp_stmt_wildcard_names.push(key_wc_name.clone()); + if let StatementContext::Request { + wildcard_names, + defined_wildcards, + .. + } = &mut context + { + let mut temp_stmt_wildcard_names: Vec = Vec::new(); + for arg in &builder_args { + match arg { + BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()), + BuilderArg::Key(pod_id_str, key_wc_str) => { + if let SelfOrWildcardStr::Wildcard(name) = pod_id_str { + temp_stmt_wildcard_names.push(name.clone()); + } + if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str { + temp_stmt_wildcard_names.push(key_wc_name.clone()); + } } + _ => {} + } + } + for name in temp_stmt_wildcard_names { + if defined_wildcards.insert(name.clone()) { + wildcard_names.push(name); } - _ => {} } } - for name in temp_stmt_wildcard_names { - if defined_request_wildcards.insert(name.clone()) { - request_wildcard_names.push(name); - } - } - - let middleware_predicate_type = - if let Some(native_pred) = native_predicate_from_string(stmt_name_str) { - Predicate::Native(native_pred) - } else if let Some((pred_index, _expected_arity)) = processing_ctx - .custom_predicate_signatures - .get(stmt_name_str) - { - if let Some(batch_ref) = custom_batch_for_request { - Predicate::Custom(CustomPredicateRef::new(batch_ref.clone(), *pred_index)) - } else { - return Err(ProcessorError::Internal(format!( - "Custom predicate '{}' found but no custom batch provided for request processing.", - stmt_name_str - ))); + let middleware_predicate_type = if let Some(native_pred) = + native_predicate_from_string(stmt_name_str) + { + Predicate::Native(native_pred) + } else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) { + Predicate::Custom(custom_ref.clone()) + } else if let Some((pred_index, _expected_arity)) = processing_ctx + .custom_predicate_signatures + .get(stmt_name_str) + { + match context { + StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index), + StatementContext::Request { custom_batch, .. } => { + let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index); + Predicate::Custom(custom_pred_ref) } - } else { - return Err(ProcessorError::UndefinedIdentifier { - name: stmt_name_str.to_string(), - span: Some(get_span(&name_pair)), - }); - }; + } + } else { + return Err(ProcessorError::UndefinedIdentifier { + name: stmt_name_str.to_string(), + span: Some(get_span(&name_pair)), + }); + }; let stb = validate_and_build_statement_template( stmt_name_str, @@ -882,13 +979,13 @@ mod processor_tests { use crate::{ lang::{ error::ProcessorError, - parser::{parse_podlog, Rule}, + parser::{parse_podlang, Rule}, }, middleware::Params, }; fn get_document_content_pairs(input: &str) -> Result, ProcessorError> { - let full_parse_tree = parse_podlog(input) + let full_parse_tree = parse_podlang(input) .map_err(|e| ProcessorError::Internal(format!("Test parsing failed: {:?}", e)))?; let document_node = full_parse_tree.peek().ok_or_else(|| { @@ -910,7 +1007,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; assert!(ctx.custom_predicate_signatures.is_empty()); assert!(ctx.custom_predicate_pairs.is_empty()); assert!(ctx.request_pair.is_none()); @@ -923,7 +1020,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; assert!(ctx.custom_predicate_signatures.is_empty()); assert!(ctx.custom_predicate_pairs.is_empty()); assert!(ctx.request_pair.is_some()); @@ -940,7 +1037,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; assert_eq!(ctx.custom_predicate_signatures.len(), 1); assert_eq!(ctx.custom_predicate_pairs.len(), 1); assert!(ctx.request_pair.is_none()); @@ -964,7 +1061,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; assert_eq!(ctx.custom_predicate_signatures.len(), 2); assert_eq!(ctx.custom_predicate_pairs.len(), 2); @@ -991,11 +1088,12 @@ mod processor_tests { let params = Params::default(); let mut ctx = ProcessingContext { params: ¶ms, + imported_predicates: HashMap::new(), custom_predicate_signatures: HashMap::new(), custom_predicate_pairs: Vec::new(), request_pair: None, }; - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; let pred_name = ctx .custom_predicate_signatures .keys() @@ -1016,7 +1114,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input).unwrap(); let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - let result = first_pass(pairs, &mut ctx); + let result = first_pass(pairs, &mut ctx, &[]); assert!(result.is_err()); match result.err().unwrap() { // Use .err().unwrap() for ProcessorError @@ -1036,7 +1134,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input).unwrap(); let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - let result = first_pass(pairs, &mut ctx); + let result = first_pass(pairs, &mut ctx, &[]); assert!(result.is_err()); match result.err().unwrap() { // Use .err().unwrap() for ProcessorError @@ -1055,7 +1153,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; assert_eq!(ctx.custom_predicate_signatures.len(), 2); assert_eq!(ctx.custom_predicate_pairs.len(), 2); @@ -1093,7 +1191,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; let result = second_pass(&mut ctx); assert!(result.is_err()); match result.err().unwrap() { @@ -1112,7 +1210,7 @@ mod processor_tests { let pairs = get_document_content_pairs(input)?; let params = Params::default(); let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx)?; + first_pass(pairs, &mut ctx, &[])?; let result = second_pass(&mut ctx); assert!(result.is_err()); match result.err().unwrap() { diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 933da5f..00b7acd 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -28,9 +28,10 @@ use std::{ cmp::{Ord, Ordering}, fmt, + fmt::Write, }; -use hex::{FromHex, FromHexError}; +use hex::{FromHex, FromHexError, ToHex}; use plonky2::{ field::types::{Field, PrimeField64}, hash::poseidon::PoseidonHash, @@ -143,6 +144,32 @@ pub struct Hash( pub [F; HASH_SIZE], ); +impl ToHex for Hash { + fn encode_hex>(&self) -> T { + self.0 + .iter() + .rev() + .fold(String::with_capacity(64), |mut s, limb| { + write!(s, "{:016x}", limb.0).unwrap(); + s + }) + .chars() + .collect() + } + + fn encode_hex_upper>(&self) -> T { + self.0 + .iter() + .rev() + .fold(String::with_capacity(64), |mut s, limb| { + write!(s, "{:016X}", limb.0).unwrap(); + s + }) + .chars() + .collect() + } +} + pub fn hash_value(input: &RawValue) -> Hash { hash_fields(&input.0) }