From 42f979c4083ba0a31b0e1dfff7096c5553d3da4b Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Thu, 13 Nov 2025 10:23:21 +0100 Subject: [PATCH] Frontend AST for Podlang (#432) * Basic frontend AST and semantic validation * Intro statement support * Simplify validator lifetime * Fix arity validation * Lowering and splitting * Remove legacy processor and use frontend AST by default * Use builders instead of creating middleware types directly * Typos/formatting * Improve error messages when overflowing a batch due to splitting * Add FromStr implementation for NativePredicate * Remove 'raw' fields, and switch HashHex representation to byte vector rather than string * Simpler wrapper types for batch and intro predicate hashes * Parse secret and public keys to their respective data structures earlier * More detail around string escape validity * Simplify native predicate arity handling and move method to NativePredicate impl * Store hashes using middleware::Hash, and simplify lowering by using pre-parsed values * Simplify predicate building * Formatting * Better error messages/suggestions for cases where predicate splitting fails * Formatting * Clippy fix * Return error if we get a too-large int --- src/frontend/custom.rs | 2 +- src/lang/error.rs | 326 +++++-- src/lang/frontend_ast.rs | 1328 ++++++++++++++++++++++++++++ src/lang/frontend_ast_lower.rs | 749 ++++++++++++++++ src/lang/frontend_ast_split.rs | 1002 +++++++++++++++++++++ src/lang/frontend_ast_validate.rs | 775 +++++++++++++++++ src/lang/grammar.pest | 3 +- src/lang/mod.rs | 58 +- src/lang/parser.rs | 12 + src/lang/processor.rs | 1350 ----------------------------- src/middleware/statement.rs | 76 ++ 11 files changed, 4250 insertions(+), 1431 deletions(-) create mode 100644 src/lang/frontend_ast.rs create mode 100644 src/lang/frontend_ast_lower.rs create mode 100644 src/lang/frontend_ast_split.rs create mode 100644 src/lang/frontend_ast_validate.rs delete mode 100644 src/lang/processor.rs diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index fa75d8f..a481aa6 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -165,7 +165,7 @@ impl CustomPredicateBatchBuilder { /// creates the custom predicate from the given input, adds it to the /// self.predicates, and returns the index of the created predicate - fn predicate( + pub fn predicate( &mut self, name: &str, conjunction: bool, diff --git a/src/lang/error.rs b/src/lang/error.rs index 96161de..72637ea 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -1,95 +1,283 @@ use thiserror::Error; -use crate::{frontend, lang::parser::ParseError, middleware}; +use crate::{ + frontend, + lang::{frontend_ast::Span, parser::ParseError}, + middleware, +}; #[derive(Error, Debug)] pub enum LangError { #[error("Parsing failed: {0}")] Parse(Box), - #[error("AST processing error: {0}")] - Processor(Box), - #[error("Middleware error during processing: {0}")] Middleware(Box), #[error("Frontend error: {0}")] Frontend(Box), + + #[error("Validation error: {0}")] + Validation(Box), + + #[error("Lowering error: {0}")] + Lowering(Box), } -/// Errors that can occur during the processing of Podlang Pest tree into middleware structures. -#[derive(thiserror::Error, Debug)] -pub enum ProcessorError { - #[error("Undefined identifier: '{name}' at {span:?}")] - UndefinedIdentifier { +/// Validation errors from frontend AST validation +#[derive(Debug, thiserror::Error)] +pub enum ValidationError { + #[error("Invalid hash: {hash}")] + InvalidHash { hash: String, span: Option }, + + #[error("Duplicate predicate definition: {name}")] + DuplicatePredicate { name: String, - span: Option<(usize, usize)>, + first_span: Option, + second_span: Option, }, - #[error("Duplicate definition: '{name}' at {span:?}")] - DuplicateDefinition { + + #[error("Duplicate import name: {name}")] + DuplicateImport { name: String, span: Option }, + + #[error("Import arity mismatch: expected {expected} predicates, found {found}")] + ImportArityMismatch { + expected: usize, + found: usize, + span: Option, + }, + + #[error("Batch not found: {id}")] + BatchNotFound { id: String, span: Option }, + + #[error("Undefined predicate: {name}")] + UndefinedPredicate { name: String, span: Option }, + + #[error("Undefined wildcard: {name} in predicate {pred_name}")] + UndefinedWildcard { name: String, - span: Option<(usize, usize)>, + pred_name: String, + span: Option, }, - #[error("Duplicate wildcard: ?{name} in scope at {span:?}")] - DuplicateWildcard { - name: String, - span: Option<(usize, usize)>, - }, - #[error("Type error: expected {expected}, found {found} for '{item}' at {span:?}")] - TypeError { - expected: String, - found: String, - item: String, - span: Option<(usize, usize)>, - }, - #[error( - "Invalid argument count for '{predicate}': expected {expected}, found {found} at {span:?}" - )] + + #[error("Argument count mismatch for {predicate}: expected {expected}, found {found}")] ArgumentCountMismatch { predicate: String, expected: usize, found: usize, - span: Option<(usize, usize)>, + span: Option, }, - #[error("Multiple REQUEST definitions found. Only one is allowed. First at {first_span:?}, second at {second_span:?}")] + + #[error("Invalid argument type for {predicate}: anchored keys not allowed")] + InvalidArgumentType { + predicate: String, + span: Option, + }, + + #[error("Duplicate wildcard in predicate arguments: {name}")] + DuplicateWildcard { name: String, span: Option }, + + #[error("Empty statement list in {context}")] + EmptyStatementList { context: String, span: Option }, + + #[error("Multiple REQUEST definitions found. Only one is allowed.")] MultipleRequestDefinitions { - first_span: Option<(usize, usize)>, - second_span: Option<(usize, usize)>, + first_span: Option, + second_span: Option, }, - #[error("Internal processing error: {0}")] - Internal(String), +} + +/// Lowering errors from frontend AST lowering to middleware +#[derive(Debug, thiserror::Error)] +pub enum LoweringError { + #[error("Too many custom predicates in batch '{batch_name}': {count} exceeds limit of {max}{}", if *.original_count != *.count { format!(" (started with {} predicates before automatic splitting)", original_count) } else { String::new() })] + TooManyPredicates { + batch_name: String, + count: usize, + max: usize, + original_count: usize, + }, + + #[error("Too many statements in predicate '{predicate}': {count} exceeds limit of {max}")] + TooManyStatements { + predicate: String, + count: usize, + max: usize, + }, + + #[error("Too many wildcards in predicate '{predicate}': {count} exceeds limit of {max}")] + TooManyWildcards { + predicate: String, + count: usize, + max: usize, + }, + + #[error("Too many arguments in statement template: {count} exceeds limit of {max}")] + TooManyStatementArgs { count: usize, max: usize }, + + #[error("Predicate '{name}' not found in symbol table")] + PredicateNotFound { name: String }, + + #[error("Invalid argument type in statement template")] + InvalidArgumentType, + #[error("Middleware error: {0}")] - Middleware(middleware::Error), - #[error("Undefined wildcard: '?{name}' in predicate '{pred_name}' at {span:?}")] - UndefinedWildcard { - name: String, - pred_name: String, - span: Option<(usize, usize)>, + Middleware(#[from] middleware::Error), + + #[error("Splitting error: {0}")] + Splitting(#[from] SplittingError), + + #[error("Cannot lower document with validation errors")] + ValidationErrors, +} + +/// Context information for split boundary failures +#[derive(Debug, Clone)] +pub struct SplitContext { + /// Index of the split boundary (0-based) + pub split_index: usize, + /// Range of statement indices in the segment before the split + pub statement_range: (usize, usize), + /// Public arguments coming into this segment + pub incoming_public: Vec, + /// Wildcards that cross this boundary (need to be promoted) + pub crossing_wildcards: Vec, + /// Total public arguments needed (incoming + crossing) + pub total_public: usize, +} + +/// Suggestions for refactoring predicates that fail to split +#[derive(Debug, Clone)] +pub enum RefactorSuggestion { + /// A wildcard is used across too many statements + ReduceWildcardSpan { + wildcard: String, + first_use: usize, + last_use: usize, + span: usize, }, - #[error("Invalid literal format for {kind}: '{value}' at {span:?}")] - InvalidLiteralFormat { - kind: String, - value: String, - span: Option<(usize, usize)>, + /// Multiple wildcards should be grouped together + GroupWildcardUsages { wildcards: Vec }, +} + +impl RefactorSuggestion { + pub fn format(&self) -> String { + match self { + RefactorSuggestion::ReduceWildcardSpan { + wildcard, + first_use, + last_use, + span, + } => { + format!( + "Wildcard '{}' is used across {} statements (statements {}-{}).\n\ + Consider grouping all '{}' operations together, or split the wildcard\n\ + into separate early/late variables.", + wildcard, span, first_use, last_use, wildcard + ) + } + RefactorSuggestion::GroupWildcardUsages { wildcards } => { + format!( + "Group operations for wildcards: {}\n\ + These wildcards are used across multiple segments. Try to complete\n\ + all operations for each wildcard before moving to the next.", + wildcards.join(", ") + ) + } + } + } +} + +/// Formats a detailed error message for TooManyPublicArgsAtSplit +fn format_public_args_at_split_error( + predicate: &str, + context: &SplitContext, + max_allowed: usize, + suggestion: &Option>, +) -> String { + let mut msg = format!( + "Too many public arguments at split boundary {} in predicate '{}':\n", + context.split_index, predicate + ); + + msg.push_str(&format!( + " {} incoming public + {} crossing wildcards = {} total (exceeds max of {})\n", + context.incoming_public.len(), + context.crossing_wildcards.len(), + context.total_public, + max_allowed + )); + + msg.push_str(&format!( + " Statements {}-{} in this segment\n", + context.statement_range.0, context.statement_range.1 + )); + + if !context.incoming_public.is_empty() { + msg.push_str(&format!( + " Incoming public args: {}\n", + context.incoming_public.join(", ") + )); + } + + if !context.crossing_wildcards.is_empty() { + msg.push_str(&format!( + " Wildcards crossing this boundary: {}\n", + context.crossing_wildcards.join(", ") + )); + } + + if let Some(suggestion) = suggestion { + msg.push_str("\nSuggestion:\n"); + msg.push_str(&suggestion.format()); + } + + msg +} + +/// Splitting errors from predicate splitting +#[derive(Debug, thiserror::Error)] +pub enum SplittingError { + #[error("Too many public arguments in predicate '{predicate}': {count} exceeds max of {max_allowed}. {message}")] + TooManyPublicArgs { + predicate: String, + count: usize, + max_allowed: usize, + message: String, }, - #[error("Batch with ID '{id}' not found at {span:?}")] - BatchNotFound { - id: String, - span: Option<(usize, usize)>, + + #[error("Too many total arguments in predicate '{predicate}': {count} exceeds max of {max_allowed}. {message}")] + TooManyTotalArgs { + predicate: String, + count: usize, + max_allowed: usize, + message: String, }, - #[error("Number of predicates in 'use' statement ({found}) exceeds the number of predicates in the batch ({expected}) at {span:?}")] - ImportArityMismatch { - expected: usize, - found: usize, - span: Option<(usize, usize)>, + + #[error("Too many total arguments in chain link {link_index} of predicate '{predicate}': {public_count} public + {private_count} private = {total_count} total (exceeds max of {max_allowed})")] + TooManyTotalArgsInChainLink { + predicate: String, + link_index: usize, + public_count: usize, + private_count: usize, + total_count: usize, + max_allowed: usize, }, - #[error("Duplicate import name '{name}' at {span:?}")] - DuplicateImportName { - name: String, - span: Option<(usize, usize)>, + + #[error("{}", format_public_args_at_split_error(.predicate, .context, *.max_allowed, .suggestion))] + TooManyPublicArgsAtSplit { + predicate: String, + context: Box, + max_allowed: usize, + suggestion: Option>, + }, + + #[error("Too many predicates in chain for '{predicate}': {count} exceeds batch limit of {max_allowed}")] + TooManyPredicatesInChain { + predicate: String, + count: usize, + max_allowed: usize, }, - #[error("Frontend error: {0}")] - Frontend(#[from] frontend::Error), } impl From for LangError { @@ -98,14 +286,20 @@ impl From for LangError { } } -impl From for LangError { - fn from(err: ProcessorError) -> Self { - LangError::Processor(Box::new(err)) - } -} - impl From for LangError { fn from(err: middleware::Error) -> Self { LangError::Middleware(Box::new(err)) } } + +impl From for LangError { + fn from(err: ValidationError) -> Self { + LangError::Validation(Box::new(err)) + } +} + +impl From for LangError { + fn from(err: LoweringError) -> Self { + LangError::Lowering(Box::new(err)) + } +} diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs new file mode 100644 index 0000000..d820d7f --- /dev/null +++ b/src/lang/frontend_ast.rs @@ -0,0 +1,1328 @@ +//! Frontend AST for the Podlang language +//! +//! This module defines an intermediate AST that captures all features of the grammar +//! and supports bidirectional conversion (parsing and pretty-printing). + +use std::fmt; + +use hex::{FromHex, ToHex}; + +use crate::backends::plonky2::primitives::ec::{curve::Point, schnorr::SecretKey}; + +/// The root document containing all top-level declarations +#[derive(Debug, Clone, PartialEq)] +pub struct Document { + pub items: Vec, +} + +/// Top-level items that can appear in a document +#[derive(Debug, Clone, PartialEq)] +pub enum DocumentItem { + UseBatchStatement(UseBatchStatement), + UseIntroStatement(UseIntroStatement), + CustomPredicateDef(CustomPredicateDef), + RequestDef(RequestDef), +} + +/// Import statement: `use batch pred1, pred2, _ from 0x...` +#[derive(Debug, Clone, PartialEq)] +pub struct UseBatchStatement { + pub imports: Vec, + pub batch_ref: HashHex, + pub span: Option, +} + +/// Intro statement: `use intro pred() from 0x...` +#[derive(Debug, Clone, PartialEq)] +pub struct UseIntroStatement { + pub name: Identifier, + pub args: Vec, + pub intro_hash: HashHex, + pub span: Option, +} +/// Individual import name (identifier or unused "_") +#[derive(Debug, Clone, PartialEq)] +pub enum ImportName { + Named(String), + Unused, // "_" +} + +/// Batch reference (hash) +#[derive(Debug, Clone, PartialEq)] +pub struct BatchRef { + pub hash: HashHex, + pub span: Option, +} + +/// Intro predicate reference (hash) +#[derive(Debug, Clone, PartialEq)] +pub struct IntroPredicateRef { + pub hash: HashHex, + pub span: Option, +} + +/// Custom predicate definition +#[derive(Debug, Clone, PartialEq)] +pub struct CustomPredicateDef { + pub name: Identifier, + pub args: ArgSection, + pub conjunction_type: ConjunctionType, + pub statements: Vec, + pub span: Option, +} + +/// Request definition +#[derive(Debug, Clone, PartialEq)] +pub struct RequestDef { + pub statements: Vec, + pub span: Option, +} + +/// Argument section with public and optional private arguments +#[derive(Debug, Clone, PartialEq)] +pub struct ArgSection { + pub public_args: Vec, + pub private_args: Option>, + pub span: Option, +} + +/// Conjunction type for custom predicates +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ConjunctionType { + And, + Or, +} + +/// Statement template: predicate call with arguments +#[derive(Debug, Clone, PartialEq)] +pub struct StatementTmpl { + pub predicate: Identifier, + pub args: Vec, + pub span: Option, +} + +/// Arguments that can be passed to statements +#[derive(Debug, Clone, PartialEq)] +pub enum StatementTmplArg { + Literal(LiteralValue), + Wildcard(Identifier), + AnchoredKey(AnchoredKey), +} + +/// Anchored key: Var["key"] or Var.key +#[derive(Debug, Clone, PartialEq)] +pub struct AnchoredKey { + pub root: Identifier, + pub key: AnchoredKeyPath, + pub span: Option, +} + +impl AnchoredKey { + pub fn key_str(&self) -> &str { + match &self.key { + AnchoredKeyPath::Bracket(ls) => &ls.value, + AnchoredKeyPath::Dot(id) => &id.name, + } + } +} + +/// Key path in an anchored key +#[derive(Debug, Clone, PartialEq)] +pub enum AnchoredKeyPath { + Bracket(LiteralString), // ["key"] + Dot(Identifier), // .key +} + +/// Identifier (variable names, predicate names, etc.) +#[derive(Debug, Clone, PartialEq)] +pub struct Identifier { + pub name: String, + pub span: Option, +} + +/// Hash value in hex format (0x...) +#[derive(Debug, Clone, PartialEq)] +pub struct HashHex { + pub hash: crate::middleware::Hash, + pub span: Option, +} + +/// All possible literal values +#[derive(Debug, Clone, PartialEq)] +pub enum LiteralValue { + Int(LiteralInt), + Bool(LiteralBool), + String(LiteralString), + Raw(LiteralRaw), + PublicKey(LiteralPublicKey), + SecretKey(LiteralSecretKey), + Array(LiteralArray), + Set(LiteralSet), + Dict(LiteralDict), +} + +/// Integer literal +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralInt { + pub value: i64, + pub span: Option, +} + +/// Boolean literal +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralBool { + pub value: bool, + pub span: Option, +} + +/// String literal +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralString { + pub value: String, // Unescaped value + pub span: Option, +} + +/// Raw value literal: Raw(0x...) +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralRaw { + pub hash: HashHex, + pub span: Option, +} + +/// Public key literal: PublicKey(base58string) +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralPublicKey { + pub point: Point, + pub span: Option, +} + +/// Secret key literal: SecretKey(base64string) +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralSecretKey { + pub secret_key: SecretKey, + pub span: Option, +} + +/// Array literal: [...] +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralArray { + pub elements: Vec, + pub span: Option, +} + +/// Set literal: #[...] +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralSet { + pub elements: Vec, + pub span: Option, +} + +/// Dictionary literal: {...} +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralDict { + pub pairs: Vec, + pub span: Option, +} + +/// Key-value pair in a dictionary +#[derive(Debug, Clone, PartialEq)] +pub struct DictPair { + pub key: LiteralString, + pub value: LiteralValue, + pub span: Option, +} + +/// Source location information for error reporting and formatting +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Span { + pub start: usize, + pub end: usize, +} + +// Display implementations for pretty-printing + +impl fmt::Display for Document { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, item) in self.items.iter().enumerate() { + if i > 0 { + writeln!(f)?; + } + write!(f, "{}", item)?; + } + Ok(()) + } +} + +impl fmt::Display for DocumentItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DocumentItem::UseBatchStatement(u) => write!(f, "{}", u), + DocumentItem::UseIntroStatement(u) => write!(f, "{}", u), + DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c), + DocumentItem::RequestDef(r) => write!(f, "{}", r), + } + } +} + +impl fmt::Display for UseBatchStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "use batch ")?; + for (i, import) in self.imports.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", import)?; + } + write!(f, " from {}", self.batch_ref) + } +} + +impl fmt::Display for UseIntroStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "use intro {}(", self.name)?; + for (i, arg) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + write!(f, ") from {}", self.intro_hash) + } +} + +impl fmt::Display for ImportName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ImportName::Named(name) => write!(f, "{}", name), + ImportName::Unused => write!(f, "_"), + } + } +} + +impl fmt::Display for BatchRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.hash) + } +} + +impl fmt::Display for IntroPredicateRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.hash) + } +} + +impl fmt::Display for HashHex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "0x{}", self.hash.encode_hex::()) + } +} + +impl fmt::Display for CustomPredicateDef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "{}({}) = {}(", + self.name, self.args, self.conjunction_type + )?; + for stmt in &self.statements { + writeln!(f, " {}", stmt)?; + } + write!(f, ")") + } +} + +impl fmt::Display for ArgSection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, arg) in self.public_args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + if let Some(private_args) = &self.private_args { + if !self.public_args.is_empty() { + write!(f, ", ")?; + } + write!(f, "private: ")?; + for (i, arg) in private_args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + } + Ok(()) + } +} + +impl fmt::Display for ConjunctionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConjunctionType::And => write!(f, "AND"), + ConjunctionType::Or => write!(f, "OR"), + } + } +} + +impl fmt::Display for RequestDef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "REQUEST(")?; + for stmt in &self.statements { + writeln!(f, " {}", stmt)?; + } + write!(f, ")") + } +} + +impl fmt::Display for StatementTmpl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.predicate)?; + for (i, arg) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + write!(f, ")") + } +} + +impl fmt::Display for StatementTmplArg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StatementTmplArg::Literal(lit) => write!(f, "{}", lit), + StatementTmplArg::Wildcard(id) => write!(f, "{}", id), + StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak), + } + } +} + +impl fmt::Display for Identifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl fmt::Display for AnchoredKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.key { + AnchoredKeyPath::Bracket(s) => write!(f, "{}[{}]", self.root, s), + AnchoredKeyPath::Dot(id) => write!(f, "{}.{}", self.root, id), + } + } +} + +impl fmt::Display for LiteralValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LiteralValue::Int(i) => write!(f, "{}", i), + LiteralValue::Bool(b) => write!(f, "{}", b), + LiteralValue::String(s) => write!(f, "{}", s), + LiteralValue::Raw(r) => write!(f, "{}", r), + LiteralValue::PublicKey(pk) => write!(f, "{}", pk), + LiteralValue::SecretKey(sk) => write!(f, "{}", sk), + LiteralValue::Array(a) => write!(f, "{}", a), + LiteralValue::Set(s) => write!(f, "{}", s), + LiteralValue::Dict(d) => write!(f, "{}", d), + } + } +} + +impl fmt::Display for LiteralInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl fmt::Display for LiteralBool { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", if self.value { "true" } else { "false" }) + } +} + +impl fmt::Display for LiteralString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "\"")?; + for ch in self.value.chars() { + match ch { + '"' => write!(f, "\\\"")?, + '\\' => write!(f, "\\\\")?, + '\n' => write!(f, "\\n")?, + '\r' => write!(f, "\\r")?, + '\t' => write!(f, "\\t")?, + '\u{0008}' => write!(f, "\\b")?, + '\u{000C}' => write!(f, "\\f")?, + _ => write!(f, "{}", ch)?, + } + } + write!(f, "\"") + } +} + +impl fmt::Display for LiteralRaw { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Raw({})", self.hash) + } +} + +impl fmt::Display for LiteralPublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PublicKey({})", self.point) + } +} + +impl fmt::Display for LiteralSecretKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SecretKey({})", self.secret_key) + } +} + +impl fmt::Display for LiteralArray { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, elem) in self.elements.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", elem)?; + } + write!(f, "]") + } +} + +impl fmt::Display for LiteralSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "#[")?; + for (i, elem) in self.elements.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", elem)?; + } + write!(f, "]") + } +} + +impl fmt::Display for LiteralDict { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{{")?; + for (i, pair) in self.pairs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", pair)?; + } + write!(f, "}}") + } +} + +impl fmt::Display for DictPair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}: {}", self.key, self.value) + } +} + +// Parser module for converting Pest pairs to AST +pub mod parse { + use pest::iterators::Pair; + + use super::*; + use crate::lang::parser::{self, Rule}; + + /// Convert a Pest document pair to an AST Document + pub fn parse_document(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::document); + let mut items = Vec::new(); + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::use_batch_statement => { + items.push(DocumentItem::UseBatchStatement(parse_use_batch_statement( + inner_pair, + ))); + } + Rule::use_intro_statement => { + items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement( + inner_pair, + ))); + } + Rule::custom_predicate_def => { + items.push(DocumentItem::CustomPredicateDef( + parse_custom_predicate_def(inner_pair)?, + )); + } + Rule::request_def => { + items.push(DocumentItem::RequestDef(parse_request_def(inner_pair)?)); + } + Rule::EOI => {} + _ => unreachable!("Unexpected rule in document: {:?}", inner_pair.as_rule()), + } + } + + Ok(Document { items }) + } + + fn parse_use_batch_statement(pair: Pair) -> UseBatchStatement { + assert_eq!(pair.as_rule(), Rule::use_batch_statement); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + + let use_list_pair = inner + .find(|p| p.as_rule() == Rule::use_predicate_list) + .unwrap(); + let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap(); + + let imports = use_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::import_name) + .map(parse_import_name) + .collect(); + + UseBatchStatement { + imports, + batch_ref: parse_hash_hex(batch_ref_pair.into_inner().next().unwrap()), + span: Some(span), + } + } + + fn parse_use_intro_statement(pair: Pair) -> UseIntroStatement { + assert_eq!(pair.as_rule(), Rule::use_intro_statement); + let span = get_span(&pair); + let inner = pair.into_inner(); + + let name = parse_identifier( + inner + .clone() + .find(|p| p.as_rule() == Rule::identifier) + .unwrap(), + ); + + let args: Vec = inner + .clone() + .find(|p| p.as_rule() == Rule::use_intro_arg_list) + .map(|arg_list| { + arg_list + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + .map(parse_identifier) + .collect() + }) + .unwrap_or_default(); + + let intro_predicate_ref_pair = inner + .clone() + .find(|p| p.as_rule() == Rule::intro_predicate_ref) + .unwrap(); + + UseIntroStatement { + name, + args, + intro_hash: parse_hash_hex(intro_predicate_ref_pair.into_inner().next().unwrap()), + span: Some(span), + } + } + + fn parse_import_name(pair: Pair) -> ImportName { + assert_eq!(pair.as_rule(), Rule::import_name); + let s = pair.as_str(); + if s == "_" { + ImportName::Unused + } else { + ImportName::Named(s.to_string()) + } + } + + fn parse_hash_hex(pair: Pair) -> HashHex { + assert_eq!(pair.as_rule(), Rule::hash_hex); + let span = get_span(&pair); + let hex_str = pair.as_str(); + + // Grammar guarantees "0x" prefix and exactly 64 hex chars + assert!(hex_str.starts_with("0x")); + let hex_without_prefix = &hex_str[2..]; + + // Parse hex string directly to middleware::Hash + let hash = crate::middleware::Hash::from_hex(hex_without_prefix) + .expect("Grammar should guarantee valid hex"); + + HashHex { + hash, + span: Some(span), + } + } + + fn parse_custom_predicate_def( + pair: Pair, + ) -> Result { + assert_eq!(pair.as_rule(), Rule::custom_predicate_def); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + + let name = parse_identifier(inner.next().unwrap()); + let args = parse_arg_section(inner.next().unwrap()); + let conjunction_type = parse_conjunction_type(inner.next().unwrap()); + let statement_list = inner.next().unwrap(); + + let statements = statement_list + .into_inner() + .filter(|p| p.as_rule() == Rule::statement) + .map(parse_statement) + .collect::, _>>()?; + + Ok(CustomPredicateDef { + name, + args, + conjunction_type, + statements, + span: Some(span), + }) + } + + fn parse_arg_section(pair: Pair) -> ArgSection { + assert_eq!(pair.as_rule(), Rule::arg_section); + let span = get_span(&pair); + let mut public_args = Vec::new(); + let mut private_args = None; + + for inner_pair in pair.into_inner() { + match inner_pair.as_rule() { + Rule::public_arg_list => { + public_args = inner_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + .map(parse_identifier) + .collect(); + } + Rule::private_arg_list => { + private_args = Some( + inner_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + .map(parse_identifier) + .collect(), + ); + } + _ => {} + } + } + + ArgSection { + public_args, + private_args, + span: Some(span), + } + } + + fn parse_conjunction_type(pair: Pair) -> ConjunctionType { + assert_eq!(pair.as_rule(), Rule::conjunction_type); + match pair.as_str() { + "AND" => ConjunctionType::And, + "OR" => ConjunctionType::Or, + _ => unreachable!("Invalid conjunction type: {}", pair.as_str()), + } + } + + fn parse_request_def(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::request_def); + let span = get_span(&pair); + let mut statements = Vec::new(); + + for inner_pair in pair.into_inner() { + if inner_pair.as_rule() == Rule::statement_list { + statements = inner_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::statement) + .map(parse_statement) + .collect::, _>>()?; + } + } + + Ok(RequestDef { + statements, + span: Some(span), + }) + } + + fn parse_statement(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::statement); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + + let predicate = parse_identifier(inner.next().unwrap()); + let mut args = Vec::new(); + + if let Some(arg_list) = inner.next() { + if arg_list.as_rule() == Rule::statement_arg_list { + args = arg_list + .into_inner() + .filter(|p| p.as_rule() == Rule::statement_arg) + .map(parse_statement_arg) + .collect::, _>>()?; + } + } + + Ok(StatementTmpl { + predicate, + args, + span: Some(span), + }) + } + + fn parse_statement_arg(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::statement_arg); + let inner = pair.into_inner().next().unwrap(); + + match inner.as_rule() { + Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)), + Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))), + Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)), + _ => unreachable!("Unexpected statement arg rule: {:?}", inner.as_rule()), + } + } + + fn parse_anchored_key(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::anchored_key); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + + let root = parse_identifier(inner.next().unwrap()); + let key_part = inner.next().unwrap(); + + let key = match key_part.as_rule() { + Rule::literal_string => AnchoredKeyPath::Bracket(parse_literal_string(key_part)?), + Rule::identifier => AnchoredKeyPath::Dot(parse_identifier(key_part)), + _ => unreachable!("Unexpected anchored key part: {:?}", key_part.as_rule()), + }; + + Ok(AnchoredKey { + root, + key, + span: Some(span), + }) + } + + fn parse_identifier(pair: Pair) -> Identifier { + assert_eq!(pair.as_rule(), Rule::identifier); + Identifier { + name: pair.as_str().to_string(), + span: Some(get_span(&pair)), + } + } + + fn parse_literal_value(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_value); + let inner = pair.into_inner().next().unwrap(); + + match inner.as_rule() { + Rule::literal_int => Ok(LiteralValue::Int(parse_literal_int(inner)?)), + Rule::literal_bool => Ok(LiteralValue::Bool(parse_literal_bool(inner))), + Rule::literal_string => Ok(LiteralValue::String(parse_literal_string(inner)?)), + Rule::literal_raw => Ok(LiteralValue::Raw(parse_literal_raw(inner))), + Rule::literal_public_key => { + Ok(LiteralValue::PublicKey(parse_literal_public_key(inner)?)) + } + Rule::literal_secret_key => { + Ok(LiteralValue::SecretKey(parse_literal_secret_key(inner)?)) + } + Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), + Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), + Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), + _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), + } + } + + fn parse_literal_int(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_int); + let value = pair + .as_str() + .parse() + .map_err(|e| parser::ParseError::InvalidInt(format!("{}: {}", pair.as_str(), e)))?; + Ok(LiteralInt { + value, + span: Some(get_span(&pair)), + }) + } + + fn parse_literal_bool(pair: Pair) -> LiteralBool { + assert_eq!(pair.as_rule(), Rule::literal_bool); + LiteralBool { + value: pair.as_str() == "true", + span: Some(get_span(&pair)), + } + } + + fn parse_literal_string(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_string); + let span = get_span(&pair); + + // Extract the unescaped value from between quotes + let inner = pair.into_inner().next().unwrap(); + let value = unescape_string(inner.as_str())?; + + Ok(LiteralString { + value, + span: Some(span), + }) + } + + fn parse_literal_raw(pair: Pair) -> LiteralRaw { + assert_eq!(pair.as_rule(), Rule::literal_raw); + let span = get_span(&pair); + let hash_pair = pair.into_inner().next().unwrap(); + LiteralRaw { + hash: parse_hash_hex(hash_pair), + span: Some(span), + } + } + + fn parse_literal_public_key(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_public_key); + let span = get_span(&pair); + let base58_pair = pair.into_inner().next().unwrap(); + let base58_str = base58_pair.as_str(); + let point = base58_str + .parse() + .map_err(|e| parser::ParseError::InvalidPublicKey(format!("{}: {}", base58_str, e)))?; + Ok(LiteralPublicKey { + point, + span: Some(span), + }) + } + + fn parse_literal_secret_key(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_secret_key); + let span = get_span(&pair); + let base64_pair = pair.into_inner().next().unwrap(); + let base64_str = base64_pair.as_str(); + let secret_key = base64_str + .parse() + .map_err(|e| parser::ParseError::InvalidSecretKey(format!("{}: {}", base64_str, e)))?; + Ok(LiteralSecretKey { + secret_key, + span: Some(span), + }) + } + + fn parse_literal_array(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_array); + let span = get_span(&pair); + let elements: Result, _> = pair + .into_inner() + .filter(|p| p.as_rule() == Rule::literal_value) + .map(parse_literal_value) + .collect(); + Ok(LiteralArray { + elements: elements?, + span: Some(span), + }) + } + + fn parse_literal_set(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_set); + let span = get_span(&pair); + let elements: Result, _> = pair + .into_inner() + .filter(|p| p.as_rule() == Rule::literal_value) + .map(parse_literal_value) + .collect(); + Ok(LiteralSet { + elements: elements?, + span: Some(span), + }) + } + + fn parse_literal_dict(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_dict); + let span = get_span(&pair); + let pairs: Result, _> = pair + .into_inner() + .filter(|p| p.as_rule() == Rule::dict_pair) + .map(parse_dict_pair) + .collect(); + Ok(LiteralDict { + pairs: pairs?, + span: Some(span), + }) + } + + fn parse_dict_pair(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::dict_pair); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + let key = parse_literal_string(inner.next().unwrap())?; + let value = parse_literal_value(inner.next().unwrap())?; + Ok(DictPair { + key, + value, + span: Some(span), + }) + } + + fn get_span(pair: &Pair) -> Span { + let span = pair.as_span(); + Span { + start: span.start(), + end: span.end(), + } + } + + fn unescape_string(s: &str) -> Result { + let mut result = String::new(); + let mut chars = s.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '\\' { + match chars.next() { + Some('"') => result.push('"'), + Some('\\') => result.push('\\'), + Some('/') => result.push('/'), + Some('b') => result.push('\u{0008}'), + Some('f') => result.push('\u{000C}'), + Some('n') => result.push('\n'), + Some('r') => result.push('\r'), + Some('t') => result.push('\t'), + Some('u') => { + // Grammar guarantees exactly 4 hex digits after \u + // We only need to check if the codepoint is valid unicode + let hex: String = chars.by_ref().take(4).collect(); + let code = u32::from_str_radix(&hex, 16) + .expect("Grammar should guarantee valid hex digits"); + let unicode_char = char::from_u32(code).ok_or_else(|| { + parser::ParseError::InvalidEscapeSequence(format!( + "\\u{}: invalid unicode codepoint", + hex + )) + })?; + result.push(unicode_char); + } + Some(other) => { + // Grammar should prevent this, but handle gracefully + unreachable!( + "Grammar should only allow specific escape sequences, got: \\{}", + other + ); + } + None => { + // Grammar should prevent this + unreachable!("Grammar should not allow backslash at end of string"); + } + } + } else { + result.push(ch); + } + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::parser::parse_podlang; + + /// Test that parsing and pretty-printing produces equivalent output + fn test_roundtrip(input: &str) { + let parsed = parse_podlang(input).expect("Failed to parse input"); + let document_pair = parsed.into_iter().next().expect("No document pair"); + let mut ast = parse::parse_document(document_pair).expect("Failed to parse"); + let output = ast.to_string(); + // Parse the output to verify it's still valid + let reparsed = parse_podlang(&output).expect("Failed to parse pretty-printed output"); + let reparsed_document_pair = reparsed + .into_iter() + .next() + .expect("No document pair in reparse"); + let mut reparsed_ast = + parse::parse_document(reparsed_document_pair).expect("Failed to parse"); + + // Clear spans for comparison (they'll be different after pretty-printing) + clear_spans(&mut ast); + clear_spans(&mut reparsed_ast); + + // Compare the ASTs (they should be structurally equivalent) + assert_eq!(ast, reparsed_ast, "AST mismatch for input:\n{}", input); + } + + fn clear_spans(doc: &mut Document) { + for item in &mut doc.items { + match item { + DocumentItem::UseBatchStatement(u) => { + u.span = None; + u.batch_ref.span = None; + } + DocumentItem::UseIntroStatement(u) => { + u.span = None; + u.name.span = None; + u.intro_hash.span = None; + } + DocumentItem::CustomPredicateDef(c) => { + c.span = None; + c.name.span = None; + c.args.span = None; + for arg in &mut c.args.public_args { + arg.span = None; + } + if let Some(private) = &mut c.args.private_args { + for arg in private { + arg.span = None; + } + } + for stmt in &mut c.statements { + clear_statement_spans(stmt); + } + } + DocumentItem::RequestDef(r) => { + r.span = None; + for stmt in &mut r.statements { + clear_statement_spans(stmt); + } + } + } + } + } + + fn clear_statement_spans(stmt: &mut StatementTmpl) { + stmt.span = None; + stmt.predicate.span = None; + for arg in &mut stmt.args { + match arg { + StatementTmplArg::Literal(lit) => clear_literal_spans(lit), + StatementTmplArg::Wildcard(id) => id.span = None, + StatementTmplArg::AnchoredKey(ak) => { + ak.span = None; + ak.root.span = None; + match &mut ak.key { + AnchoredKeyPath::Bracket(s) => s.span = None, + AnchoredKeyPath::Dot(id) => id.span = None, + } + } + } + } + } + + fn clear_literal_spans(lit: &mut LiteralValue) { + match lit { + LiteralValue::Int(i) => i.span = None, + LiteralValue::Bool(b) => b.span = None, + LiteralValue::String(s) => s.span = None, + LiteralValue::Raw(r) => { + r.span = None; + r.hash.span = None; + } + LiteralValue::PublicKey(pk) => pk.span = None, + LiteralValue::SecretKey(sk) => sk.span = None, + LiteralValue::Array(a) => { + a.span = None; + for elem in &mut a.elements { + clear_literal_spans(elem); + } + } + LiteralValue::Set(s) => { + s.span = None; + for elem in &mut s.elements { + clear_literal_spans(elem); + } + } + LiteralValue::Dict(d) => { + d.span = None; + for pair in &mut d.pairs { + pair.span = None; + pair.key.span = None; + clear_literal_spans(&mut pair.value); + } + } + } + } + + #[test] + fn test_empty_document() { + test_roundtrip(""); + } + + #[test] + fn test_simple_request() { + let input = r#"REQUEST( + Equal(A["foo"], B["bar"]) + NotEqual(C["baz"], 123) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_custom_predicate() { + let input = r#"my_pred(A, B) = AND ( + Equal(A["foo"], B.bar) + Lt(A["key with spaces"], 100) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_private_args() { + let input = r#"pred_with_private(X, private: TempKey) = OR ( + Equal(X["key"], TempKey["value"]) + Contains(X["list"], TempKey["item"]) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_use_batch_statement() { + let input = r#"use batch pred1, pred2, _ from 0x0000000000000000000000000000000000000000000000000000000000000000"#; + test_roundtrip(input); + } + + #[test] + fn test_use_intro_statement() { + let input = r#"use intro pred1() from 0x0000000000000000000000000000000000000000000000000000000000000000"#; + test_roundtrip(input); + } + + #[test] + fn test_literals() { + // Generate valid PublicKey and SecretKey for the test + let sk = SecretKey::new_rand(); + let pk = sk.public_key(); + + let input = format!( + r#"REQUEST( + Equal(A["int"], 42) + Equal(B["neg"], -100) + Equal(C["bool"], true) + Equal(D["bool2"], false) + Equal(E["string"], "hello world") + Equal(F["raw"], Raw(0x0000000000000000000000000000000000000000000000000000000000000001)) + Equal(G["pk"], PublicKey({})) + Equal(H["sk"], SecretKey({})) +)"#, + pk, sk + ); + test_roundtrip(&input); + } + + #[test] + fn test_containers() { + let input = r#"REQUEST( + Equal(A["array"], [1, 2, 3]) + Equal(B["set"], #["a", "b", "c"]) + Equal(C["dict"], {"key1": "value1", "key2": 42}) + Equal(D["nested"], [{"inner": #[1, 2]}, [true, false]]) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_anchored_keys() { + let input = r#"REQUEST( + Equal(Var["bracket_key"], Other["key2"]) + Equal(Var.dot_key, Other.key3) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_complete_document() { + let input = r#"use batch imported_pred from 0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd + +is_valid(User, private: Config) = AND ( + Equal(User["age"], Config["min_age"]) + imported_pred(User, Config) +) + +check_both(A, B, C) = OR ( + is_valid(A) + is_valid(B) + Equal(C["flag"], true) +) + +REQUEST( + check_both(Pod1, Pod2, Pod3) + NotContains(Pod1["list"], Pod2["value"]) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_string_escapes() { + let input = r#"REQUEST( + Equal(A["escaped"], "line1\nline2") + Equal(B["quote"], "say \"hello\"") + Equal(C["backslash"], "path\\to\\file") + Equal(D["tab"], "col1\tcol2") +)"#; + + let parsed = parse_podlang(input).expect("Failed to parse input"); + let document_pair = parsed.into_iter().next().expect("No document pair"); + let ast = parse::parse_document(document_pair).expect("Failed to parse"); + + // Check that the AST correctly unescaped the strings + if let DocumentItem::RequestDef(req) = &ast.items[0] { + if let StatementTmplArg::Literal(LiteralValue::String(s)) = &req.statements[0].args[1] { + assert_eq!(s.value, "line1\nline2"); + } + if let StatementTmplArg::Literal(LiteralValue::String(s)) = &req.statements[1].args[1] { + assert_eq!(s.value, "say \"hello\""); + } + if let StatementTmplArg::Literal(LiteralValue::String(s)) = &req.statements[2].args[1] { + assert_eq!(s.value, "path\\to\\file"); + } + if let StatementTmplArg::Literal(LiteralValue::String(s)) = &req.statements[3].args[1] { + assert_eq!(s.value, "col1\tcol2"); + } + } + } + + #[test] + fn test_ast_structure() { + let input = r#"my_pred(A, B, private: C) = AND ( + Equal(A["foo"], B["bar"]) +) + +REQUEST( + my_pred(X, Y) +)"#; + + let parsed = parse_podlang(input).expect("Failed to parse input"); + let document_pair = parsed.into_iter().next().expect("No document pair"); + let ast = parse::parse_document(document_pair).expect("Failed to parse"); + + assert_eq!(ast.items.len(), 2); + + // Check custom predicate structure + if let DocumentItem::CustomPredicateDef(pred) = &ast.items[0] { + assert_eq!(pred.name.name, "my_pred"); + assert_eq!(pred.args.public_args.len(), 2); + assert_eq!(pred.args.public_args[0].name, "A"); + assert_eq!(pred.args.public_args[1].name, "B"); + assert_eq!(pred.args.private_args.as_ref().unwrap().len(), 1); + assert_eq!(pred.args.private_args.as_ref().unwrap()[0].name, "C"); + assert_eq!(pred.conjunction_type, ConjunctionType::And); + assert_eq!(pred.statements.len(), 1); + } else { + panic!("Expected CustomPredicateDef"); + } + + // Check request structure + if let DocumentItem::RequestDef(req) = &ast.items[1] { + assert_eq!(req.statements.len(), 1); + assert_eq!(req.statements[0].predicate.name, "my_pred"); + assert_eq!(req.statements[0].args.len(), 2); + } else { + panic!("Expected RequestDef"); + } + } + + #[test] + fn test_invalid_escape_sequences() { + // Test invalid unicode codepoint - surrogate pair range + let input = r#"REQUEST(Equal(A["key"], "test\uD800"))"#; + let parsed = crate::lang::parser::parse_podlang(input).expect("Grammar should accept this"); + let result = parse::parse_document(parsed.into_iter().next().unwrap()); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::lang::parser::ParseError::InvalidEscapeSequence(_) + )); + } +} diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs new file mode 100644 index 0000000..e5ad305 --- /dev/null +++ b/src/lang/frontend_ast_lower.rs @@ -0,0 +1,749 @@ +//! Lowering from frontend AST to middleware structures +//! +//! This module converts validated frontend AST to middleware data structures. +//! Currently implements basic 1:1 conversion without automatic predicate splitting. + +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, + sync::Arc, +}; + +use crate::{ + frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, + lang::{ + frontend_ast::*, + frontend_ast_split, + frontend_ast_validate::{PredicateKind, ValidatedAST}, + }, + middleware::{ + self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params, + Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, + Wildcard, + }, +}; + +/// Result of lowering: optional custom predicate batch and optional request +/// +/// A Podlang file can contain: +/// - Just custom predicates (batch: Some, request: None) +/// - Just a request (batch: None, request: Some) +/// - Both (batch: Some, request: Some) +/// - Neither (batch: None, request: None) - just imports +#[derive(Debug, Clone)] +pub struct LoweredOutput { + pub batch: Option>, + pub request: Option, +} + +pub use crate::lang::error::LoweringError; + +/// Lower a validated AST to middleware structures +/// +/// Returns both the custom predicate batch (if any) and the request (if any). +/// At least one will be Some if the document contains custom predicates or a request. +pub fn lower( + validated: ValidatedAST, + params: &Params, + batch_name: String, +) -> Result { + if !validated.diagnostics().is_empty() { + // For now, treat any diagnostics as errors + // In future we could allow warnings + return Err(LoweringError::ValidationErrors); + } + + let lowerer = Lowerer::new(validated, params); + lowerer.lower(batch_name) +} + +struct Lowerer<'a> { + validated: ValidatedAST, + params: &'a Params, + /// Map of predicate names to their index in the current batch (for split predicates) + batch_predicate_index: HashMap, +} + +impl<'a> Lowerer<'a> { + fn new(validated: ValidatedAST, params: &'a Params) -> Self { + Self { + validated, + params, + batch_predicate_index: HashMap::new(), + } + } + + fn lower(mut self, batch_name: String) -> Result { + // Lower custom predicates (if any) + let batch = self.lower_batch(batch_name)?; + + // Lower request (if any) - pass batch so BatchSelf refs can be converted to Custom refs + let request = self.lower_request(batch.as_ref())?; + + Ok(LoweredOutput { batch, request }) + } + + fn lower_batch( + &mut self, + batch_name: String, + ) -> Result>, LoweringError> { + // Extract and split custom predicates from document + let (custom_predicates, original_count) = self.extract_and_split_predicates()?; + + // If no custom predicates, return None + if custom_predicates.is_empty() { + return Ok(None); + } + + // Check batch size constraint + if custom_predicates.len() > self.params.max_custom_batch_size { + return Err(LoweringError::TooManyPredicates { + batch_name: batch_name.clone(), + count: custom_predicates.len(), + max: self.params.max_custom_batch_size, + original_count, + }); + } + + // Build index of all predicates in the batch + for (idx, pred) in custom_predicates.iter().enumerate() { + self.batch_predicate_index + .insert(pred.name.name.clone(), idx); + } + + // Create custom predicate batch using builder + let mut cpb_builder = + CustomPredicateBatchBuilder::new(self.params.clone(), batch_name.clone()); + + for pred_def in &custom_predicates { + self.lower_custom_predicate(pred_def, &mut cpb_builder)?; + } + + Ok(Some(cpb_builder.finish())) + } + + fn lower_request( + &self, + batch: Option<&Arc>, + ) -> Result, LoweringError> { + let doc = self.validated.document(); + + // Find request definition (if any) + let request_def = doc.items.iter().find_map(|item| match item { + DocumentItem::RequestDef(req) => Some(req), + _ => None, + }); + + let Some(request_def) = request_def else { + return Ok(None); + }; + + // Build wildcard map from all wildcards used in the request statements + let wildcard_map = self.build_request_wildcard_map(request_def); + + // Lower each statement to a builder first + let mut statement_builders = Vec::new(); + for stmt in &request_def.statements { + let stmt_builder = self.lower_statement_to_builder(stmt)?; + statement_builders.push(stmt_builder); + } + + // Resolve builders to middleware statement templates + let mut request_templates = Vec::new(); + for stmt_builder in statement_builders { + let mw_stmt = + self.resolve_request_statement_builder(stmt_builder, &wildcard_map, batch)?; + request_templates.push(mw_stmt); + } + + Ok(Some(crate::frontend::PodRequest::new(request_templates))) + } + + fn resolve_request_statement_builder( + &self, + stmt_builder: StatementTmplBuilder, + wildcard_map: &HashMap, + batch: Option<&Arc>, + ) -> Result { + // First desugar the builder + let desugared = stmt_builder.desugar(); + + // Convert BatchSelf predicate to Custom if we have a batch + let mut predicate = desugared.predicate; + if let Some(batch_ref) = batch { + if let Predicate::BatchSelf(index) = predicate { + predicate = Predicate::Custom(middleware::CustomPredicateRef::new( + batch_ref.clone(), + index, + )); + } + } + + // Convert BuilderArgs to StatementTmplArgs + let mut mw_args = Vec::new(); + for builder_arg in desugared.args { + let mw_arg = match builder_arg { + BuilderArg::Literal(value) => MWStatementTmplArg::Literal(value), + BuilderArg::WildcardLiteral(name) => { + let index = wildcard_map.get(&name).expect("Wildcard not found"); + MWStatementTmplArg::Wildcard(Wildcard::new(name, *index)) + } + BuilderArg::Key(root_name, key_str) => { + let root_index = wildcard_map + .get(&root_name) + .expect("Root wildcard not found"); + let wildcard = Wildcard::new(root_name, *root_index); + let key = middleware::Key::from(key_str.as_str()); + MWStatementTmplArg::AnchoredKey(wildcard, key) + } + }; + mw_args.push(mw_arg); + } + + Ok(MWStatementTmpl { + pred: predicate, + args: mw_args, + }) + } + + fn build_request_wildcard_map(&self, request_def: &RequestDef) -> HashMap { + // Collect all unique wildcards from all statements + let mut wildcard_names = Vec::new(); + let mut seen = HashSet::new(); + + for stmt in &request_def.statements { + self.collect_statement_wildcards(stmt, &mut wildcard_names, &mut seen); + } + + // Build map from name to index + wildcard_names + .into_iter() + .enumerate() + .map(|(idx, name)| (name, idx)) + .collect() + } + + fn collect_statement_wildcards( + &self, + stmt: &StatementTmpl, + names: &mut Vec, + seen: &mut HashSet, + ) { + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if !seen.contains(&id.name) { + seen.insert(id.name.clone()); + names.push(id.name.clone()); + } + } + StatementTmplArg::AnchoredKey(ak) => { + if !seen.contains(&ak.root.name) { + seen.insert(ak.root.name.clone()); + names.push(ak.root.name.clone()); + } + } + StatementTmplArg::Literal(_) => {} + } + } + } + + fn extract_and_split_predicates( + &self, + ) -> Result<(Vec, usize), LoweringError> { + let doc = self.validated.document(); + let predicates: Vec = doc + .items + .iter() + .filter_map(|item| match item { + DocumentItem::CustomPredicateDef(pred) => Some(pred.clone()), + _ => None, + }) + .collect(); + + let original_count = predicates.len(); + + // Apply splitting to each predicate as needed + let mut split_predicates = Vec::new(); + for pred in predicates { + let chain = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; + split_predicates.extend(chain); + } + + Ok((split_predicates, original_count)) + } + + fn lower_custom_predicate( + &self, + pred_def: &CustomPredicateDef, + cpb_builder: &mut CustomPredicateBatchBuilder, + ) -> Result<(), LoweringError> { + let name = pred_def.name.name.clone(); + + // Note: Constraint checking is handled by the splitting phase + // Predicates passed here should already be within limits + + // Collect public and private argument names + let mut public_arg_names = Vec::new(); + let mut private_arg_names = Vec::new(); + + for arg in &pred_def.args.public_args { + public_arg_names.push(arg.name.clone()); + } + + if let Some(private_args) = &pred_def.args.private_args { + for arg in private_args { + private_arg_names.push(arg.name.clone()); + } + } + + // Lower statements to builders + let mut statement_builders = Vec::new(); + for stmt in &pred_def.statements { + let stmt_builder = self.lower_statement_to_builder(stmt)?; + statement_builders.push(stmt_builder); + } + + // Convert to &str slices for builder API + let public_args_str: Vec<&str> = public_arg_names.iter().map(|s| s.as_str()).collect(); + let private_args_str: Vec<&str> = private_arg_names.iter().map(|s| s.as_str()).collect(); + + // Add predicate to batch using builder + let conjunction = pred_def.conjunction_type == ConjunctionType::And; + + cpb_builder + .predicate( + &name, + conjunction, + &public_args_str, + &private_args_str, + &statement_builders, + ) + .map_err(|e| match e { + crate::frontend::Error::Middleware(mw_err) => LoweringError::Middleware(mw_err), + _ => LoweringError::InvalidArgumentType, + })?; + + Ok(()) + } + + fn lower_statement_to_builder( + &self, + stmt: &StatementTmpl, + ) -> Result { + // Get predicate + let pred_name = &stmt.predicate.name; + let symbols = self.validated.symbols(); + + // Check for native predicates first + let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) { + Predicate::Native(native) + } else if let Some(&index) = self.batch_predicate_index.get(pred_name) { + // References to other predicates in the same batch (including split chains) + Predicate::BatchSelf(index) + } else if let Some(info) = symbols.predicates.get(pred_name) { + match &info.kind { + PredicateKind::Native(np) => Predicate::Native(*np), + PredicateKind::Custom { index } => Predicate::BatchSelf(*index), + PredicateKind::BatchImported { batch, index } => { + Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index)) + } + PredicateKind::IntroImported { + name, + verifier_data_hash, + } => Predicate::Intro(IntroPredicateRef { + name: name.clone(), + args_len: info.public_arity, + verifier_data_hash: *verifier_data_hash, + }), + } + } else { + unreachable!("Predicate {} not found", pred_name); + }; + + // Check args count + if stmt.args.len() > self.params.max_statement_args { + return Err(LoweringError::TooManyStatementArgs { + count: stmt.args.len(), + max: self.params.max_statement_args, + }); + } + + // Convert AST args to BuilderArgs + let mut builder = StatementTmplBuilder::new(predicate); + for arg in &stmt.args { + let builder_arg = self.lower_statement_arg_to_builder(arg)?; + builder = builder.arg(builder_arg); + } + + // Return builder without calling .desugar() - that will happen later + Ok(builder) + } + + fn lower_statement_arg_to_builder( + &self, + arg: &StatementTmplArg, + ) -> Result { + match arg { + StatementTmplArg::Literal(lit) => { + let value = self.lower_literal(lit)?; + Ok(BuilderArg::Literal(value)) + } + StatementTmplArg::Wildcard(id) => { + // For builder, we just need the wildcard name + Ok(BuilderArg::WildcardLiteral(id.name.clone())) + } + StatementTmplArg::AnchoredKey(ak) => { + let key_str = match &ak.key { + AnchoredKeyPath::Bracket(s) => s.value.clone(), + AnchoredKeyPath::Dot(id) => id.name.clone(), + }; + Ok(BuilderArg::Key(ak.root.name.clone(), key_str)) + } + } + } + + fn lower_literal(&self, lit: &LiteralValue) -> Result { + let value = match lit { + LiteralValue::Int(i) => middleware::Value::from(i.value), + LiteralValue::Bool(b) => middleware::Value::from(b.value), + LiteralValue::String(s) => middleware::Value::from(s.value.clone()), + LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash), + LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point), + LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()), + LiteralValue::Array(a) => { + let elements: Result, _> = + a.elements.iter().map(|e| self.lower_literal(e)).collect(); + let array = containers::Array::new(self.params.max_depth_mt_containers, elements?)?; + middleware::Value::from(array) + } + LiteralValue::Set(s) => { + let elements: Result, _> = + s.elements.iter().map(|e| self.lower_literal(e)).collect(); + let set_values: std::collections::HashSet<_> = elements?.into_iter().collect(); + let set = containers::Set::new(self.params.max_depth_mt_containers, set_values)?; + middleware::Value::from(set) + } + LiteralValue::Dict(d) => { + let pairs: Result, LoweringError> = d + .pairs + .iter() + .map(|pair| { + let key = middleware::Key::from(pair.key.value.as_str()); + let value = self.lower_literal(&pair.value)?; + Ok((key, value)) + }) + .collect(); + let dict_map: std::collections::HashMap<_, _> = pairs?.into_iter().collect(); + let dict = + containers::Dictionary::new(self.params.max_depth_mt_containers, dict_map)?; + middleware::Value::from(dict) + } + }; + Ok(value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{ + frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang, + }; + + fn parse_validate_and_lower( + input: &str, + params: &Params, + ) -> Result { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + let validated = validate(document, &[]).expect("Failed to validate"); + lower(validated, params, "test_batch".to_string()) + } + + // Helper to get the batch from the output (expecting it to exist) + fn expect_batch(output: &LoweredOutput) -> &Arc { + output.batch.as_ref().expect("Expected batch to be present") + } + + #[test] + fn test_simple_predicate() { + let input = r#" + my_pred(A, B) = AND ( + Equal(A["foo"], B["bar"]) + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + if let Err(e) = &result { + eprintln!("Error: {:?}", e); + } + assert!(result.is_ok()); + + let lowered = result.unwrap(); + assert_eq!(expect_batch(&lowered).predicates().len(), 1); + + let pred = &expect_batch(&lowered).predicates()[0]; + assert_eq!(pred.name, "my_pred"); + assert_eq!(pred.args_len(), 2); + assert_eq!(pred.wildcard_names().len(), 2); + assert_eq!(pred.statements().len(), 1); + } + + #[test] + fn test_private_args() { + let input = r#" + my_pred(A, private: B, C) = AND ( + Equal(A["x"], B["y"]) + Equal(B["z"], C["w"]) + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + + let lowered = result.unwrap(); + let pred = &expect_batch(&lowered).predicates()[0]; + assert_eq!(pred.args_len(), 1); // Only A is public + assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total + } + + #[test] + fn test_or_predicate() { + let input = r#" + my_pred(A, B) = OR ( + Equal(A["x"], 1) + Equal(B["y"], 2) + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + + let lowered = result.unwrap(); + let pred = &expect_batch(&lowered).predicates()[0]; + assert!(pred.is_disjunction()); + } + + #[test] + fn test_automatic_splitting() { + let input = r#" + my_pred(A) = AND ( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let params = Params::default(); // max_custom_predicate_arity = 5 + let result = parse_validate_and_lower(input, ¶ms); + if let Err(e) = &result { + eprintln!("Splitting error: {:?}", e); + } + assert!(result.is_ok()); + + let lowered = result.unwrap(); + // Should be automatically split into 2 predicates (my_pred and my_pred_1) + assert_eq!(expect_batch(&lowered).predicates().len(), 2); + + // First predicate should have 5 statements (4 + chain call) + assert_eq!(expect_batch(&lowered).predicates()[0].statements().len(), 5); + + // Second predicate should have 2 statements + assert_eq!(expect_batch(&lowered).predicates()[1].statements().len(), 2); + } + + #[test] + fn test_multiple_predicates() { + let input = r#" + pred1(A) = AND ( + Equal(A["x"], 1) + ) + + pred2(B) = AND ( + Equal(B["y"], 2) + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + + let lowered = result.unwrap(); + assert_eq!(expect_batch(&lowered).predicates().len(), 2); + } + + #[test] + fn test_batch_self_reference() { + let input = r#" + pred1(A) = AND ( + Equal(A["x"], 1) + ) + + pred2(B) = AND ( + pred1(B) + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + + let lowered = result.unwrap(); + let pred2 = &expect_batch(&lowered).predicates()[1]; + let stmt = &pred2.statements()[0]; + + // Should be BatchSelf(0) referring to pred1 + assert!(matches!(stmt.pred, Predicate::BatchSelf(0))); + } + + #[test] + fn test_literals() { + let input = r#" + my_pred(X) = AND ( + Equal(X["int"], 42) + Equal(X["bool"], true) + Equal(X["string"], "hello") + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + } + + #[test] + fn test_syntactic_sugar_desugaring() { + let input = r#" + my_pred(D) = AND ( + DictContains(D, "key", "value") + ) + "#; + + let params = Params::default(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); + + let lowered = result.unwrap(); + let pred = &expect_batch(&lowered).predicates()[0]; + let stmt = &pred.statements()[0]; + + // Should desugar to the Contains predicate + assert!(matches!( + stmt.pred, + Predicate::Native(NativePredicate::Contains) + )); + } + + #[test] + fn test_error_message_with_splitting() { + // Create a document with predicates that will exceed the batch limit after splitting + // We'll create 2 predicates with 4 statements each (max arity = 5) + // Each will NOT split individually, but together they exceed a small batch limit + let input = r#" + pred1(A) = AND ( + Equal(A["a"], 1) + Equal(A["b"], 2) + ) + pred2(B) = AND ( + Equal(B["c"], 3) + Equal(B["d"], 4) + ) + "#; + + // Use very restrictive params to force the error + let params = Params { + max_custom_batch_size: 1, + ..Default::default() + }; + + let result = parse_validate_and_lower(input, ¶ms); + + // Should fail with TooManyPredicates error + assert!(result.is_err()); + let err = result.unwrap_err(); + + if let LoweringError::TooManyPredicates { + count, + max, + original_count, + .. + } = err + { + assert_eq!(count, 2); // 2 predicates after splitting (no splitting occurred) + assert_eq!(max, 1); + assert_eq!(original_count, 2); // Started with 2 predicates + + // Error message should NOT mention splitting since no splitting occurred + let err_msg = format!("{}", err); + assert!(!err_msg.contains("before automatic splitting")); + } else { + panic!("Expected TooManyPredicates error, got: {:?}", err); + } + } + + #[test] + fn test_error_message_after_splitting() { + // Create TWO predicates that will EACH split into 2 predicates + // This tests the case where splitting causes the batch to be too large + // but no individual predicate chain exceeds the limit + let input = r#" + pred1(A) = AND ( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + pred2(B) = AND ( + Equal(B["a"], 1) + Equal(B["b"], 2) + Equal(B["c"], 3) + Equal(B["d"], 4) + Equal(B["e"], 5) + Equal(B["f"], 6) + ) + "#; + + // Use params where each predicate splits into 2, but total of 4 exceeds batch limit + let params = Params { + // Allow 3 predicates in batch + // Default max_custom_predicate_arity is 5, so each will split into 2 predicates + // Total: 2 original predicates -> 4 after splitting (exceeds limit of 3) + max_custom_batch_size: 3, + ..Default::default() + }; + + let result = parse_validate_and_lower(input, ¶ms); + + // Should fail with TooManyPredicates error + assert!(result.is_err()); + let err = result.unwrap_err(); + + if let LoweringError::TooManyPredicates { + count, + max, + original_count, + .. + } = err + { + assert_eq!(count, 4); // 4 predicates after splitting (2 from each) + assert_eq!(max, 3); + assert_eq!(original_count, 2); // Started with 2 predicates + + // Error message SHOULD mention splitting since splitting occurred + let err_msg = format!("{}", err); + assert!(err_msg.contains("before automatic splitting")); + assert!(err_msg.contains("started with 2 predicates")); + } else { + panic!("Expected TooManyPredicates error, got: {:?}", err); + } + } +} diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs new file mode 100644 index 0000000..a8e3780 --- /dev/null +++ b/src/lang/frontend_ast_split.rs @@ -0,0 +1,1002 @@ +//! Predicate splitting for frontend AST +//! +//! This module implements automatic predicate splitting when predicates exceed +//! middleware constraints. +//! +//! When splitting a predicate, we try to group statements that use the same +//! wildcards together. However, if a private wildcard must be used across a +//! split boundary, it must be promoted to a public argument in the latter +//! predicate, to ensure that it is bound to the same value in both predicates. +//! +//! A wildcard is "live" at a split boundary if it is used in a statement on both +//! sides of the boundary. We want to minimize the number of live wildcards at +//! split boundaries, to minimize the number of promotions required. +//! +//! We use a greedy algorithm to order the statements in a predicate to minimize +//! the number of live wildcards at split boundaries. + +use std::collections::{HashMap, HashSet}; + +// SplittingError is now defined in error.rs +pub use crate::lang::error::SplittingError; +use crate::{lang::frontend_ast::*, middleware::Params}; + +/// A link in the predicate chain +#[derive(Debug, Clone)] +pub struct ChainLink { + /// Statements in this link + pub statements: Vec, + /// Public arguments coming into this link + pub public_args_in: Vec, + /// Private arguments used only in this link + pub private_args: Vec, + /// Public arguments promoted to pass to next link (empty if last link) + pub public_args_out: Vec, +} + +/// Wildcard usage information +#[derive(Debug, Clone)] +struct WildcardUsage { + /// Indices of statements using this wildcard + used_in_statements: HashSet, +} + +/// Early validation: Check if predicate is fundamentally splittable +pub fn validate_predicate_is_splittable( + pred: &CustomPredicateDef, + params: &Params, +) -> Result<(), SplittingError> { + let public_args = pred.args.public_args.len(); + + // Check: public args must fit in operation arg limit + if public_args > params.max_statement_args { + return Err(SplittingError::TooManyPublicArgs { + predicate: pred.name.name.clone(), + count: public_args, + max_allowed: params.max_statement_args, + message: "Public arguments exceed max operation args - cannot call this predicate" + .to_string(), + }); + } + + Ok(()) +} + +/// Split a predicate into a chain if it exceeds statement limit +pub fn split_predicate_if_needed( + pred: CustomPredicateDef, + params: &Params, +) -> Result, SplittingError> { + // Early validation + validate_predicate_is_splittable(&pred, params)?; + + // If within limits, no splitting needed + if pred.statements.len() <= params.max_custom_predicate_arity { + return Ok(vec![pred]); + } + + // Need to split - execute the splitting algorithm + let chain = split_into_chain(pred, params)?; + + Ok(chain) +} + +fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap { + let mut usage: HashMap = HashMap::new(); + + for (idx, stmt) in statements.iter().enumerate() { + let wildcards = collect_wildcards_from_statement(stmt); + + for wildcard in wildcards { + usage + .entry(wildcard.clone()) + .or_insert_with(|| WildcardUsage { + used_in_statements: HashSet::new(), + }) + .used_in_statements + .insert(idx); + } + } + + usage +} + +/// Collect all wildcard names from a statement +fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { + let mut wildcards = HashSet::new(); + + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + wildcards.insert(id.name.clone()); + } + StatementTmplArg::AnchoredKey(ak) => { + wildcards.insert(ak.root.name.clone()); + } + StatementTmplArg::Literal(_) => {} + } + } + + wildcards +} + +/// Order constraints optimally to minimize liveness at boundaries +fn order_constraints_optimally( + statements: Vec, + _usage: &HashMap, + params: &Params, +) -> Vec { + // If no splitting needed, preserve original order + if statements.len() <= params.max_custom_predicate_arity { + return statements; + } + + let mut ordered = Vec::new(); + let mut remaining: HashSet = (0..statements.len()).collect(); + let mut active_wildcards: HashSet = HashSet::new(); + + while !remaining.is_empty() { + let best_idx = find_best_next_statement( + &statements, + &remaining, + &active_wildcards, + ordered.len(), + params, + ); + + remaining.remove(&best_idx); + let stmt = &statements[best_idx]; + ordered.push(stmt.clone()); + + // Update active wildcards + let stmt_wildcards = collect_wildcards_from_statement(stmt); + active_wildcards.extend(stmt_wildcards); + + // Remove wildcards no longer needed by remaining statements + let needed_later: HashSet<_> = remaining + .iter() + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .collect(); + active_wildcards.retain(|w| needed_later.contains(w)); + } + + ordered +} + +/// Compute tie-breaker metrics for deterministic ordering when scores are equal +/// Returns (simplicity, public_closure, negative_fanout) tuple for use in max_by_key +fn compute_tie_breakers( + stmt: &StatementTmpl, + active_wildcards: &HashSet, + statements: &[StatementTmpl], + remaining: &HashSet, +) -> (usize, usize, i32) { + let stmt_wildcards = collect_wildcards_from_statement(stmt); + + // Metric 1: Simplicity - prefer statements with fewer wildcards + let simplicity = usize::MAX - stmt_wildcards.len(); + + // Metric 2: Public closure - prefer statements that close active wildcards + // (wildcards that won't be needed by any remaining statements) + let needed_later: HashSet = remaining + .iter() + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .collect(); + + let closes_count = stmt_wildcards + .intersection(active_wildcards) + .filter(|w| !needed_later.contains(*w)) + .count(); + + // Metric 3: Fanout - prefer statements with lower future usage + // (number of remaining statements that use any wildcard from this statement) + let fanout = remaining + .iter() + .filter(|&&i| { + let other_wildcards = collect_wildcards_from_statement(&statements[i]); + !stmt_wildcards.is_disjoint(&other_wildcards) + }) + .count(); + + (simplicity, closes_count, -(fanout as i32)) +} + +/// Find the best next statement to add based on scoring heuristic +fn find_best_next_statement( + statements: &[StatementTmpl], + remaining: &HashSet, + active_wildcards: &HashSet, + ordered_count: usize, + params: &Params, +) -> usize { + // Calculate distance to next split point + let bucket_size = params.max_custom_predicate_arity - 1; // Reserve slot for chain call + let distance_to_split = bucket_size - (ordered_count % bucket_size); + let approaching_split = distance_to_split <= 2; + + remaining + .iter() + .max_by_key(|&&idx| { + let primary_score = score_statement( + &statements[idx], + active_wildcards, + statements, + remaining, + approaching_split, + ); + let tie_breakers = + compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining); + (primary_score, tie_breakers) + }) + .copied() + .unwrap() +} + +/// Score a statement based on how well it minimizes liveness +fn score_statement( + stmt: &StatementTmpl, + active_wildcards: &HashSet, + statements: &[StatementTmpl], + remaining: &HashSet, + approaching_split: bool, +) -> i32 { + let stmt_wildcards = collect_wildcards_from_statement(stmt); + + // How many active wildcards does this reuse? + let reuse_count = stmt_wildcards.intersection(active_wildcards).count(); + + // How many new wildcards does this introduce? + let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count(); + + // After adding this statement, what would be active? + let mut projected_active = active_wildcards.clone(); + projected_active.extend(stmt_wildcards.clone()); + + // Which wildcards are still needed by other remaining statements? + let needed_later: HashSet = remaining + .iter() + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .collect(); + + // Wildcards we can close = active now but not needed later + projected_active.retain(|w| needed_later.contains(w)); + let still_active_count = projected_active.len(); + + // Base score calculation + // - Prefer statements that reuse active wildcards (don't introduce new liveness) + // - Penalize introducing new wildcards (increases liveness) + // - Penalize keeping many wildcards active (higher liveness) + let base_score = (reuse_count * 3) as i32 + - (new_wildcard_count * 4) as i32 + - (still_active_count * 2) as i32; + + // Look-ahead bonus: when approaching split, heavily favor closing wildcards + if approaching_split { + let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count; + base_score + (closes_count * 10) as i32 + } else { + base_score + } +} + +/// Calculate which wildcards are live at a split boundary +fn calculate_live_wildcards( + before_split: &[StatementTmpl], + after_split: &[StatementTmpl], +) -> HashSet { + let before: HashSet<_> = before_split + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + let after: HashSet<_> = after_split + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + // Live = in both sets (crosses boundary) + before.intersection(&after).cloned().collect() +} + +/// Generate a refactor suggestion for wildcards crossing a boundary +fn generate_refactor_suggestion( + crossing_wildcards: &[String], + ordered_statements: &[StatementTmpl], + _pos: usize, + _end: usize, +) -> Option { + use crate::lang::error::RefactorSuggestion; + + if crossing_wildcards.is_empty() { + return None; + } + + // Analyze the span of each crossing wildcard + let mut wildcard_spans: Vec<(String, usize, usize, usize)> = Vec::new(); + + for wildcard in crossing_wildcards { + let mut first_use = None; + let mut last_use = None; + + for (i, stmt) in ordered_statements.iter().enumerate() { + let wildcards = collect_wildcards_from_statement(stmt); + if wildcards.contains(wildcard) { + if first_use.is_none() { + first_use = Some(i); + } + last_use = Some(i); + } + } + + if let (Some(first), Some(last)) = (first_use, last_use) { + let span = last - first; + wildcard_spans.push((wildcard.clone(), first, last, span)); + } + } + + // Sort by span (largest first) + wildcard_spans.sort_by(|a, b| b.3.cmp(&a.3)); + + if let Some((wildcard, first, last, span)) = wildcard_spans.first() { + // If a single wildcard has a large span, suggest reducing it + if *span > 3 { + return Some(RefactorSuggestion::ReduceWildcardSpan { + wildcard: wildcard.clone(), + first_use: *first, + last_use: *last, + span: *span, + }); + } + } + + // If multiple wildcards cross the boundary, suggest grouping + if crossing_wildcards.len() > 1 { + return Some(RefactorSuggestion::GroupWildcardUsages { + wildcards: crossing_wildcards.to_vec(), + }); + } + + None +} + +/// Split into chain using bucket-filling approach +fn split_into_chain( + pred: CustomPredicateDef, + params: &Params, +) -> Result, SplittingError> { + let original_name = pred.name.name.clone(); + let conjunction = pred.conjunction_type; + + let usage = analyze_wildcards(&pred.statements); + + let ordered_statements = order_constraints_optimally(pred.statements, &usage, params); + + let original_public_args: Vec = pred + .args + .public_args + .iter() + .map(|id| id.name.clone()) + .collect(); + + let mut chain_links = Vec::new(); + let mut pos = 0; + let mut incoming_public = original_public_args.clone(); + + while pos < ordered_statements.len() { + let remaining = ordered_statements.len() - pos; + let is_last = remaining <= params.max_custom_predicate_arity; + + let bucket_size = if is_last { + remaining // Last predicate uses all remaining + } else { + params.max_custom_predicate_arity - 1 // Reserve slot for chain call + }; + + let end = pos + bucket_size; + + // Calculate liveness at this split boundary + let live_at_boundary = if is_last { + HashSet::new() + } else { + calculate_live_wildcards(&ordered_statements[pos..end], &ordered_statements[end..]) + }; + + // Check: Can we fit promoted wildcards in public args? + // Need to account for possible overlap between incoming_public and live_at_boundary + let incoming_set: HashSet<_> = incoming_public.iter().cloned().collect(); + let new_promotions: Vec<_> = live_at_boundary + .iter() + .filter(|w| !incoming_set.contains(*w)) + .cloned() + .collect(); + let total_public = incoming_public.len() + new_promotions.len(); + if total_public > params.max_statement_args { + let context = crate::lang::error::SplitContext { + split_index: chain_links.len(), + statement_range: (pos, end), + incoming_public: incoming_public.clone(), + crossing_wildcards: new_promotions.clone(), + total_public, + }; + + let suggestion = + generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end); + + return Err(SplittingError::TooManyPublicArgsAtSplit { + predicate: original_name.clone(), + context: Box::new(context), + max_allowed: params.max_statement_args, + suggestion: suggestion.map(Box::new), + }); + } + + // Calculate private args (used in this segment but not incoming and not outgoing) + let segment_wildcards: HashSet<_> = ordered_statements[pos..end] + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + let mut private_args: Vec = segment_wildcards + .difference(&incoming_set) + .filter(|w| !live_at_boundary.contains(*w)) + .cloned() + .collect(); + private_args.sort(); // Deterministic ordering + + // Check: Total args constraint (incoming + new promotions + private) + let public_count = incoming_public.len() + new_promotions.len(); + let private_count = private_args.len(); + let total_args = public_count + private_count; + if total_args > params.max_custom_predicate_wildcards { + return Err(SplittingError::TooManyTotalArgsInChainLink { + predicate: original_name.clone(), + link_index: chain_links.len(), + public_count, + private_count, + total_count: total_args, + max_allowed: params.max_custom_predicate_wildcards, + }); + } + + let mut public_args_out: Vec = live_at_boundary.iter().cloned().collect(); + public_args_out.sort(); // Deterministic ordering + + chain_links.push(ChainLink { + statements: ordered_statements[pos..end].to_vec(), + public_args_in: incoming_public.clone(), + private_args, + public_args_out: public_args_out.clone(), + }); + + pos = end; + + // Next link's incoming public args = current incoming + newly promoted live wildcards + // Only add wildcards that aren't already in incoming_public to avoid duplicates + for wildcard in public_args_out { + if !incoming_set.contains(&wildcard) { + incoming_public.push(wildcard); + } + } + } + + let chain_predicates = + generate_chain_predicates(&original_name, chain_links, conjunction, params)?; + + validate_chain(&chain_predicates, &original_name, params)?; + + Ok(chain_predicates) +} + +/// Phase 4: Generate synthetic predicates from chain links +fn generate_chain_predicates( + original_name: &str, + chain_links: Vec, + conjunction: ConjunctionType, + _params: &Params, +) -> Result, SplittingError> { + let mut predicates = Vec::new(); + + for (i, link) in chain_links.iter().enumerate() { + let pred_name = if i == 0 { + Identifier { + name: original_name.to_string(), + span: None, + } + } else { + Identifier { + name: format!("{}_{}", original_name, i), + span: None, + } + }; + + let is_last = i == chain_links.len() - 1; + let mut statements = link.statements.clone(); + + // Add chain call if not last + if !is_last { + let next_pred_name = Identifier { + name: format!("{}_{}", original_name, i + 1), + span: None, + }; + + // Create arguments for chain call: all public args (incoming + promoted) + let mut chain_call_args = Vec::new(); + for arg_name in &link.public_args_in { + chain_call_args.push(StatementTmplArg::Wildcard(Identifier { + name: arg_name.clone(), + span: None, + })); + } + for arg_name in &link.public_args_out { + chain_call_args.push(StatementTmplArg::Wildcard(Identifier { + name: arg_name.clone(), + span: None, + })); + } + + let chain_call = StatementTmpl { + predicate: next_pred_name, + args: chain_call_args, + span: None, + }; + + statements.push(chain_call); + } + + // Build public args (incoming) + let public_args: Vec = link + .public_args_in + .iter() + .map(|name| Identifier { + name: name.clone(), + span: None, + }) + .collect(); + + // Build private args (private + promoted for next) + let mut private_arg_names = link.private_args.clone(); + if !is_last { + private_arg_names.extend(link.public_args_out.clone()); + } + + let private_args = if private_arg_names.is_empty() { + None + } else { + Some( + private_arg_names + .into_iter() + .map(|name| Identifier { name, span: None }) + .collect(), + ) + }; + + predicates.push(CustomPredicateDef { + name: pred_name, + args: ArgSection { + public_args, + private_args, + span: None, + }, + conjunction_type: conjunction, + statements, + span: None, + }); + } + + Ok(predicates) +} + +/// Phase 5: Validate the generated chain +fn validate_chain( + chain: &[CustomPredicateDef], + original_name: &str, + params: &Params, +) -> Result<(), SplittingError> { + if chain.len() > params.max_custom_batch_size { + return Err(SplittingError::TooManyPredicatesInChain { + predicate: original_name.to_string(), + count: chain.len(), + max_allowed: params.max_custom_batch_size, + }); + } + + for pred in chain { + // Each predicate should have ≤ max_statements + assert!(pred.statements.len() <= params.max_custom_predicate_arity); + + // Public args should fit + assert!(pred.args.public_args.len() <= params.max_statement_args); + + // Total args should fit + let total = + pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len()); + assert!(total <= params.max_custom_predicate_wildcards); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lang::{frontend_ast::parse::parse_document, parser::parse_podlang}; + + fn parse_predicate(input: &str) -> CustomPredicateDef { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + + for item in document.items { + if let DocumentItem::CustomPredicateDef(pred) = item { + return pred; + } + } + + panic!("No custom predicate found"); + } + + #[test] + fn test_validate_splittable() { + let input = r#" + my_pred(A, B) = AND ( + Equal(A, B) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); + + assert!(validate_predicate_is_splittable(&pred, ¶ms).is_ok()); + } + + #[test] + fn test_validate_too_many_public_args() { + let input = r#" + my_pred(A, B, C, D, E, F) = AND ( + Equal(A, B) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); // max_statement_args = 5 + + let result = validate_predicate_is_splittable(&pred, ¶ms); + assert!(matches!( + result, + Err(SplittingError::TooManyPublicArgs { .. }) + )); + } + + #[test] + fn test_no_split_needed() { + let input = r#" + my_pred(A, B) = AND ( + Equal(A["x"], B["y"]) + Equal(A["z"], 1) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + assert_eq!(chain.len(), 1); // No split needed + } + + #[test] + fn test_simple_split() { + let input = r#" + my_pred(A) = AND ( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); // max_custom_predicate_arity = 5 + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + assert_eq!(chain.len(), 2); // Should split into 2 predicates + + // First predicate: 4 statements + chain call = 5 + assert_eq!(chain[0].statements.len(), 5); + + // Second predicate: 2 remaining statements + assert_eq!(chain[1].statements.len(), 2); + } + + #[test] + fn test_split_with_private_wildcards() { + let input = r#" + complex(A, B, private: T1, T2) = AND ( + Equal(T1["x"], A["y"]) + Equal(T1["z"], 100) + Equal(T2["a"], T1["x"]) + HashOf(T2["b"], B) + Equal(A["result"], T2["a"]) + Equal(B["final"], T2["b"]) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); // max_custom_predicate_arity = 5 + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + assert_eq!(chain.len(), 2); // Should split into 2 predicates + + // First predicate should have wildcards that cross boundary promoted + // Check that chain call is present + let last_stmt = &chain[0].statements.last().unwrap(); + assert_eq!(last_stmt.predicate.name, "complex_1"); + } + + #[test] + fn test_split_into_three_predicates() { + let input = r#" + large_pred(A) = AND ( + Equal(A["a"], 1) + Equal(A["b"], 2) + Equal(A["c"], 3) + Equal(A["d"], 4) + Equal(A["e"], 5) + Equal(A["f"], 6) + Equal(A["g"], 7) + Equal(A["h"], 8) + Equal(A["i"], 9) + Equal(A["j"], 10) + Equal(A["k"], 11) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); // max_custom_predicate_arity = 5 + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + assert_eq!(chain.len(), 3); // Should split into 3 predicates + + // First: 4 + chain call = 5 + assert_eq!(chain[0].statements.len(), 5); + // Second: 4 + chain call = 5 + assert_eq!(chain[1].statements.len(), 5); + // Third: 3 remaining + assert_eq!(chain[2].statements.len(), 3); + } + + #[test] + fn test_no_duplicate_promoted_wildcards() { + // Test that a wildcard used across multiple chain boundaries + // doesn't get duplicated in incoming_public + let input = r#" + reuse_pred(A, private: T) = AND ( + Equal(T["x"], A["start"]) + Equal(T["y"], 1) + Equal(T["z"], 2) + Equal(T["w"], 3) + Equal(A["mid"], T["x"]) + Equal(T["a"], 4) + Equal(T["b"], 5) + Equal(T["c"], 6) + Equal(A["end"], T["x"]) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + // Should split into 2 predicates + // T is used in first segment and crosses to second, then used again in second + assert_eq!(chain.len(), 2); + + // Check that second predicate's public args don't have duplicates + let second_pred_public_count = chain[1].args.public_args.len(); + let second_pred_public_names: Vec<_> = chain[1] + .args + .public_args + .iter() + .map(|id| &id.name) + .collect(); + let unique_count = second_pred_public_names + .iter() + .collect::>() + .len(); + + assert_eq!( + second_pred_public_count, unique_count, + "Public args should not contain duplicates" + ); + } + + #[test] + fn test_greedy_ordering_reduces_liveness() { + // This test verifies that our greedy ordering algorithm reduces wildcard liveness + // by clustering statements that use the same wildcards together. + // + // The predicate has 8 statements using 3 private wildcards (T1, T2, T3): + // - T1 used in statements 1, 4, 7 + // - T2 used in statements 2, 5, 8 + // - T3 used in statements 3, 6 + // + // NAIVE ORDERING (original order): + // Would interleave T1, T2, T3 usage throughout the predicate. + // When splitting at statement limit (5 statements per predicate): + // Predicate 1: statements 1-5 (introduces T1, T2, T3 - none complete) + // Predicate 2: statements 6-8 (all 3 wildcards still live) + // Result: 2 public args (A, B) + 3 promoted wildcards = 5 total in predicate 2 + // + // GREEDY ORDERING (our algorithm): + // Clusters statements by wildcard to minimize liveness: + // Groups T1 statements together, then T2, then T3 + // Predicate 1: completes some wildcards before the split point + // Predicate 2: fewer wildcards need to cross the boundary + // Result: 2 public args (A, B) + 1-2 promoted wildcards = 3-4 total in predicate 2 + let input = r#" + clustered(A, B, private: T1, T2, T3) = AND ( + Equal(T1["x"], 1) + Equal(T2["y"], 2) + Equal(T3["z"], 3) + Equal(T1["a"], 4) + Equal(T2["b"], 5) + Equal(T3["c"], 6) + Equal(T1["d"], A["result"]) + Equal(T2["e"], B["value"]) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let chain = result.unwrap(); + assert_eq!(chain.len(), 2, "Predicate should split into 2 links"); + + let second_pred = &chain[1]; + let second_pred_public_count = second_pred.args.public_args.len(); + + // Verify greedy ordering achieves better results than naive ordering would + // Started with 2 public args (A, B) + // Naive would have: 2 + 3 promoted = 5 public args in second predicate + // Greedy achieves: 2 + 1-2 promoted = 3-4 public args in second predicate + assert!( + second_pred_public_count <= 4, + "Greedy ordering should reduce promotions to ≤4 public args, but got {}", + second_pred_public_count + ); + } + + #[test] + fn test_error_message_formatting() { + // Test that error messages format correctly with detailed context + // We'll manually construct the error to test the formatting + use crate::lang::error::{RefactorSuggestion, SplitContext}; + + let context = SplitContext { + split_index: 0, + statement_range: (0, 4), + incoming_public: vec!["A".to_string(), "B".to_string(), "C".to_string()], + crossing_wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], + total_public: 6, + }; + + let suggestion = Some(RefactorSuggestion::GroupWildcardUsages { + wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], + }); + + let error = SplittingError::TooManyPublicArgsAtSplit { + predicate: "test_pred".to_string(), + context: Box::new(context), + max_allowed: 5, + suggestion: suggestion.map(Box::new), + }; + + let error_msg = format!("{}", error); + + // Verify the error message contains all the key information + assert!(error_msg.contains("test_pred")); + assert!(error_msg.contains("split boundary 0")); + assert!(error_msg.contains("3 incoming public")); + assert!(error_msg.contains("3 crossing wildcards")); + assert!(error_msg.contains("= 6 total")); + assert!(error_msg.contains("exceeds max of 5")); + assert!(error_msg.contains("Statements 0-4")); + assert!(error_msg.contains("Incoming public args: A, B, C")); + assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3")); + assert!(error_msg.contains("Suggestion:")); + assert!(error_msg.contains("Group operations for wildcards")); + + eprintln!("\n=== Example Error Message ===\n{}\n", error_msg); + } + + #[test] + fn test_error_too_many_total_args_formatting() { + // Test the TooManyTotalArgsInChainLink error message formatting + let error = SplittingError::TooManyTotalArgsInChainLink { + predicate: "huge_pred".to_string(), + link_index: 1, + public_count: 5, + private_count: 6, + total_count: 11, + max_allowed: 10, + }; + + let error_msg = format!("{}", error); + + // Verify the error message includes breakdown + assert!(error_msg.contains("huge_pred")); + assert!(error_msg.contains("chain link 1")); + assert!(error_msg.contains("5 public")); + assert!(error_msg.contains("6 private")); + assert!(error_msg.contains("= 11 total")); + assert!(error_msg.contains("exceeds max of 10")); + + eprintln!("\n=== Example TooManyTotalArgs Error ===\n{}\n", error_msg); + } + + #[test] + fn test_refactor_suggestion_reduce_wildcard_span() { + // Test the "reduce wildcard span" suggestion formatting + use crate::lang::error::RefactorSuggestion; + + let suggestion = RefactorSuggestion::ReduceWildcardSpan { + wildcard: "T".to_string(), + first_use: 0, + last_use: 7, + span: 7, + }; + + let suggestion_text = suggestion.format(); + + // Verify the suggestion formats correctly + assert!(suggestion_text.contains("'T'")); + assert!(suggestion_text.contains("used across 7 statements")); + assert!(suggestion_text.contains("statements 0-7")); + assert!(suggestion_text.contains("grouping all 'T' operations together")); + + eprintln!( + "\n=== Example ReduceWildcardSpan Suggestion ===\n{}\n", + suggestion_text + ); + } + + #[test] + fn test_refactor_suggestion_group_wildcards() { + // Test the "group wildcard usages" suggestion formatting + use crate::lang::error::RefactorSuggestion; + + let suggestion = RefactorSuggestion::GroupWildcardUsages { + wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], + }; + + let suggestion_text = suggestion.format(); + + // Verify the suggestion formats correctly + assert!(suggestion_text.contains("Group operations for wildcards")); + assert!(suggestion_text.contains("T1, T2, T3")); + assert!(suggestion_text.contains("used across multiple segments")); + + eprintln!( + "\n=== Example GroupWildcardUsages Suggestion ===\n{}\n", + suggestion_text + ); + } +} diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs new file mode 100644 index 0000000..6fd2349 --- /dev/null +++ b/src/lang/frontend_ast_validate.rs @@ -0,0 +1,775 @@ +//! Validation for the frontend AST +//! +//! This module provides semantic validation for parsed AST documents, +//! including name resolution, arity checking, and wildcard validation. + +use std::{collections::HashMap, str::FromStr, sync::Arc}; + +use hex::ToHex; + +use crate::{ + lang::frontend_ast::*, + middleware::{CustomPredicateBatch, Hash, NativePredicate}, +}; + +/// A validated AST document with symbol table and diagnostics +#[derive(Debug, Clone)] +pub struct ValidatedAST { + document: Document, + symbols: SymbolTable, + diagnostics: Vec, +} + +impl ValidatedAST { + pub fn document(&self) -> &Document { + &self.document + } + + pub fn symbols(&self) -> &SymbolTable { + &self.symbols + } + + pub fn diagnostics(&self) -> &[Diagnostic] { + &self.diagnostics + } + + pub fn into_document(self) -> Document { + self.document + } +} + +/// Symbol table containing all predicates and their metadata +#[derive(Debug, Clone)] +pub struct SymbolTable { + /// All predicates available in this scope + pub predicates: HashMap, + /// Wildcard scopes for each custom predicate + pub wildcard_scopes: HashMap, +} + +/// Information about a predicate +#[derive(Debug, Clone)] +pub struct PredicateInfo { + pub kind: PredicateKind, + pub arity: usize, + pub public_arity: usize, + pub source_span: Option, +} + +/// Kind of predicate +#[derive(Debug, Clone)] +pub enum PredicateKind { + Native(NativePredicate), + Custom { + index: usize, + }, + BatchImported { + batch: Arc, + index: usize, + }, + IntroImported { + name: String, + verifier_data_hash: Hash, + }, +} + +/// Wildcard scope for a custom predicate +#[derive(Debug, Clone)] +pub struct WildcardScope { + pub wildcards: HashMap, +} + +/// Information about a wildcard +#[derive(Debug, Clone)] +pub struct WildcardInfo { + pub index: usize, + pub is_public: bool, + pub source_span: Option, +} + +/// Diagnostic message (warning or info) +#[derive(Debug, Clone)] +pub struct Diagnostic { + pub level: DiagnosticLevel, + pub message: String, + pub span: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DiagnosticLevel { + Warning, + Info, +} + +pub use crate::lang::error::ValidationError; + +/// Validate an AST document +pub fn validate( + document: Document, + available_batches: &[Arc], +) -> Result { + let validator = Validator::new(available_batches); + validator.validate(document) +} + +struct Validator { + available_batches: HashMap>, + symbols: SymbolTable, + diagnostics: Vec, + custom_predicate_count: usize, +} + +impl Validator { + fn new(batches: &[Arc]) -> Self { + let mut available_batches = HashMap::new(); + for batch in batches { + // Store by hex ID for lookup + let id = format!("0x{}", batch.id().encode_hex::()); + available_batches.insert(id, batch.clone()); + } + + Self { + available_batches, + symbols: SymbolTable { + predicates: HashMap::new(), + wildcard_scopes: HashMap::new(), + }, + diagnostics: Vec::new(), + custom_predicate_count: 0, + } + } + + fn validate(mut self, document: Document) -> Result { + // Pass 1: Build symbol table + self.build_symbol_table(&document)?; + + // Pass 2: Validate all references + self.validate_references(&document)?; + + Ok(ValidatedAST { + document, + symbols: self.symbols, + diagnostics: self.diagnostics, + }) + } + + fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> { + // First process imports + for item in &document.items { + if let DocumentItem::UseBatchStatement(use_stmt) = item { + self.process_use_batch_statement(use_stmt)?; + } + if let DocumentItem::UseIntroStatement(use_stmt) = item { + self.process_use_intro_statement(use_stmt)?; + } + } + + // Then process custom predicate definitions + for item in &document.items { + if let DocumentItem::CustomPredicateDef(pred_def) = item { + self.process_custom_predicate_def(pred_def)?; + } + } + + // Check for multiple REQUEST definitions (only one allowed) + let mut first_request_span = None; + for item in &document.items { + if let DocumentItem::RequestDef(req) = item { + if let Some(first_span) = first_request_span { + return Err(ValidationError::MultipleRequestDefinitions { + first_span: Some(first_span), + second_span: req.span, + }); + } + first_request_span = req.span; + } + } + + Ok(()) + } + + fn process_use_batch_statement( + &mut self, + use_stmt: &UseBatchStatement, + ) -> Result<(), ValidationError> { + let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::()); + + let batch = self.available_batches.get(&batch_id).ok_or_else(|| { + ValidationError::BatchNotFound { + id: batch_id.clone(), + span: use_stmt.batch_ref.span, + } + })?; + + if use_stmt.imports.len() != batch.predicates().len() { + return Err(ValidationError::ImportArityMismatch { + expected: batch.predicates().len(), + found: use_stmt.imports.len(), + span: use_stmt.span, + }); + } + + for (i, import) in use_stmt.imports.iter().enumerate() { + if let ImportName::Named(name) = import { + if self.symbols.predicates.contains_key(name) { + return Err(ValidationError::DuplicateImport { + name: name.clone(), + span: use_stmt.span, + }); + } + + let pred = &batch.predicates()[i]; + // CustomPredicate has args_len (public args) and wildcard_names (total args) + let total_arity = pred.wildcard_names.len(); + let public_arity = pred.args_len; + + self.symbols.predicates.insert( + name.clone(), + PredicateInfo { + kind: PredicateKind::BatchImported { + batch: batch.clone(), + index: i, + }, + arity: total_arity, + public_arity, + source_span: use_stmt.span, + }, + ); + } + } + + Ok(()) + } + + fn process_use_intro_statement( + &mut self, + use_stmt: &UseIntroStatement, + ) -> Result<(), ValidationError> { + let intro_name = &use_stmt.name.name; + let args = &use_stmt.args; + let intro_predicate_ref = &use_stmt.intro_hash; + + if self.symbols.predicates.contains_key(intro_name) { + return Err(ValidationError::DuplicateImport { + name: intro_name.clone(), + span: use_stmt.span, + }); + } + + self.symbols.predicates.insert( + intro_name.clone(), + PredicateInfo { + kind: PredicateKind::IntroImported { + name: intro_name.clone(), + // Hash is already parsed in the AST + verifier_data_hash: intro_predicate_ref.hash, + }, + arity: args.len(), + public_arity: args.len(), + source_span: use_stmt.span, + }, + ); + Ok(()) + } + + fn process_custom_predicate_def( + &mut self, + pred_def: &CustomPredicateDef, + ) -> Result<(), ValidationError> { + let name = &pred_def.name.name; + + if self.symbols.predicates.contains_key(name) { + let first_span = self.symbols.predicates[name].source_span; + return Err(ValidationError::DuplicatePredicate { + name: name.clone(), + first_span, + second_span: pred_def.name.span, + }); + } + + // Check for empty statement list + if pred_def.statements.is_empty() { + return Err(ValidationError::EmptyStatementList { + context: format!("predicate '{}'", name), + span: pred_def.span, + }); + } + + // Build wildcard scope + let mut wildcards = HashMap::new(); + let mut wildcard_index = 0; + + // Process public arguments + for arg in &pred_def.args.public_args { + if wildcards.contains_key(&arg.name) { + return Err(ValidationError::DuplicateWildcard { + name: arg.name.clone(), + span: arg.span, + }); + } + wildcards.insert( + arg.name.clone(), + WildcardInfo { + index: wildcard_index, + is_public: true, + source_span: arg.span, + }, + ); + wildcard_index += 1; + } + + // Process private arguments + let mut private_count = 0; + if let Some(private_args) = &pred_def.args.private_args { + for arg in private_args { + if wildcards.contains_key(&arg.name) { + return Err(ValidationError::DuplicateWildcard { + name: arg.name.clone(), + span: arg.span, + }); + } + wildcards.insert( + arg.name.clone(), + WildcardInfo { + index: wildcard_index, + is_public: false, + source_span: arg.span, + }, + ); + wildcard_index += 1; + private_count += 1; + } + } + + // Add to symbol table + self.symbols.predicates.insert( + name.clone(), + PredicateInfo { + kind: PredicateKind::Custom { + index: self.custom_predicate_count, + }, + arity: pred_def.args.public_args.len() + private_count, + public_arity: pred_def.args.public_args.len(), + source_span: pred_def.name.span, + }, + ); + + self.symbols + .wildcard_scopes + .insert(name.clone(), WildcardScope { wildcards }); + self.custom_predicate_count += 1; + + Ok(()) + } + + fn validate_references(&mut self, document: &Document) -> Result<(), ValidationError> { + for item in &document.items { + match item { + DocumentItem::CustomPredicateDef(pred_def) => { + self.validate_custom_predicate_statements(pred_def)?; + } + DocumentItem::RequestDef(req_def) => { + self.validate_request_statements(req_def)?; + } + _ => {} + } + } + Ok(()) + } + + fn validate_custom_predicate_statements( + &self, + pred_def: &CustomPredicateDef, + ) -> Result<(), ValidationError> { + let pred_name = pred_def.name.name.clone(); + + for stmt in &pred_def.statements { + let wildcard_scope = self + .symbols + .wildcard_scopes + .get(&pred_name) + .expect("Wildcard scope should exist after pass 1"); + self.validate_statement(stmt, Some((&pred_name, wildcard_scope)))?; + } + + Ok(()) + } + + fn validate_request_statements(&mut self, req_def: &RequestDef) -> Result<(), ValidationError> { + if req_def.statements.is_empty() { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Warning, + message: "Empty REQUEST block".to_string(), + span: req_def.span, + }); + } + + for stmt in &req_def.statements { + self.validate_statement(stmt, None)?; + } + + Ok(()) + } + + fn validate_statement( + &self, + stmt: &StatementTmpl, + wildcard_context: Option<(&str, &WildcardScope)>, + ) -> Result<(), ValidationError> { + let pred_name = &stmt.predicate.name; + + // Check if predicate exists + let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) { + // Native predicate + PredicateInfo { + kind: PredicateKind::Native(native), + arity: native.arity(), + public_arity: native.arity(), + source_span: None, + } + } else if let Some(info) = self.symbols.predicates.get(pred_name) { + // Custom or imported predicate + info.clone() + } else { + return Err(ValidationError::UndefinedPredicate { + name: pred_name.clone(), + span: stmt.predicate.span, + }); + }; + + let expected_arity = pred_info.public_arity; + + if stmt.args.len() != expected_arity { + return Err(ValidationError::ArgumentCountMismatch { + predicate: pred_name.clone(), + expected: expected_arity, + found: stmt.args.len(), + span: stmt.span, + }); + } + + // Validate arguments + self.validate_statement_args(stmt, &pred_info, wildcard_context)?; + + Ok(()) + } + + fn validate_statement_args( + &self, + stmt: &StatementTmpl, + pred_info: &PredicateInfo, + wildcard_context: Option<(&str, &WildcardScope)>, + ) -> Result<(), ValidationError> { + // For custom predicates, only wildcards and literals are allowed + if matches!( + pred_info.kind, + PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. } + ) { + for arg in &stmt.args { + match arg { + StatementTmplArg::AnchoredKey(_) => { + return Err(ValidationError::InvalidArgumentType { + predicate: stmt.predicate.name.clone(), + span: stmt.span, + }); + } + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); + } + } + } + StatementTmplArg::Literal(_) => {} + } + } + } else { + // Native predicates can have anchored keys + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); + } + } + } + StatementTmplArg::AnchoredKey(ak) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&ak.root.name) { + return Err(ValidationError::UndefinedWildcard { + name: ak.root.name.clone(), + pred_name: pred_name.to_string(), + span: ak.root.span, + }); + } + } + } + StatementTmplArg::Literal(_) => {} + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + lang::{frontend_ast::parse::parse_document, parser::parse_podlang}, + middleware::{CustomPredicate, Params, EMPTY_HASH}, + }; + + fn parse_and_validate( + input: &str, + batches: &[Arc], + ) -> Result { + let parsed = parse_podlang(input).expect("Failed to parse"); + let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + validate(document, batches) + } + + #[test] + fn test_validate_empty() { + let result = parse_and_validate("", &[]); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_simple_request() { + let input = r#"REQUEST( + Equal(A["foo"], B["bar"]) + )"#; + let result = parse_and_validate(input, &[]); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_custom_predicate() { + let input = r#" + my_pred(A, B) = AND ( + Equal(A["foo"], B["bar"]) + ) + "#; + let result = parse_and_validate(input, &[]); + assert!(result.is_ok()); + + let validated = result.unwrap(); + assert!(validated.symbols.predicates.contains_key("my_pred")); + assert!(validated.symbols.wildcard_scopes.contains_key("my_pred")); + } + + #[test] + fn test_undefined_predicate() { + let input = r#"REQUEST( + UndefinedPred(A, B) + )"#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::UndefinedPredicate { .. }) + )); + } + + #[test] + fn test_undefined_wildcard() { + let input = r#" + my_pred(A) = AND ( + Equal(A["foo"], B["bar"]) + ) + "#; + let result = parse_and_validate(input, &[]); + assert!( + matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B") + ); + } + + #[test] + fn test_arity_mismatch() { + let input = r#"REQUEST( + Equal(A, B, C) + )"#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::ArgumentCountMismatch { .. }) + )); + } + + #[test] + fn test_duplicate_predicate() { + let input = r#" + my_pred(A) = AND (Equal(A["x"], 1)) + my_pred(B) = AND (Equal(B["y"], 2)) + "#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::DuplicatePredicate { .. }) + )); + } + + #[test] + fn test_duplicate_wildcard() { + let input = r#" + my_pred(A, A) = AND (Equal(A["x"], 1)) + "#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::DuplicateWildcard { .. }) + )); + } + + #[test] + fn test_custom_predicate_with_anchored_key() { + let input = r#" + my_pred(A, B) = AND ( + Equal(A["foo"], B["bar"]) + ) + + REQUEST( + my_pred(X["key"], Y) + ) + "#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::InvalidArgumentType { .. }) + )); + } + + #[test] + fn test_forward_reference() { + let input = r#" + pred1(A) = AND ( + pred2(A) + ) + + pred2(B) = AND ( + Equal(B["x"], 1) + ) + "#; + let result = parse_and_validate(input, &[]); + assert!(result.is_ok()); + } + + #[test] + fn test_private_args() { + let input = r#" + my_pred(A, private: B, C) = AND ( + Equal(A["x"], B["y"]) + Equal(B["z"], C["w"]) + ) + "#; + let result = parse_and_validate(input, &[]); + assert!(result.is_ok()); + + let validated = result.unwrap(); + let pred_info = &validated.symbols.predicates["my_pred"]; + assert_eq!(pred_info.arity, 3); + assert_eq!(pred_info.public_arity, 1); + } + + #[test] + fn test_empty_statement_list() { + // Create a custom predicate with empty statements to test validation + let document = Document { + items: vec![DocumentItem::CustomPredicateDef(CustomPredicateDef { + name: Identifier { + name: "my_pred".to_string(), + span: None, + }, + args: ArgSection { + public_args: vec![Identifier { + name: "A".to_string(), + span: None, + }], + private_args: None, + span: None, + }, + conjunction_type: ConjunctionType::And, + statements: vec![], // Empty statements + span: None, + })], + }; + let result = validate(document, &[]); + assert!(matches!( + result, + Err(ValidationError::EmptyStatementList { .. }) + )); + } + + #[test] + fn test_multiple_request_definitions() { + let input = r#" + REQUEST(Equal(A["x"], 1)) + REQUEST(Equal(B["y"], 2)) + "#; + let result = parse_and_validate(input, &[]); + assert!(matches!( + result, + Err(ValidationError::MultipleRequestDefinitions { .. }) + )); + } + + #[test] + fn test_use_statement() { + let params = Params::default(); + + // Create a batch to import + let pred = CustomPredicate::and( + ¶ms, + "imported".to_string(), + vec![], + 2, + vec!["X".to_string(), "Y".to_string()], + ) + .unwrap(); + + let batch = CustomPredicateBatch::new(¶ms, "TestBatch".to_string(), vec![pred]); + + let batch_id = batch.id().encode_hex::(); + let input = format!( + r#" + use batch imported_pred from 0x{} + use intro intro_pred() from 0x{} + + REQUEST( + imported_pred(A, B) + intro_pred() + ) + "#, + batch_id, + EMPTY_HASH.encode_hex::() + ); + + let result = parse_and_validate(&input, &[batch]); + assert!(result.is_ok()); + + let validated = result.unwrap(); + assert!(validated.symbols.predicates.contains_key("imported_pred")); + assert!(validated.symbols.predicates.contains_key("intro_pred")); + } + + #[test] + fn test_syntactic_sugar_predicates() { + let input = r#"REQUEST( + GtEq(A["x"], B["y"]) + DictContains(D, K, V) + SetNotContains(S, E) + )"#; + let result = parse_and_validate(input, &[]); + assert!(result.is_ok()); + } +} diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index a11f308..f6d6baa 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -33,8 +33,9 @@ use_predicate_list = { import_name ~ ("," ~ import_name)* } import_name = { identifier | "_" } batch_ref = { hash_hex } -use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ batch_ref } +use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref } use_intro_arg_list = { identifier ~ ("," ~ identifier)* } +intro_predicate_ref = { hash_hex } request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index d202d8b..053062d 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -1,17 +1,27 @@ pub mod error; +pub mod frontend_ast; +pub mod frontend_ast_lower; +pub mod frontend_ast_split; +pub mod frontend_ast_validate; pub mod parser; pub mod pretty_print; -pub mod processor; use std::sync::Arc; pub use error::LangError; pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use pretty_print::PrettyPrint; -pub use processor::process_pest_tree; -use processor::PodlangOutput; -use crate::middleware::{CustomPredicateBatch, Params}; +use crate::{ + frontend::PodRequest, + middleware::{CustomPredicateBatch, Params}, +}; + +#[derive(Debug, Clone, PartialEq)] +pub struct PodlangOutput { + pub custom_batch: Arc, + pub request: PodRequest, +} pub fn parse( input: &str, @@ -19,7 +29,28 @@ pub fn parse( available_batches: &[Arc], ) -> Result { let pairs = parse_podlang(input)?; - processor::process_pest_tree(pairs, params, available_batches).map_err(LangError::from) + let document_pair = pairs + .into_iter() + .next() + .expect("parse_podlang should always return at least one pair for a valid document"); + let document = frontend_ast::parse::parse_document(document_pair)?; + let validated = frontend_ast_validate::validate(document, available_batches)?; + let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?; + + let custom_batch = lowered.batch.unwrap_or_else(|| { + // If no batch, create an empty one + CustomPredicateBatch::new(params, "PodlangBatch".to_string(), vec![]) + }); + + let request = lowered.request.unwrap_or_else(|| { + // If no request, create an empty one + PodRequest::new(vec![]) + }); + + Ok(PodlangOutput { + custom_batch, + request, + }) } #[cfg(test)] @@ -30,7 +61,6 @@ mod tests { use super::*; use crate::{ backends::plonky2::primitives::ec::schnorr::SecretKey, - lang::error::ProcessorError, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard, @@ -963,13 +993,13 @@ mod tests { assert!(result.is_err()); match result.err().unwrap() { - LangError::Processor(e) => match *e { - ProcessorError::BatchNotFound { id, .. } => { + LangError::Validation(e) => match *e { + frontend_ast_validate::ValidationError::BatchNotFound { id, .. } => { assert_eq!(id, unknown_batch_id); } _ => panic!("Expected BatchNotFound error, but got {:?}", e), }, - e => panic!("Expected LangError::Processor, but got {:?}", e), + e => panic!("Expected LangError::Validation, but got {:?}", e), } } @@ -991,16 +1021,18 @@ mod tests { assert!(result.is_err()); match result.err().unwrap() { - LangError::Processor(e) => match *e { - ProcessorError::UndefinedWildcard { - name, pred_name, .. + LangError::Validation(e) => match *e { + frontend_ast_validate::ValidationError::UndefinedWildcard { + name, + pred_name, + .. } => { assert_eq!(name, "user_public_key"); assert_eq!(pred_name, "identity_verified"); } _ => panic!("Expected UndefinedWildcard error, but got {:?}", e), }, - e => panic!("Expected LangError::Processor, but got {:?}", e), + e => panic!("Expected LangError::Validation, but got {:?}", e), } } } diff --git a/src/lang/parser.rs b/src/lang/parser.rs index 69eecc4..000e683 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -12,8 +12,20 @@ pub type Pairs<'a, R> = PestPairs<'a, R>; #[derive(thiserror::Error, Debug)] pub enum ParseError { + #[error("Invalid integer: {0}")] + InvalidInt(String), + #[error("Pest parsing error: {0}")] Pest(Box>), + + #[error("Invalid public key: {0}")] + InvalidPublicKey(String), + + #[error("Invalid secret key: {0}")] + InvalidSecretKey(String), + + #[error("Invalid escape sequence in string: {0}")] + InvalidEscapeSequence(String), } impl From> for ParseError { diff --git a/src/lang/processor.rs b/src/lang/processor.rs deleted file mode 100644 index 1be986a..0000000 --- a/src/lang/processor.rs +++ /dev/null @@ -1,1350 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - -use pest::iterators::{Pair, Pairs}; -use plonky2::field::types::Field; - -use super::error::ProcessorError; -use crate::{ - backends::plonky2::{ - deserialize_bytes, - primitives::ec::{curve::Point, schnorr::SecretKey}, - }, - frontend::{BuilderArg, CustomPredicateBatchBuilder, PodRequest, StatementTmplBuilder}, - lang::parser::Rule, - middleware::{ - self, CustomPredicateBatch, CustomPredicateRef, Hash, IntroPredicateRef, Key, - NativePredicate, Params, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard, F, - VALUE_SIZE, - }, -}; - -fn get_span(pair: &Pair) -> (usize, usize) { - let span = pair.as_span(); - (span.start(), span.end()) -} - -pub fn native_predicate_from_string(s: &str) -> Option { - match s { - // TODO: update any code that still uses ValueOf to use Equal instead - "ValueOf" => Some(NativePredicate::Equal), - "Equal" => Some(NativePredicate::Equal), - "NotEqual" => Some(NativePredicate::NotEqual), - // Syntactic sugar for Gt/GtEq is handled at a later stage - "Gt" => Some(NativePredicate::Gt), - "GtEq" => Some(NativePredicate::GtEq), - "Lt" => Some(NativePredicate::Lt), - "LtEq" => Some(NativePredicate::LtEq), - "Contains" => Some(NativePredicate::Contains), - "NotContains" => Some(NativePredicate::NotContains), - "SumOf" => Some(NativePredicate::SumOf), - "ProductOf" => Some(NativePredicate::ProductOf), - "MaxOf" => Some(NativePredicate::MaxOf), - "HashOf" => Some(NativePredicate::HashOf), - "PublicKeyOf" => Some(NativePredicate::PublicKeyOf), - "SignedBy" => Some(NativePredicate::SignedBy), - "ContainerInsert" => Some(NativePredicate::ContainerInsert), - "ContainerUpdate" => Some(NativePredicate::ContainerUpdate), - "ContainerDelete" => Some(NativePredicate::ContainerDelete), - "DictContains" => Some(NativePredicate::DictContains), - "DictNotContains" => Some(NativePredicate::DictNotContains), - "ArrayContains" => Some(NativePredicate::ArrayContains), - "SetContains" => Some(NativePredicate::SetContains), - "SetNotContains" => Some(NativePredicate::SetNotContains), - "DictInsert" => Some(NativePredicate::DictInsert), - "DictUpdate" => Some(NativePredicate::DictUpdate), - "DictDelete" => Some(NativePredicate::DictDelete), - "SetInsert" => Some(NativePredicate::SetInsert), - "SetDelete" => Some(NativePredicate::SetDelete), - "ArrayUpdate" => Some(NativePredicate::ArrayUpdate), - "None" => Some(NativePredicate::None), - "False" => Some(NativePredicate::False), - _ => None, - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct PodlangOutput { - pub custom_batch: Arc, - pub request: PodRequest, -} - -struct ProcessingContext<'a> { - params: &'a Params, - /// Maps imported predicate names to their full reference (batch and index) - imported_predicates: HashMap, - /// Maps imported intro predicate names to their intro refs - imported_intro_predicates: HashMap, - /// Maps predicate names to their batch index and public argument count (from Pass 1) - custom_predicate_signatures: HashMap, - /// Stores the original Pest pairs for custom predicate definitions for Pass 2 - custom_predicate_pairs: Vec>, - /// Stores the original Pest pair for the request definition for Pass 2 - request_pair: Option>, -} - -impl<'a> ProcessingContext<'a> { - fn new(params: &'a Params) -> Self { - ProcessingContext { - params, - imported_predicates: HashMap::new(), - imported_intro_predicates: HashMap::new(), - custom_predicate_signatures: HashMap::new(), - custom_predicate_pairs: Vec::new(), - request_pair: None, - } - } -} - -pub fn process_pest_tree( - mut pairs_iterator_for_document_rule: Pairs<'_, Rule>, - params: &Params, - available_batches: &[Arc], -) -> Result { - let mut processing_ctx = ProcessingContext::new(params); - - let document_node = pairs_iterator_for_document_rule.next().ok_or_else(|| { - ProcessorError::Internal(format!( - "Parser returned no pairs for the expected top-level rule: {:?}.", - Rule::document - )) - })?; - - if document_node.as_rule() != Rule::document { - return Err(ProcessorError::Internal(format!( - "Expected top-level pair to be Rule::{:?}, but found Rule::{:?}.", - Rule::document, - document_node.as_rule() - ))); - } - - let document_content_pairs = document_node.into_inner(); - - first_pass( - document_content_pairs, - &mut processing_ctx, - available_batches, - )?; - - second_pass(&mut processing_ctx, params) -} - -/// Pass 1: Iterates through top-level definitions, records custom predicate -/// signatures and stores pairs for Pass 2. -fn first_pass<'a>( - document_pairs: Pairs<'a, Rule>, - ctx: &mut ProcessingContext<'a>, - available_batches: &[Arc], -) -> Result<(), ProcessorError> { - let mut defined_custom_names: HashSet = HashSet::new(); - let mut first_request_span: Option<(usize, usize)> = None; - - for pair in document_pairs { - match pair.as_rule() { - Rule::use_batch_statement => { - process_use_batch_statement(&pair, ctx, available_batches)?; - } - Rule::use_intro_statement => { - process_use_intro_statement(&pair, ctx)?; - } - Rule::custom_predicate_def => { - let pred_name_pair = pair - .clone() - .into_inner() - .find(|p| p.as_rule() == Rule::identifier) - .unwrap(); - let pred_name = pred_name_pair.as_str().to_string(); - - if defined_custom_names.contains(&pred_name) - || ctx.imported_predicates.contains_key(&pred_name) - || ctx.imported_intro_predicates.contains_key(&pred_name) - { - return Err(ProcessorError::DuplicateDefinition { - name: pred_name, - span: Some(get_span(&pred_name_pair)), - }); - } - defined_custom_names.insert(pred_name.clone()); - - let public_arity = count_public_args(&pair)?; - ctx.custom_predicate_signatures.insert( - pred_name.clone(), - (ctx.custom_predicate_pairs.len(), public_arity), - ); - ctx.custom_predicate_pairs.push(pair); - } - Rule::request_def => { - if ctx.request_pair.is_some() { - return Err(ProcessorError::MultipleRequestDefinitions { - first_span: first_request_span, - second_span: Some(get_span(&pair)), - }); - } - first_request_span = Some(get_span(&pair)); - ctx.request_pair = Some(pair); - } - Rule::EOI => break, - Rule::COMMENT | Rule::WHITESPACE => {} - _ => { - unreachable!("Unexpected rule: {:?}", pair.as_rule()); - } - } - } - Ok(()) -} - -fn count_public_args(pred_def_pair: &Pair) -> Result { - let arg_section_pair = pred_def_pair - .clone() - .into_inner() - .find(|p| p.as_rule() == Rule::arg_section) - .unwrap(); - - let public_arg_list_pair = arg_section_pair - .into_inner() - .find(|p| p.as_rule() == Rule::public_arg_list) - .unwrap(); - - Ok(public_arg_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - .count()) -} - -fn process_use_batch_statement( - use_pair: &Pair, - ctx: &mut ProcessingContext, - available_batches: &[Arc], -) -> Result<(), ProcessorError> { - let mut inner = use_pair.clone().into_inner(); - - let import_list_pair = inner - .find(|p| p.as_rule() == Rule::use_predicate_list) - .unwrap(); - let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap(); - let batch_id_pair = batch_ref_pair.into_inner().next().unwrap(); - let batch_id_str_full = batch_id_pair.as_str(); - - let batch_id_hex = batch_id_str_full - .strip_prefix("0x") - .unwrap_or(batch_id_str_full); - let batch_id_val = parse_hex_str_to_raw_value(batch_id_hex).map_err(|_| { - ProcessorError::InvalidLiteralFormat { - kind: "batch ID hash".to_string(), - value: batch_id_str_full.to_string(), - span: Some(get_span(&batch_id_pair)), - } - })?; - - let target_batch = available_batches - .iter() - .find(|b| b.id().0 == batch_id_val.0) - .ok_or_else(|| ProcessorError::BatchNotFound { - id: batch_id_str_full.to_string(), - span: Some(get_span(&batch_id_pair)), - })?; - - let import_names: Vec> = import_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::import_name) - .collect(); - - if import_names.len() != target_batch.predicates().len() { - return Err(ProcessorError::ImportArityMismatch { - expected: target_batch.predicates().len(), - found: import_names.len(), - span: Some(get_span(use_pair)), - }); - } - - for (i, import_name_pair) in import_names.into_iter().enumerate() { - if import_name_pair.as_str() == "_" { - continue; - } - - let name = import_name_pair.as_str().to_string(); - - if ctx.imported_predicates.contains_key(&name) - || ctx.imported_intro_predicates.contains_key(&name) - || ctx.custom_predicate_signatures.contains_key(&name) - { - return Err(ProcessorError::DuplicateImportName { - name, - span: Some(get_span(&import_name_pair)), - }); - } - - let custom_pred_ref = CustomPredicateRef::new(target_batch.clone(), i); - ctx.imported_predicates.insert(name, custom_pred_ref); - } - - Ok(()) -} - -fn process_use_intro_statement( - use_pair: &Pair, - ctx: &mut ProcessingContext, -) -> Result<(), ProcessorError> { - let mut inner = use_pair.clone().into_inner(); - - // Structure: identifier, '(', optional arg list, ')', 'from', batch_ref - let name_pair = inner.find(|p| p.as_rule() == Rule::identifier).unwrap(); - let pred_name = name_pair.as_str().to_string(); - - if ctx.imported_predicates.contains_key(&pred_name) - || ctx.imported_intro_predicates.contains_key(&pred_name) - || ctx.custom_predicate_signatures.contains_key(&pred_name) - { - return Err(ProcessorError::DuplicateImportName { - name: pred_name, - span: Some(get_span(&name_pair)), - }); - } - - let args_len = inner - .clone() - .find(|p| p.as_rule() == Rule::use_intro_arg_list) - .map(|arg_list| { - arg_list - .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - .count() - }) - .unwrap_or(0); - - let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap(); - let hash_hex_pair = batch_ref_pair.into_inner().next().unwrap(); - let hash_str_full = hash_hex_pair.as_str(); - let hex_no_prefix = hash_str_full.strip_prefix("0x").unwrap_or(hash_str_full); - let raw_val = parse_hex_str_to_raw_value(hex_no_prefix).map_err(|_| { - ProcessorError::InvalidLiteralFormat { - kind: "intro verifier hash".to_string(), - value: hash_str_full.to_string(), - span: Some(get_span(&hash_hex_pair)), - } - })?; - let verifier_hash: Hash = Hash::from(raw_val); - - let intro_ref = IntroPredicateRef { - name: pred_name.clone(), - args_len, - verifier_data_hash: verifier_hash, - }; - - ctx.imported_intro_predicates.insert(pred_name, intro_ref); - - Ok(()) -} - -enum StatementContext<'a> { - CustomPredicate { - pred_name: &'a str, - argument_names: &'a HashSet, - }, - Request { - custom_batch: &'a Arc, - wildcard_names: &'a mut Vec, - defined_wildcards: &'a mut HashSet, - }, -} - -fn second_pass( - ctx: &mut ProcessingContext, - params: &Params, -) -> Result { - let mut cpb_builder = - CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlangBatch".to_string()); - - for pred_pair in &ctx.custom_predicate_pairs { - process_and_add_custom_predicate_to_batch(params, pred_pair, ctx, &mut cpb_builder)?; - } - - let custom_batch = cpb_builder.finish(); - - let request_templates = if let Some(req_pair) = &ctx.request_pair { - process_request_def(params, req_pair, ctx, &custom_batch)? - } else { - Vec::new() - }; - - Ok(PodlangOutput { - custom_batch, - request: PodRequest::new(request_templates), - }) -} - -fn pest_pair_to_builder_arg( - params: &Params, - arg_content_pair: &Pair, - context: &StatementContext, -) -> Result { - match arg_content_pair.as_rule() { - Rule::literal_value => { - let value = process_literal_value(params, arg_content_pair)?; - Ok(BuilderArg::Literal(value)) - } - Rule::identifier => { - let wc_str = arg_content_pair.as_str(); - if let StatementContext::CustomPredicate { - argument_names, - pred_name, - } = context - { - if !argument_names.contains(wc_str) { - return Err(ProcessorError::UndefinedWildcard { - name: wc_str.to_string(), - pred_name: pred_name.to_string(), - span: Some(get_span(arg_content_pair)), - }); - } - } - Ok(BuilderArg::WildcardLiteral(wc_str.to_string())) - } - Rule::anchored_key => { - let mut inner_ak_pairs = arg_content_pair.clone().into_inner(); - let root_pair = inner_ak_pairs.next().unwrap(); - let root_wc_str = root_pair.as_str(); - - if let StatementContext::CustomPredicate { - argument_names, - pred_name, - } = context - { - if !argument_names.contains(root_wc_str) { - return Err(ProcessorError::UndefinedWildcard { - name: root_wc_str.to_string(), - pred_name: pred_name.to_string(), - span: Some(get_span(arg_content_pair)), - }); - } - } - - let key_part_pair = inner_ak_pairs.next().unwrap(); - let key_str = match key_part_pair.as_rule() { - Rule::literal_string => parse_pest_string_literal(&key_part_pair)?, - Rule::identifier => key_part_pair.as_str().to_string(), - _ => unreachable!( - "unknown key type in anchored key: {:?}", - key_part_pair.as_rule() - ), - }; - Ok(BuilderArg::Key(root_wc_str.to_string(), key_str)) - } - _ => unreachable!("Unexpected rule: {:?}", arg_content_pair.as_rule()), - } -} - -fn validate_dyn_len_predicate( - stmt_name_str: &str, - args: &[BuilderArg], - expected_arity: usize, - stmt_span: (usize, usize), - stmt_name_span: (usize, usize), -) -> Result<(), ProcessorError> { - if args.len() != expected_arity { - return Err(ProcessorError::ArgumentCountMismatch { - predicate: stmt_name_str.to_string(), - expected: expected_arity, - found: args.len(), - span: Some(stmt_name_span), - }); - } - for (idx, arg) in args.iter().enumerate() { - if !matches!(arg, BuilderArg::WildcardLiteral(_) | BuilderArg::Literal(_)) { - return Err(ProcessorError::TypeError { - expected: "Wildcard or Literal".to_string(), - found: format!("{:?}", arg), - item: format!( - "argument {} of custom predicate call '{}'", - idx + 1, - stmt_name_str - ), - span: Some(stmt_span), - }); - } - } - Ok(()) -} - -fn validate_and_build_statement_template( - stmt_name_str: &str, - pred: &Predicate, - args: Vec, - processing_ctx: &ProcessingContext, - stmt_span: (usize, usize), - stmt_name_span: (usize, usize), -) -> Result { - match pred { - Predicate::Native(native_pred) => { - let expected_arity = match native_pred { - NativePredicate::Gt - | NativePredicate::GtEq - | NativePredicate::Equal - | NativePredicate::NotEqual - | NativePredicate::Lt - | NativePredicate::LtEq - | NativePredicate::SetContains - | NativePredicate::DictNotContains - | NativePredicate::SetNotContains - | NativePredicate::NotContains - | NativePredicate::PublicKeyOf - | NativePredicate::SignedBy => 2, - NativePredicate::Contains - | NativePredicate::ArrayContains - | NativePredicate::DictContains - | NativePredicate::SumOf - | NativePredicate::ProductOf - | NativePredicate::MaxOf - | NativePredicate::HashOf - | NativePredicate::ContainerDelete - | NativePredicate::DictDelete - | NativePredicate::SetInsert - | NativePredicate::SetDelete => 3, - NativePredicate::ContainerInsert - | NativePredicate::ContainerUpdate - | NativePredicate::DictInsert - | NativePredicate::DictUpdate - | NativePredicate::ArrayUpdate => 4, - NativePredicate::None | NativePredicate::False => 0, - }; - - if args.len() != expected_arity { - return Err(ProcessorError::ArgumentCountMismatch { - predicate: stmt_name_str.to_string(), - expected: expected_arity, - found: args.len(), - span: Some(stmt_name_span), - }); - } - } - Predicate::Custom(custom_ref) => { - let expected_arity = custom_ref.predicate().args_len; - validate_dyn_len_predicate( - stmt_name_str, - &args, - expected_arity, - stmt_span, - stmt_name_span, - )?; - } - Predicate::Intro(intro_ref) => { - let expected_arity = intro_ref.args_len; - validate_dyn_len_predicate( - stmt_name_str, - &args, - expected_arity, - stmt_span, - stmt_name_span, - )?; - } - Predicate::BatchSelf(_) => { - let (_original_pred_idx, expected_arity_val) = processing_ctx - .custom_predicate_signatures - .get(stmt_name_str) - .ok_or_else(|| { - ProcessorError::Internal(format!( - "Custom predicate signature not found for '{}' during validation", - stmt_name_str - )) - })?; - - if args.len() != *expected_arity_val { - return Err(ProcessorError::ArgumentCountMismatch { - predicate: stmt_name_str.to_string(), - expected: *expected_arity_val, - found: args.len(), - span: Some(stmt_name_span), - }); - } - - for (idx, arg) in args.iter().enumerate() { - if !matches!(arg, BuilderArg::WildcardLiteral(_) | BuilderArg::Literal(_)) { - return Err(ProcessorError::TypeError { - expected: "Wildcard or Literal".to_string(), - found: format!("{:?}", arg), - item: format!( - "argument {} of custom predicate call '{}'", - idx + 1, - stmt_name_str - ), - span: Some(stmt_span), - }); - } - } - } - } - - let mut stb = StatementTmplBuilder::new(pred.clone()); - for arg in args { - stb = stb.arg(arg); - } - Ok(stb.desugar()) -} - -fn process_and_add_custom_predicate_to_batch( - params: &Params, - pred_def_pair: &Pair, - processing_ctx: &ProcessingContext, - cpb_builder: &mut CustomPredicateBatchBuilder, -) -> Result<(), ProcessorError> { - let mut inner_pairs = pred_def_pair.clone().into_inner(); - let name_pair = inner_pairs - .find(|p| p.as_rule() == Rule::identifier) - .unwrap(); - let name = name_pair.as_str().to_string(); - - let arg_section_pair = inner_pairs - .find(|p| p.as_rule() == Rule::arg_section) - .unwrap(); - - let mut public_arg_strings: Vec = Vec::new(); - let mut private_arg_strings: Vec = Vec::new(); - let mut defined_arg_names: HashSet = HashSet::new(); - - for arg_part_pair in arg_section_pair.into_inner() { - match arg_part_pair.as_rule() { - Rule::public_arg_list => { - for arg_ident_pair in arg_part_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - { - let arg_name = arg_ident_pair.as_str().to_string(); - if !defined_arg_names.insert(arg_name.clone()) { - return Err(ProcessorError::DuplicateWildcard { - name: arg_name, - span: Some(get_span(&arg_ident_pair)), - }); - } - public_arg_strings.push(arg_name); - } - } - Rule::private_arg_list => { - for arg_ident_pair in arg_part_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - { - let arg_name = arg_ident_pair.as_str().to_string(); - if !defined_arg_names.insert(arg_name.clone()) { - return Err(ProcessorError::DuplicateWildcard { - name: arg_name, - span: Some(get_span(&arg_ident_pair)), - }); - } - private_arg_strings.push(arg_name); - } - } - Rule::private_kw | Rule::COMMENT | Rule::WHITESPACE => {} - _ if arg_part_pair.as_str() == "," => {} - _ => { - unreachable!("Unexpected rule: {:?}", arg_part_pair.as_rule()); - } - } - } - - let conjunction_type_pair = inner_pairs - .find(|p| p.as_rule() == Rule::conjunction_type) - .unwrap(); - let conjunction = match conjunction_type_pair.as_str() { - "AND" => true, - "OR" => false, - _ => { - unreachable!( - "Invalid conjunction type: {}", - conjunction_type_pair.as_str() - ); - } - }; - - let statement_list_pair = inner_pairs - .find(|p| p.as_rule() == Rule::statement_list) - .unwrap_or_else(|| { - unreachable!("statement_list rule must be present in predicate definition") - }); - - let mut statement_builders = Vec::new(); - for stmt_pair in statement_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::statement) - { - let stb = process_statement_template( - params, - &stmt_pair, - processing_ctx, - &mut StatementContext::CustomPredicate { - pred_name: &name, - argument_names: &defined_arg_names, - }, - )?; - statement_builders.push(stb); - } - - let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect(); - let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect(); - let sts_slice: &[StatementTmplBuilder] = &statement_builders; - if conjunction { - cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?; - } else { - cpb_builder.predicate_or(&name, &public_args_strs, &private_args_strs, sts_slice)?; - } - - Ok(()) -} - -fn process_request_def( - params: &Params, - req_def_pair: &Pair, - processing_ctx: &ProcessingContext, - custom_batch: &Arc, -) -> Result, ProcessorError> { - let mut request_wildcard_names: Vec = Vec::new(); - let mut defined_request_wildcards: HashSet = HashSet::new(); - - let mut request_statement_builders: Vec = Vec::new(); - - if let Some(statement_list_pair) = req_def_pair - .clone() - .into_inner() - .find(|p| p.as_rule() == Rule::statement_list) - { - for stmt_pair in statement_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::statement) - { - let built_stb = process_statement_template( - params, - &stmt_pair, - processing_ctx, - &mut StatementContext::Request { - custom_batch, - wildcard_names: &mut request_wildcard_names, - defined_wildcards: &mut defined_request_wildcards, - }, - )?; - request_statement_builders.push(built_stb); - } - } - - let mut request_templates: Vec = - Vec::with_capacity(request_statement_builders.len()); - for stb in request_statement_builders { - let tmpl = - resolve_request_statement_builder(stb, &request_wildcard_names, processing_ctx.params)?; - request_templates.push(tmpl); - } - - Ok(request_templates) -} - -fn process_statement_template( - params: &Params, - stmt_pair: &Pair, - processing_ctx: &ProcessingContext, - context: &mut StatementContext, -) -> Result { - let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); - let name_pair = inner_stmt_pairs - .find(|p| p.as_rule() == Rule::identifier) - .unwrap(); - let stmt_name_str = name_pair.as_str(); - - let builder_args = parse_statement_args(params, stmt_pair, context)?; - - if let StatementContext::Request { - wildcard_names, - defined_wildcards, - .. - } = context - { - let mut temp_stmt_wildcard_names: Vec = Vec::new(); - for arg in &builder_args { - match arg { - BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()), - BuilderArg::Key(root_wc_str, _key_str) => { - temp_stmt_wildcard_names.push(root_wc_str.clone()); - } - _ => {} - } - } - for name in temp_stmt_wildcard_names { - if defined_wildcards.insert(name.clone()) { - wildcard_names.push(name); - } - } - } - - let middleware_predicate_type = if let Some(native_pred) = - native_predicate_from_string(stmt_name_str) - { - Predicate::Native(native_pred) - } else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) { - Predicate::Custom(custom_ref.clone()) - } else if let Some(intro_ref) = processing_ctx.imported_intro_predicates.get(stmt_name_str) { - Predicate::Intro(intro_ref.clone()) - } else if let Some((pred_index, _expected_arity)) = processing_ctx - .custom_predicate_signatures - .get(stmt_name_str) - { - match context { - StatementContext::CustomPredicate { .. } => Predicate::BatchSelf(*pred_index), - StatementContext::Request { custom_batch, .. } => { - let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index); - Predicate::Custom(custom_pred_ref) - } - } - } else { - return Err(ProcessorError::UndefinedIdentifier { - name: stmt_name_str.to_string(), - span: Some(get_span(&name_pair)), - }); - }; - - let stb = validate_and_build_statement_template( - stmt_name_str, - &middleware_predicate_type, - builder_args, - processing_ctx, - get_span(stmt_pair), - get_span(&name_pair), - )?; - - Ok(stb.desugar()) -} - -fn process_literal_value( - params: &Params, - lit_val_pair: &Pair, -) -> Result { - let inner_lit = lit_val_pair.clone().into_inner().next().unwrap(); - - match inner_lit.as_rule() { - Rule::literal_int => { - let val = inner_lit.as_str().parse::().unwrap(); - Ok(Value::from(val)) - } - Rule::literal_bool => { - let val = inner_lit.as_str().parse::().unwrap(); - Ok(Value::from(val)) - } - Rule::literal_raw => { - let full_literal_str = inner_lit.clone().into_inner().next().unwrap(); - let hex_str_no_prefix = full_literal_str - .as_str() - .strip_prefix("0x") - .unwrap_or(full_literal_str.as_str()); - parse_hex_str_to_raw_value(hex_str_no_prefix) - .map_err(|e| match e { - ProcessorError::InvalidLiteralFormat { kind, value, .. } => { - ProcessorError::InvalidLiteralFormat { - kind, - value, - span: Some(get_span(&inner_lit)), - } - } - ProcessorError::Internal(message) => ProcessorError::InvalidLiteralFormat { - kind: format!("raw hex processing (internal: {})", message), - value: full_literal_str.to_string(), - span: Some(get_span(&inner_lit)), - }, - _ => ProcessorError::InvalidLiteralFormat { - kind: "raw hex processing error".to_string(), - value: full_literal_str.to_string(), - span: Some(get_span(&inner_lit)), - }, - }) - .map(Value::from) - } - Rule::literal_public_key => { - let pk_str_pair = inner_lit.into_inner().next().unwrap(); - let pk_b58 = pk_str_pair.as_str(); - let point: Point = - pk_b58 - .parse() - .map_err(|e| ProcessorError::InvalidLiteralFormat { - kind: "PublicKey".to_string(), - value: format!("{} (error: {})", pk_b58, e), - span: Some(get_span(&pk_str_pair)), - })?; - Ok(Value::from(point)) - } - Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)), - Rule::literal_array => { - let elements: Result, ProcessorError> = inner_lit - .into_inner() - .map(|elem_pair| process_literal_value(params, &elem_pair)) - .collect(); - let middleware_array = - middleware::containers::Array::new(params.max_depth_mt_containers, elements?) - .map_err(|e| { - ProcessorError::Internal(format!("Failed to create Array: {}", e)) - })?; - Ok(Value::from(middleware_array)) - } - Rule::literal_set => { - let elements: Result, ProcessorError> = inner_lit - .into_inner() - .map(|elem_pair| process_literal_value(params, &elem_pair)) - .collect(); - let middleware_set = - middleware::containers::Set::new(params.max_depth_mt_containers, elements?) - .map_err(|e| { - ProcessorError::Internal(format!("Failed to create Set: {}", e)) - })?; - Ok(Value::from(middleware_set)) - } - Rule::literal_dict => { - let pairs: Result, ProcessorError> = inner_lit - .into_inner() - .map(|dict_entry_pair| { - let mut entry_inner = dict_entry_pair.clone().into_inner(); - let key_pair = entry_inner.next().unwrap(); - let val_pair = entry_inner.next().unwrap(); - let key_str = parse_pest_string_literal(&key_pair)?; - let val = process_literal_value(params, &val_pair)?; - Ok((Key::new(key_str), val)) - }) - .collect(); - let middleware_dict = - middleware::containers::Dictionary::new(params.max_depth_mt_containers, pairs?) - .map_err(|e| { - ProcessorError::Internal(format!("Failed to create Dictionary: {}", e)) - })?; - Ok(Value::from(middleware_dict)) - } - Rule::literal_secret_key => { - let sk_str_pair = inner_lit.clone().into_inner().next().unwrap(); - let sk_base64 = sk_str_pair.as_str(); - let bytes = deserialize_bytes(sk_base64).map_err(|_e| { - ProcessorError::InvalidLiteralFormat { - kind: "SecretKey".to_string(), - value: sk_base64.to_string(), - span: Some(get_span(&inner_lit)), - } - })?; - let secret_key = SecretKey::from_bytes(&bytes).map_err(|_e| { - ProcessorError::InvalidLiteralFormat { - kind: "SecretKey".to_string(), - value: sk_base64.to_string(), - span: Some(get_span(&inner_lit)), - } - })?; - Ok(Value::from(secret_key)) - } - _ => unreachable!("Unexpected rule: {:?}", inner_lit.as_rule()), - } -} - -fn parse_pest_string_literal(pair: &Pair) -> Result { - let inner_pair = pair.clone().into_inner().next().unwrap(); - - let raw_content = inner_pair.as_str(); - let mut result = String::with_capacity(raw_content.len()); - let mut chars = raw_content.chars().peekable(); - - while let Some(c) = chars.next() { - if c == '\\' { - match chars.next() { - Some('"') => result.push('"'), - Some('\\') => result.push('\\'), - Some('/') => result.push('/'), - Some('b') => result.push('\x08'), - Some('f') => result.push('\x0C'), - Some('n') => result.push('\n'), - Some('r') => result.push('\r'), - Some('t') => result.push('\t'), - Some('u') => { - let mut hex_code = String::with_capacity(4); - for _ in 0..4 { - hex_code.push(chars.next().ok_or_else(|| { - ProcessorError::InvalidLiteralFormat { - kind: "unicode escape".to_string(), - value: format!("\\u{}... (incomplete)", hex_code), - span: Some(get_span(&inner_pair)), - } - })?); - } - let char_code = u32::from_str_radix(&hex_code, 16).map_err(|_| { - ProcessorError::InvalidLiteralFormat { - kind: "unicode escape".to_string(), - value: format!("\\u{}", hex_code), - span: Some(get_span(&inner_pair)), - } - })?; - result.push(std::char::from_u32(char_code).ok_or_else(|| { - ProcessorError::InvalidLiteralFormat { - kind: "unicode escape (invalid code point)".to_string(), - value: format!("\\u{}", hex_code), - span: Some(get_span(&inner_pair)), - } - })?); - } - Some(other) => { - return Err(ProcessorError::InvalidLiteralFormat { - kind: "escape sequence".to_string(), - value: format!("\\{}", other), - span: Some(get_span(&inner_pair)), - }) - } - None => { - return Err(ProcessorError::InvalidLiteralFormat { - kind: "escape sequence".to_string(), - value: "\\ (ends with escape)".to_string(), - span: Some(get_span(&inner_pair)), - }) - } - } - } else { - result.push(c); - } - } - Ok(result) -} - -// Translates a big-endian hex string to a little-endian RawValue -fn parse_hex_str_to_raw_value(hex_str: &str) -> Result { - let mut v = [F::ZERO; VALUE_SIZE]; - let value_range = 0..VALUE_SIZE; - for i in value_range { - let start_idx = i * 16; - let end_idx = start_idx + 16; - let hex_part = &hex_str[start_idx..end_idx]; - - let u64_val = u64::from_str_radix(hex_part, 16).unwrap(); - v[VALUE_SIZE - i - 1] = F::from_canonical_u64(u64_val); - } - Ok(middleware::RawValue(v)) -} - -// Helper to resolve a wildcard name string to an indexed middleware::Wildcard -// based on an ordered list of names from the current scope (e.g., request or predicate def). -fn resolve_wildcard( - ordered_scope_wildcard_names: &[String], - name_to_resolve: &str, -) -> Result { - ordered_scope_wildcard_names - .iter() - .position(|n| n == name_to_resolve) - .map(|index| Wildcard::new(name_to_resolve.to_string(), index)) - .ok_or_else(|| ProcessorError::UndefinedWildcard { - name: name_to_resolve.to_string(), - pred_name: "REQUEST".to_string(), - span: None, - }) -} - -fn resolve_request_statement_builder( - stb: StatementTmplBuilder, - ordered_request_wildcard_names: &[String], - params: &Params, -) -> Result { - let stb = stb.desugar(); - - let mut middleware_args = Vec::with_capacity(stb.args.len()); - for builder_arg in stb.args { - let mw_arg = match builder_arg { - BuilderArg::Literal(v) => StatementTmplArg::Literal(v), - BuilderArg::Key(root_wc_str, key_str) => { - let root_wc = resolve_wildcard(ordered_request_wildcard_names, &root_wc_str)?; - let key = Key::from(key_str); - StatementTmplArg::AnchoredKey(root_wc, key) - } - BuilderArg::WildcardLiteral(wc_name) => { - let wc = resolve_wildcard(ordered_request_wildcard_names, &wc_name)?; - StatementTmplArg::Wildcard(wc) - } - }; - middleware_args.push(mw_arg); - } - - if middleware_args.len() > params.max_statement_args { - return Err(ProcessorError::Middleware(middleware::Error::max_length( - format!("Arguments for predicate {:?}", stb.predicate), - middleware_args.len(), - params.max_statement_args, - ))); - } - - Ok(StatementTmpl { - pred: stb.predicate, - args: middleware_args, - }) -} - -fn parse_statement_args( - params: &Params, - stmt_pair: &Pair, - context: &StatementContext, -) -> Result, ProcessorError> { - let mut builder_args = Vec::new(); - let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); - - if let Some(arg_list_pair) = inner_stmt_pairs.find(|p| p.as_rule() == Rule::statement_arg_list) - { - for arg_pair in arg_list_pair - .into_inner() - .filter(|p| p.as_rule() == Rule::statement_arg) - { - let arg_content_pair = arg_pair.into_inner().next().unwrap(); - let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair, context)?; - builder_args.push(builder_arg); - } - } - Ok(builder_args) -} - -#[cfg(test)] -mod processor_tests { - use std::collections::HashMap; - - use pest::iterators::Pairs; - - use super::{first_pass, second_pass, ProcessingContext}; - use crate::{ - lang::{ - error::ProcessorError, - parser::{parse_podlang, Rule}, - }, - middleware::Params, - }; - - fn get_document_content_pairs(input: &str) -> Result, ProcessorError> { - let full_parse_tree = parse_podlang(input) - .map_err(|e| ProcessorError::Internal(format!("Test parsing failed: {:?}", e)))?; - - let document_node = full_parse_tree.peek().ok_or_else(|| { - ProcessorError::Internal("Parser returned no pairs for the document rule.".to_string()) - })?; - - if document_node.as_rule() != Rule::document { - return Err(ProcessorError::Internal(format!( - "Expected top-level pair to be Rule::document, but found {:?}.", - document_node.as_rule() - ))); - } - Ok(full_parse_tree.into_iter().next().unwrap().into_inner()) - } - - #[test] - fn test_fp_empty_input() -> Result<(), ProcessorError> { - let input = ""; - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - assert!(ctx.custom_predicate_signatures.is_empty()); - assert!(ctx.custom_predicate_pairs.is_empty()); - assert!(ctx.request_pair.is_none()); - Ok(()) - } - - #[test] - fn test_fp_only_request() -> Result<(), ProcessorError> { - let input = "REQUEST( Equal(A[\"k\"],B.k) )"; // Escaped quotes - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - assert!(ctx.custom_predicate_signatures.is_empty()); - assert!(ctx.custom_predicate_pairs.is_empty()); - assert!(ctx.request_pair.is_some()); - assert_eq!( - ctx.request_pair.as_ref().unwrap().as_rule(), - Rule::request_def - ); - Ok(()) - } - - #[test] - fn test_fp_simple_predicate() -> Result<(), ProcessorError> { - let input = "my_pred(A, B) = AND( Equal(A[\"k\"],B.k) )"; // Escaped quotes - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - assert_eq!(ctx.custom_predicate_signatures.len(), 1); - assert_eq!(ctx.custom_predicate_pairs.len(), 1); - assert!(ctx.request_pair.is_none()); - - let (index, arity) = ctx.custom_predicate_signatures.get("my_pred").unwrap(); - assert_eq!(*index, 0); - assert_eq!(*arity, 2); // A, B - assert_eq!( - ctx.custom_predicate_pairs[0].as_rule(), - Rule::custom_predicate_def - ); - Ok(()) - } - - #[test] - fn test_fp_multiple_predicates() -> Result<(), ProcessorError> { - let input = r#" - pred1(X) = AND( Equal(X["k"],X.k) ) - pred2(Y, Z) = OR( Equal(Y["v"], 123) ) - "#; - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - assert_eq!(ctx.custom_predicate_signatures.len(), 2); - assert_eq!(ctx.custom_predicate_pairs.len(), 2); - - let (idx1, arity1) = ctx.custom_predicate_signatures.get("pred1").unwrap(); - assert_eq!(*idx1, 0); - assert_eq!(*arity1, 1); - - let (idx2, arity2) = ctx.custom_predicate_signatures.get("pred2").unwrap(); - assert_eq!(*idx2, 1); - assert_eq!(*arity2, 2); - Ok(()) - } - - #[test] - fn test_fp_predicate_public_args_count() -> Result<(), ProcessorError> { - let inputs_and_expected_arities = vec![ - ("p1(A) = AND(None()) // One public arg", 1), - ("p3(A,B,C) = AND(None()) // Three public args", 3), - ("p_pub_priv(Pub1, private: Priv1) = AND(None())", 1), - ]; - - for (input_str, expected_arity) in inputs_and_expected_arities { - let pairs = get_document_content_pairs(input_str)?; - let params = Params::default(); - let mut ctx = ProcessingContext { - params: ¶ms, - imported_predicates: HashMap::new(), - imported_intro_predicates: HashMap::new(), - custom_predicate_signatures: HashMap::new(), - custom_predicate_pairs: Vec::new(), - request_pair: None, - }; - first_pass(pairs, &mut ctx, &[])?; - let pred_name = ctx - .custom_predicate_signatures - .keys() - .next() - .expect("No predicate found in test string"); - let (_, arity) = ctx.custom_predicate_signatures.get(pred_name).unwrap(); - assert_eq!(*arity, expected_arity, "Mismatch for input: {}", input_str); - } - Ok(()) - } - - #[test] - fn test_fp_duplicate_predicate() { - let input = r#" - my_pred(A) = AND(None()) - my_pred(B) = OR(None()) - "#; - let pairs = get_document_content_pairs(input).unwrap(); - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - let result = first_pass(pairs, &mut ctx, &[]); - assert!(result.is_err()); - match result.err().unwrap() { - // Use .err().unwrap() for ProcessorError - ProcessorError::DuplicateDefinition { name, .. } => { - assert_eq!(name, "my_pred"); - } - e => panic!("Expected DuplicateDefinition, got {:?}", e), - } - } - - #[test] - fn test_fp_multiple_requests() { - let input = r#" - REQUEST(None()) - REQUEST(None()) - "#; - let pairs = get_document_content_pairs(input).unwrap(); - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - let result = first_pass(pairs, &mut ctx, &[]); - assert!(result.is_err()); - match result.err().unwrap() { - // Use .err().unwrap() for ProcessorError - ProcessorError::MultipleRequestDefinitions { .. } => { /* Correct error */ } - e => panic!("Expected MultipleRequestDefinitions, got {:?}", e), - } - } - - #[test] - fn test_fp_mixed_content() -> Result<(), ProcessorError> { - let input = r#" - pred_one(X) = AND(None()) - REQUEST( pred_one(A) ) - pred_two(Y, Z) = OR(None()) - "#; - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - - assert_eq!(ctx.custom_predicate_signatures.len(), 2); - assert_eq!(ctx.custom_predicate_pairs.len(), 2); - assert!(ctx.request_pair.is_some()); - - let (idx1, arity1) = ctx.custom_predicate_signatures.get("pred_one").unwrap(); - assert_eq!(*idx1, 0); - assert_eq!(*arity1, 1); - - let (idx2, arity2) = ctx.custom_predicate_signatures.get("pred_two").unwrap(); - assert_eq!(*idx2, 1); - assert_eq!(*arity2, 2); - - // Check that the pairs were stored in the correct order and have the correct content (simplistic check) - assert!(ctx.custom_predicate_pairs[0].as_str().contains("pred_one")); - assert!(ctx.custom_predicate_pairs[1].as_str().contains("pred_two")); - assert!(ctx - .request_pair - .as_ref() - .unwrap() - .as_str() - .contains("pred_one(A)")); - - Ok(()) - } - - #[test] - fn test_sp_unknown_predicate() -> Result<(), ProcessorError> { - // Undefined predicates will be flagged as an error on the second pass - let input = r#" - REQUEST( - pred_one(A) - ) - "#; - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - let result = second_pass(&mut ctx, ¶ms); - assert!(result.is_err()); - match result.err().unwrap() { - ProcessorError::UndefinedIdentifier { name, span: _ } => { - assert_eq!(name, "pred_one") - } - e => panic!("Expected UndefinedIdentifier, got {:?}", e), - } - - // Native predicate names are case-sensitive - let input = r#" - REQUEST( - EQUAL(A["b"], C.d) - ) - "#; - let pairs = get_document_content_pairs(input)?; - let params = Params::default(); - let mut ctx = ProcessingContext::new(¶ms); - first_pass(pairs, &mut ctx, &[])?; - let result = second_pass(&mut ctx, ¶ms); - assert!(result.is_err()); - match result.err().unwrap() { - ProcessorError::UndefinedIdentifier { name, span: _ } => { - assert_eq!(name, "EQUAL") - } - e => panic!("Expected UndefinedIdentifier, got {:?}", e), - } - - Ok(()) - } -} diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 6989081..5889d5e 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -1,6 +1,7 @@ use std::{ fmt::{self, Display}, iter, + str::FromStr, }; use plonky2::field::types::Field; @@ -51,6 +52,42 @@ pub enum NativePredicate { ArrayUpdate = 1014, } +impl NativePredicate { + pub fn arity(&self) -> usize { + match self { + NativePredicate::None | NativePredicate::False => 0, + NativePredicate::Equal + | NativePredicate::NotEqual + | NativePredicate::Lt + | NativePredicate::Gt + | NativePredicate::GtEq + | NativePredicate::LtEq + | NativePredicate::NotContains + | NativePredicate::SetNotContains + | NativePredicate::DictNotContains + | NativePredicate::PublicKeyOf + | NativePredicate::SignedBy + | NativePredicate::SetContains => 2, + NativePredicate::Contains + | NativePredicate::DictContains + | NativePredicate::ArrayContains + | NativePredicate::SumOf + | NativePredicate::ProductOf + | NativePredicate::MaxOf + | NativePredicate::HashOf + | NativePredicate::SetInsert + | NativePredicate::SetDelete => 3, + NativePredicate::DictInsert + | NativePredicate::DictUpdate + | NativePredicate::DictDelete + | NativePredicate::ArrayUpdate + | NativePredicate::ContainerInsert + | NativePredicate::ContainerUpdate + | NativePredicate::ContainerDelete => 4, + } + } +} + impl Display for NativePredicate { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { @@ -95,6 +132,45 @@ impl ToFields for NativePredicate { } } +impl FromStr for NativePredicate { + type Err = Error; + fn from_str(s: &str) -> Result { + match s { + "Equal" => Ok(NativePredicate::Equal), + "NotEqual" => Ok(NativePredicate::NotEqual), + "Gt" => Ok(NativePredicate::Gt), + "GtEq" => Ok(NativePredicate::GtEq), + "Lt" => Ok(NativePredicate::Lt), + "LtEq" => Ok(NativePredicate::LtEq), + "Contains" => Ok(NativePredicate::Contains), + "NotContains" => Ok(NativePredicate::NotContains), + "SumOf" => Ok(NativePredicate::SumOf), + "ProductOf" => Ok(NativePredicate::ProductOf), + "MaxOf" => Ok(NativePredicate::MaxOf), + "HashOf" => Ok(NativePredicate::HashOf), + "PublicKeyOf" => Ok(NativePredicate::PublicKeyOf), + "SignedBy" => Ok(NativePredicate::SignedBy), + "ContainerInsert" => Ok(NativePredicate::ContainerInsert), + "ContainerUpdate" => Ok(NativePredicate::ContainerUpdate), + "ContainerDelete" => Ok(NativePredicate::ContainerDelete), + "DictContains" => Ok(NativePredicate::DictContains), + "DictNotContains" => Ok(NativePredicate::DictNotContains), + "ArrayContains" => Ok(NativePredicate::ArrayContains), + "SetContains" => Ok(NativePredicate::SetContains), + "SetNotContains" => Ok(NativePredicate::SetNotContains), + "DictInsert" => Ok(NativePredicate::DictInsert), + "DictUpdate" => Ok(NativePredicate::DictUpdate), + "DictDelete" => Ok(NativePredicate::DictDelete), + "SetInsert" => Ok(NativePredicate::SetInsert), + "SetDelete" => Ok(NativePredicate::SetDelete), + "ArrayUpdate" => Ok(NativePredicate::ArrayUpdate), + "None" => Ok(NativePredicate::None), + "False" => Ok(NativePredicate::False), + _ => Err(Error::custom(format!("Invalid native predicate: {}", s))), + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub struct IntroPredicateRef { pub name: String,