From 541c26458693afe7b9165abce7a578b46ecab034 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Sat, 7 Jun 2025 07:17:23 +0200 Subject: [PATCH] Podlog language v1 (#225) * Initial commit for Podlog language * Spell-checker thinks that 'lits' is a bad abbreviation for 'literals' * Enable SetContains/SetNotContains * Update language based on review feedback * Typo/comment fix * Make native predicates case-sensitive * Enforce max batch size in CustomPredicateBatchBuilder * Remove some unnecessary checks for things handled by the grammar * Clean up more unnecessary error-checking * Typo * Simplify hex processing * Replace various errors with unreachable!() * Translate from big-endian hex string to little-endian RawValue * Update hex en/decoding functions --- Cargo.toml | 2 + src/frontend/custom.rs | 16 +- src/lang/error.rs | 94 +++ src/lang/grammar.pest | 96 +++ src/lang/mod.rs | 682 +++++++++++++++++++ src/lang/parser.rs | 210 ++++++ src/lang/processor.rs | 1120 +++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/middleware/basetypes.rs | 17 +- src/middleware/custom.rs | 10 +- src/middleware/serialization.rs | 40 +- 11 files changed, 2259 insertions(+), 29 deletions(-) create mode 100644 src/lang/error.rs create mode 100644 src/lang/grammar.pest create mode 100644 src/lang/mod.rs create mode 100644 src/lang/parser.rs create mode 100644 src/lang/processor.rs diff --git a/Cargo.toml b/Cargo.toml index 9acc509..cb2b69d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ serde_json = "1.0.140" base64 = "0.22.1" schemars = "0.8.22" hashbrown = { version = "0.14.3", default-features = false, features = ["serde"] } +pest = "2.8.0" +pest_derive = "2.8.0" # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPolygonZero/plonky2"] diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index d88d558..610f1a3 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -31,7 +31,7 @@ pub fn key(s: &str) -> KeyOrWildcardStr { } /// Builder Argument for the StatementTmplBuilder -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum BuilderArg { Literal(Value), /// Key: (origin, key), where origin is SELF or Wildcard and key is Key or Wildcard @@ -79,8 +79,8 @@ pub fn literal(v: impl Into) -> BuilderArg { #[derive(Clone)] pub struct StatementTmplBuilder { - predicate: Predicate, - args: Vec, + pub(crate) predicate: Predicate, + pub(crate) args: Vec, } impl StatementTmplBuilder { @@ -98,7 +98,7 @@ impl StatementTmplBuilder { /// Desugar the predicate to a simpler form /// Should mirror the logic in `MainPodBuilder::lower_op` - fn desugar(self) -> StatementTmplBuilder { + pub(crate) fn desugar(self) -> StatementTmplBuilder { match self.predicate { Predicate::Native(NativePredicate::Gt) => { let mut stb = StatementTmplBuilder { @@ -184,6 +184,14 @@ impl CustomPredicateBatchBuilder { priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { + if self.predicates.len() >= self.params.max_custom_batch_size { + return Err(Error::max_length( + "self.predicates.len".to_string(), + self.predicates.len(), + self.params.max_custom_batch_size, + )); + } + if args.len() > self.params.max_statement_args { return Err(Error::max_length( "args.len".to_string(), diff --git a/src/lang/error.rs b/src/lang/error.rs new file mode 100644 index 0000000..b359f68 --- /dev/null +++ b/src/lang/error.rs @@ -0,0 +1,94 @@ +use thiserror::Error; + +use crate::{frontend, lang::parser::ParseError, middleware}; + +#[derive(Error, Debug)] +pub enum LangError { + #[error("Parsing failed: {0}")] + Parse(Box), + + #[error("AST processing error: {0}")] + Processor(Box), + + #[error("Middleware error during processing: {0}")] + Middleware(Box), + + #[error("Frontend error: {0}")] + Frontend(Box), +} + +/// Errors that can occur during the processing of Podlog Pest tree into middleware structures. +#[derive(thiserror::Error, Debug)] +pub enum ProcessorError { + #[error("Undefined identifier: '{name}' at {span:?}")] + UndefinedIdentifier { + name: String, + span: Option<(usize, usize)>, + }, + #[error("Duplicate definition: '{name}' at {span:?}")] + DuplicateDefinition { + name: String, + span: Option<(usize, usize)>, + }, + #[error("Duplicate wildcard: ?{name} in scope at {span:?}")] + DuplicateWildcard { + name: String, + span: Option<(usize, usize)>, + }, + #[error("Type error: expected {expected}, found {found} for '{item}' at {span:?}")] + TypeError { + expected: String, + found: String, + item: String, + span: Option<(usize, usize)>, + }, + #[error( + "Invalid argument count for '{predicate}': expected {expected}, found {found} at {span:?}" + )] + ArgumentCountMismatch { + predicate: String, + expected: usize, + found: usize, + span: Option<(usize, usize)>, + }, + #[error("Multiple REQUEST definitions found. Only one is allowed. First at {first_span:?}, second at {second_span:?}")] + MultipleRequestDefinitions { + first_span: Option<(usize, usize)>, + second_span: Option<(usize, usize)>, + }, + #[error("Internal processing error: {0}")] + Internal(String), + #[error("Middleware error: {0}")] + Middleware(middleware::Error), + #[error("Undefined wildcard: '?{name}' at {span:?}")] + UndefinedWildcard { + name: String, + span: Option<(usize, usize)>, + }, + #[error("Invalid literal format for {kind}: '{value}' at {span:?}")] + InvalidLiteralFormat { + kind: String, + value: String, + span: Option<(usize, usize)>, + }, + #[error("Frontend error: {0}")] + Frontend(#[from] frontend::Error), +} + +impl From for LangError { + fn from(err: ParseError) -> Self { + LangError::Parse(Box::new(err)) + } +} + +impl From for LangError { + fn from(err: ProcessorError) -> Self { + LangError::Processor(Box::new(err)) + } +} + +impl From for LangError { + fn from(err: middleware::Error) -> Self { + LangError::Middleware(Box::new(err)) + } +} diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest new file mode 100644 index 0000000..bd75359 --- /dev/null +++ b/src/lang/grammar.pest @@ -0,0 +1,96 @@ +// Grammar for the "Podlog" language. Used for describing POD2 Custom +// Predicates and Proof Requests. + +// Silent rules (`_`) are automatically handled by Pest between other rules. +// WHITESPACE matches one or more spaces, tabs, or newlines. +WHITESPACE = _{ (" " | "\t" | NEWLINE)+ } + +// COMMENT matches '//' followed by any characters until the end of the line. +// Also silent. +COMMENT = _{ "//" ~ (!NEWLINE ~ ANY)* } + +// Define rules for identifiers (predicate names, variable names without '?') +// Must start with alpha or _, followed by alpha, numeric, or _ +identifier = @{ !("private") ~ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } + +private_kw = { "private:" } + +self_keyword = @{ "SELF" } + +// Define wildcard names (start with '?') +wildcard = @{ "?" ~ identifier } + +arg_section = { + public_arg_list ~ ("," ~ private_kw ~ private_arg_list)? +} + +public_arg_list = { identifier ~ ("," ~ identifier)* } +private_arg_list = { identifier ~ ("," ~ identifier)* } + +document = { SOI ~ (custom_predicate_def | request_def)* ~ EOI } + +request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" } + +// Define conjunction type explicitly +conjunction_type = { "AND" | "OR" } + +custom_predicate_def = { + identifier + ~ "(" ~ arg_section ~ ")" + ~ "=" + ~ conjunction_type + ~ "(" ~ statement_list ~ ")" +} + +statement_list = { statement+ } + +statement_arg = { anchored_key | wildcard | literal_value } +statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* } + +statement = { identifier ~ "(" ~ statement_arg_list? ~ ")" } + +// Anchored Key: (SELF | ?Var)["key_literal" | ?KeyVar] +anchored_key = { ( self_keyword | wildcard ) ~ "[" ~ (wildcard | literal_string) ~ "]" } + +// Literal Values (ordered to avoid ambiguity, e.g., string before int) +literal_value = { + literal_dict | + literal_set | + literal_array | + literal_bool | + literal_raw | + literal_string | + literal_int +} + +// Primitive literal types +literal_int = @{ "-"? ~ ASCII_DIGIT+ } +literal_bool = @{ "true" | "false" } + +// literal_raw: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters) +// representing a 32-byte value in big-endian order +literal_raw = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} } + +// String literal parsing based on https://pest.rs/book/examples/json.html +literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule +inner = @{ char* } // Atomic rule for the raw inner content +char = { // Rule for a single logical character (unescaped or escaped) + !("\"" | "\\") ~ ANY // Any char except quote or backslash + | "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t") // Simple escape sequences + | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence +} + +// Container Literals (recursive definition using literal_value) +literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } +literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } +literal_dict = { "{" ~ (dict_pair ~ ("," ~ dict_pair)*)? ~ "}" } +dict_pair = { literal_string ~ ":" ~ literal_value } + +// --- Rules for testing full input matching --- +test_identifier = { SOI ~ identifier ~ EOI } +test_wildcard = { SOI ~ wildcard ~ EOI } +test_literal_int = { SOI ~ literal_int ~ EOI } +test_literal_raw = { SOI ~ literal_raw ~ EOI } +test_literal_value = { SOI ~ literal_value ~ EOI } +test_statement = { SOI ~ statement ~ EOI } +test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI } diff --git a/src/lang/mod.rs b/src/lang/mod.rs new file mode 100644 index 0000000..447f6e2 --- /dev/null +++ b/src/lang/mod.rs @@ -0,0 +1,682 @@ +pub mod error; +pub mod parser; +pub mod processor; + +pub use error::LangError; +pub use parser::{parse_podlog, Pairs, ParseError, Rule}; +pub use processor::process_pest_tree; +use processor::ProcessedOutput; + +use crate::middleware::Params; + +pub fn parse(input: &str, params: &Params) -> Result { + let pairs = parse_podlog(input)?; + processor::process_pest_tree(pairs, params).map_err(LangError::from) +} + +#[cfg(test)] +mod tests { + + use pretty_assertions::assert_eq; + + use super::*; + use crate::middleware::{ + CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard, + NativePredicate, Params, PodType, Predicate, SelfOrWildcard, StatementTmpl, + StatementTmplArg, Value, Wildcard, SELF_ID_HASH, + }; + + // Helper functions + fn wc(name: &str, index: usize) -> Wildcard { + Wildcard::new(name.to_string(), index) + } + + fn k(name: &str) -> KeyOrWildcard { + KeyOrWildcard::Key(Key::new(name.to_string())) + } + + fn ko_wc(name: &str, index: usize) -> KeyOrWildcard { + KeyOrWildcard::Wildcard(Wildcard::new(name.to_string(), index)) + } + + fn sta_ak(pod_var: (&str, usize), key_or_wc: KeyOrWildcard) -> StatementTmplArg { + StatementTmplArg::AnchoredKey( + SelfOrWildcard::Wildcard(wc(pod_var.0, pod_var.1)), + key_or_wc, + ) + } + + fn sta_ak_self(key_or_wc: KeyOrWildcard) -> StatementTmplArg { + StatementTmplArg::AnchoredKey(SelfOrWildcard::SELF, key_or_wc) + } + + fn sta_lit(value: impl Into) -> StatementTmplArg { + StatementTmplArg::Literal(value.into()) + } + + #[test] + fn test_e2e_simple_predicate() -> Result<(), LangError> { + let input = r#" + is_equal(PodA, PodB) = AND( + Equal(?PodA["the_key"], ?PodB["the_key"]) + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_result = processed.request_templates; + + assert_eq!(request_result.len(), 0); + assert_eq!(batch_result.predicates.len(), 1); + + let batch = batch_result; + + // Expected structure + let expected_statements = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("PodA", 0), k("the_key")), // ?PodA["the_key"] -> Wildcard(0), Key("the_key") + sta_ak(("PodB", 1), k("the_key")), // ?PodB["the_key"] -> Wildcard(1), Key("the_key") + ], + }]; + let expected_predicate = CustomPredicate::and( + ¶ms, + "is_equal".to_string(), + expected_statements, + 2, // args_len (PodA, PodB) + )?; + let expected_batch = + CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + + assert_eq!(batch, expected_batch); + + Ok(()) + } + + #[test] + fn test_e2e_simple_request() -> Result<(), LangError> { + let input = r#" + REQUEST( + ValueOf(?ConstPod["my_val"], 0x0000000000000000000000000000000000000000000000000000000000000001) + Lt(?GovPod["dob"], ?ConstPod["my_val"]) + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_templates = processed.request_templates; + + assert_eq!(batch_result.predicates.len(), 0); + assert!(!request_templates.is_empty()); + + let request_templates = request_templates; + + // Expected structure + let expected_templates = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + sta_ak(("ConstPod", 0), k("my_val")), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val") + sta_lit(SELF_ID_HASH), + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Lt), + args: vec![ + sta_ak(("GovPod", 1), k("dob")), // ?GovPod["dob"] -> Wildcard(1), Key("dob") + sta_ak(("ConstPod", 0), k("my_val")), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val") + ], + }, + ]; + + assert_eq!(request_templates, expected_templates); + + Ok(()) + } + + #[test] + fn test_e2e_predicate_with_private_var() -> Result<(), LangError> { + let input = r#" + uses_private(A, private: Temp) = AND( + Equal(?A["input_key"], ?Temp["const_key"]) + ValueOf(?Temp["const_key"], "some_value") + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_result = processed.request_templates; + + assert_eq!(request_result.len(), 0); + assert_eq!(batch_result.predicates.len(), 1); + + let batch = batch_result; + + // Expected structure: Public args: A (index 0). Private args: Temp (index 1) + let expected_statements = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("A", 0), k("input_key")), // ?A["input_key"] -> Wildcard(0), Key("input_key") + sta_ak(("Temp", 1), k("const_key")), // ?Temp["const_key"] -> Wildcard(1), Key("const_key") + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + sta_ak(("Temp", 1), k("const_key")), // ?Temp["const_key"] -> Wildcard(1), Key("const_key") + sta_lit("some_value"), // Literal("some_value") + ], + }, + ]; + let expected_predicate = CustomPredicate::and( + ¶ms, + "uses_private".to_string(), + expected_statements, + 1, // args_len (A) + )?; + let expected_batch = + CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + + assert_eq!(batch, expected_batch); + + Ok(()) + } + + #[test] + fn test_e2e_request_with_custom_call() -> Result<(), LangError> { + let input = r#" + my_pred(X, Y) = AND( + Equal(?X["val"], ?Y["val"]) + ) + + REQUEST( + my_pred(?Pod1, ?Pod2) + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_templates = processed.request_templates; + + assert_eq!(batch_result.predicates.len(), 1); + assert!(!request_templates.is_empty()); + + let batch = batch_result; + let request_templates = request_templates; + + // Expected Batch structure + let expected_pred_statements = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("X", 0), k("val")), // ?X["val"] -> Wildcard(0), Key("val") + sta_ak(("Y", 1), k("val")), // ?Y["val"] -> Wildcard(1), Key("val") + ], + }]; + let expected_predicate = CustomPredicate::and( + ¶ms, + "my_pred".to_string(), + expected_pred_statements, + 2, // args_len (X, Y) + )?; + let expected_batch = + CustomPredicateBatch::new(¶ms, "PodlogBatch".to_string(), vec![expected_predicate]); + + assert_eq!(batch, expected_batch); + + // Expected Request structure + // Pod1 -> Wildcard 0, Pod2 -> Wildcard 1 + let expected_request_templates = vec![StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)), + args: vec![ + StatementTmplArg::WildcardLiteral(wc("Pod1", 0)), + StatementTmplArg::WildcardLiteral(wc("Pod2", 1)), + ], + }]; + + assert_eq!(request_templates, expected_request_templates); + + Ok(()) + } + + #[test] + fn test_e2e_request_with_various_args() -> Result<(), LangError> { + let input = r#" + some_pred(A, B, C) = AND( Equal(?A["foo"], ?B["bar"]) ) + + REQUEST( + some_pred( + ?Var1, // Wildcard + 12345, // Int Literal + "hello_string" // String Literal (Removed invalid AK args) + ) + Equal(?AnotherPod["another_key"], ?Var1["some_field"]) + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_templates = processed.request_templates; + + assert_eq!(batch_result.predicates.len(), 1); // some_pred is defined + assert!(!request_templates.is_empty()); + + let request_templates = request_templates; + + // Expected Wildcard Indices in Request Scope: + // ?Var1 -> 0 + // ?AnotherPod -> 1 + + // Expected structure + let expected_templates = vec![ + StatementTmpl { + pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred + args: vec![ + StatementTmplArg::WildcardLiteral(wc("Var1", 0)), // ?Var1 + StatementTmplArg::Literal(Value::from(12345i64)), // 12345 + StatementTmplArg::Literal(Value::from("hello_string")), // "hello_string" + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + // ?AnotherPod["another_key"] -> Wildcard(1), Key("another_key") + sta_ak(("AnotherPod", 1), k("another_key")), + // ?Var1["some_field"] -> Wildcard(0), Key("some_field") + sta_ak(("Var1", 0), k("some_field")), + ], + }, + ]; + + assert_eq!(request_templates, expected_templates); + + Ok(()) + } + + #[test] + fn test_e2e_syntactic_sugar_predicates() -> Result<(), LangError> { + let input = r#" + REQUEST( + GtEq(?A["foo"], ?B["bar"]) + Gt(?C["baz"], ?D["qux"]) + DictContains(?A["foo"], ?B["bar"], ?C["baz"]) + DictNotContains(?A["foo"], ?B["bar"]) + ArrayContains(?A["foo"], ?B["bar"], ?C["baz"]) + ) + "#; + + let params = Params::default(); + let pairs = parse_podlog(input)?; + let processed = process_pest_tree(pairs, ¶ms)?; + let batch_result = processed.custom_batch; + let request_templates = processed.request_templates; + + assert_eq!(batch_result.predicates.len(), 0); + assert!(!request_templates.is_empty()); + + let request_templates = request_templates; + + let expected_templates = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::LtEq), + args: vec![sta_ak(("B", 1), k("bar")), sta_ak(("A", 0), k("foo"))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Lt), + args: vec![sta_ak(("D", 3), k("qux")), sta_ak(("C", 2), k("baz"))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Contains), + args: vec![ + sta_ak(("A", 0), k("foo")), + sta_ak(("B", 1), k("bar")), + sta_ak(("C", 2), k("baz")), + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::NotContains), + args: vec![sta_ak(("A", 0), k("foo")), sta_ak(("B", 1), k("bar"))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Contains), + args: vec![ + sta_ak(("A", 0), k("foo")), + sta_ak(("B", 1), k("bar")), + sta_ak(("C", 2), k("baz")), + ], + }, + ]; + + assert_eq!(request_templates, expected_templates); + + Ok(()) + } + + #[test] + fn test_e2e_zukyc_request_parsing() -> Result<(), LangError> { + let input = r#" + REQUEST( + // Order matters for comparison with the hardcoded templates + SetNotContains(?sanctions["sanctionList"], ?gov["idNumber"]) + Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"]) + Equal(?pay["startDate"], ?SELF_HOLDER_1Y["const_1y"]) + Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"]) + ValueOf(?SELF_HOLDER_18Y["const_18y"], 1169909388) + ValueOf(?SELF_HOLDER_1Y["const_1y"], 1706367566) + ) + "#; + + // Parse the input string + let processed = super::parse(input, &Params::default())?; + let parsed_templates = processed.request_templates; + + // Define Expected Templates (Copied from prover/mod.rs) + let now_minus_18y_val = Value::from(1169909388_i64); + let now_minus_1y_val = Value::from(1706367566_i64); + + // Define wildcards and keys for the request + // Note: Indices must match the order of appearance in the *parsed* request + // Order: sanctions, gov, SELF_HOLDER_18Y, pay, SELF_HOLDER_1Y + let wc_sanctions = wc("sanctions", 0); + let wc_gov = wc("gov", 1); + let wc_self_18y = wc("SELF_HOLDER_18Y", 2); + let wc_pay = wc("pay", 3); + let wc_self_1y = wc("SELF_HOLDER_1Y", 4); + + let id_num_key = k("idNumber"); + let dob_key = k("dateOfBirth"); + let const_18y_key = k("const_18y"); + let start_date_key = k("startDate"); + let const_1y_key = k("const_1y"); + let ssn_key = k("socialSecurityNumber"); + let sanction_list_key = k("sanctionList"); + + // Define the request templates using wildcards for constants + let expected_templates = vec![ + // 1. NotContains(?sanctions["sanctionList"], ?gov["idNumber"]) + StatementTmpl { + pred: Predicate::Native(NativePredicate::NotContains), + args: vec![ + sta_ak( + (wc_sanctions.name.as_str(), wc_sanctions.index), + sanction_list_key.clone(), + ), + sta_ak((wc_gov.name.as_str(), wc_gov.index), id_num_key.clone()), + ], + }, + // 2. Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"]) + StatementTmpl { + pred: Predicate::Native(NativePredicate::Lt), + args: vec![ + sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key.clone()), + sta_ak( + (wc_self_18y.name.as_str(), wc_self_18y.index), + const_18y_key.clone(), + ), + ], + }, + // 3. Equal(?pay["startDate"], ?SELF_HOLDER_1Y["const_1y"]) + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key.clone()), + sta_ak( + (wc_self_1y.name.as_str(), wc_self_1y.index), + const_1y_key.clone(), + ), + ], + }, + // 4. Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"]) + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key.clone()), + sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key.clone()), + ], + }, + // 5. ValueOf(?SELF_HOLDER_18Y["const_18y"], 1169909388) + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + sta_ak( + (wc_self_18y.name.as_str(), wc_self_18y.index), + const_18y_key.clone(), + ), + sta_lit(now_minus_18y_val.clone()), + ], + }, + // 6. ValueOf(?SELF_HOLDER_1Y["const_1y"], 1706367566) + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + sta_ak( + (wc_self_1y.name.as_str(), wc_self_1y.index), + const_1y_key.clone(), + ), + sta_lit(now_minus_1y_val.clone()), + ], + }, + ]; + + assert_eq!( + parsed_templates, expected_templates, + "Parsed ZuKYC request templates do not match the expected hard-coded version" + ); + + assert!( + processed.custom_batch.predicates.is_empty(), + "Expected no custom predicates for a REQUEST only input" + ); + + Ok(()) + } + + #[test] + fn test_e2e_ethdos_predicates() -> Result<(), LangError> { + let params = Params { + max_input_signed_pods: 3, + max_input_recursive_pods: 3, + max_statements: 31, + max_signed_pod_values: 8, + max_public_statements: 10, + max_statement_args: 6, + max_operation_args: 5, + max_custom_predicate_arity: 5, + max_custom_batch_size: 5, + max_custom_predicate_wildcards: 12, + ..Default::default() + }; + + let input = r#" + eth_friend(src_key, dst_key, private: attestation_pod) = AND( + ValueOf(?attestation_pod["_type"], 1) + Equal(?attestation_pod["_signer"], SELF[?src_key]) + Equal(?attestation_pod["attestation"], SELF[?dst_key]) + ) + + eth_dos_distance_base(src_key, dst_key, distance_key) = AND( + Equal(SELF[?src_key], SELF[?dst_key]) + ValueOf(SELF[?distance_key], 0) + ) + + eth_dos_distance_ind(src_key, dst_key, distance_key, private: one_key, shorter_distance_key, intermed_key) = AND( + eth_dos_distance(?src_key, ?dst_key, ?distance_key) + ValueOf(SELF[?one_key], 1) + SumOf(SELF[?distance_key], SELF[?shorter_distance_key], SELF[?one_key]) + eth_friend(?intermed_key, ?dst_key) + ) + + eth_dos_distance(src_key, dst_key, distance_key, private: intermed_key, shorter_distance_key) = OR( + eth_dos_distance_base(?src_key, ?dst_key, ?distance_key) + eth_dos_distance_ind(?src_key, ?dst_key, ?distance_key) + ) + "#; + + let processed = super::parse(input, ¶ms)?; + + assert!( + processed.request_templates.is_empty(), + "Expected no request templates" + ); + assert_eq!( + processed.custom_batch.predicates.len(), + 4, + "Expected 4 custom predicates" + ); + + // Predicate Order: eth_friend (0), base (1), ind (2), distance (3) + + // eth_friend (Index 0) + let expected_friend_stmts = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + sta_ak(("attestation_pod", 2), k("_type")), // Pub(0-1), Priv(2) + sta_lit(PodType::MockSigned), + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("attestation_pod", 2), k("_signer")), + sta_ak_self(ko_wc("src_key", 0)), // Pub arg 0 + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak(("attestation_pod", 2), k("attestation")), + sta_ak_self(ko_wc("dst_key", 1)), // Pub arg 1 + ], + }, + ]; + let expected_friend_pred = CustomPredicate::new( + ¶ms, + "eth_friend".to_string(), + true, // AND + expected_friend_stmts, + 2, // public_args_len: src_key, dst_key + )?; + + // eth_dos_distance_base (Index 1) + let expected_base_stmts = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + sta_ak_self(ko_wc("src_key", 0)), + sta_ak_self(ko_wc("dst_key", 1)), + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![sta_ak_self(ko_wc("distance_key", 2)), sta_lit(0i64)], + }, + ]; + let expected_base_pred = CustomPredicate::new( + ¶ms, + "eth_dos_distance_base".to_string(), + true, // AND + expected_base_stmts, + 3, // public_args_len + )?; + + // eth_dos_distance_ind (Index 2) + // Public args indices: 0-2 + // Private args indices: 3-5 (one_key, shorter_distance_key, intermed_key) + let expected_ind_stmts = vec![ + StatementTmpl { + pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3) + args: vec![ + // WildcardLiteral args + StatementTmplArg::WildcardLiteral(wc("src_key", 0)), + StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), // private arg + StatementTmplArg::WildcardLiteral(wc("distance_key", 2)), // private arg + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![sta_ak_self(ko_wc("one_key", 3)), sta_lit(1i64)], // private arg + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::SumOf), + args: vec![ + sta_ak_self(ko_wc("distance_key", 2)), // public arg + sta_ak_self(ko_wc("shorter_distance_key", 4)), // private arg + sta_ak_self(ko_wc("one_key", 3)), // private arg + ], + }, + StatementTmpl { + pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0) + args: vec![ + // WildcardLiteral args + StatementTmplArg::WildcardLiteral(wc("intermed_key", 5)), // private arg + StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), // public arg + ], + }, + ]; + let expected_ind_pred = CustomPredicate::new( + ¶ms, + "eth_dos_distance_ind".to_string(), + true, // AND + expected_ind_stmts, + 3, // public_args_len + )?; + + // eth_dos_distance (Index 3) + let expected_dist_stmts = vec![ + StatementTmpl { + pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1) + args: vec![ + // WildcardLiteral args + StatementTmplArg::WildcardLiteral(wc("src_key", 0)), + StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), + StatementTmplArg::WildcardLiteral(wc("distance_key", 2)), + ], + }, + StatementTmpl { + pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2) + args: vec![ + // WildcardLiteral args + StatementTmplArg::WildcardLiteral(wc("src_key", 0)), + StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), + StatementTmplArg::WildcardLiteral(wc("distance_key", 2)), + ], + }, + ]; + let expected_dist_pred = CustomPredicate::new( + ¶ms, + "eth_dos_distance".to_string(), + false, // OR + expected_dist_stmts, + 3, // public_args_len + )?; + + let expected_batch = CustomPredicateBatch::new( + ¶ms, + "PodlogBatch".to_string(), + vec![ + expected_friend_pred, + expected_base_pred, + expected_ind_pred, + expected_dist_pred, + ], + ); + + assert_eq!( + processed.custom_batch, expected_batch, + "Processed ETHDoS predicates do not match expected structure" + ); + + Ok(()) + } +} diff --git a/src/lang/parser.rs b/src/lang/parser.rs new file mode 100644 index 0000000..e2e6634 --- /dev/null +++ b/src/lang/parser.rs @@ -0,0 +1,210 @@ +use pest::{iterators::Pairs as PestPairs, Parser}; +use pest_derive::Parser; + +// Derive the parser from the grammar file +// The Rust analyzer will only reload the grammar file when *this* file is recompiled, +// and changes to the grammar file will not automatically trigger a recompile. +#[derive(Parser)] +#[grammar = "lang/grammar.pest"] +pub struct PodlogParser; + +pub type Pairs<'a, R> = PestPairs<'a, R>; + +#[derive(thiserror::Error, Debug)] +pub enum ParseError { + #[error("Pest parsing error: {0}")] + Pest(Box>), +} + +impl From> for ParseError { + fn from(err: pest::error::Error) -> Self { + ParseError::Pest(Box::new(err)) + } +} + +/// Parses a Podlog input string according to the grammar rules. +pub fn parse_podlog(input: &str) -> Result, ParseError> { + let pairs = PodlogParser::parse(Rule::document, input)?; + Ok(pairs) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_parses(rule: Rule, input: &str) { + match PodlogParser::parse(rule, input) { + Ok(_) => (), // Successfully parsed + Err(e) => panic!("Failed to parse input:\n{}\nError: {}", input, e), + } + } + + fn assert_fails(rule: Rule, input: &str) { + match PodlogParser::parse(rule, input) { + Ok(pairs) => panic!( + "Expected parse to fail, but it succeeded. Parsed:\n{:#?}", + pairs + ), + Err(_) => (), // Failed as expected + } + } + + #[test] + fn test_parse_empty() { + assert_parses(Rule::document, ""); + assert_parses(Rule::document, " "); + assert_parses(Rule::document, "\n\n"); + assert_parses(Rule::document, "// comment only"); + } + + #[test] + fn test_parse_comment() { + assert_parses(Rule::document, "// This is a comment\n"); + assert_parses(Rule::document, " // Indented comment"); + } + + #[test] + fn test_parse_identifier() { + assert_parses(Rule::identifier, "my_pred"); + assert_parses(Rule::identifier, "_internal"); + assert_parses(Rule::identifier, "ValidName123"); + assert_fails(Rule::test_identifier, "?invalid"); // Use test rule + assert_fails(Rule::test_identifier, "1_invalid_start"); // Use test rule + assert_fails(Rule::test_identifier, "invalid-char"); // Use test rule + } + + #[test] + fn test_parse_wildcard() { + assert_parses(Rule::wildcard, "?Var"); + assert_parses(Rule::wildcard, "?_Internal"); + assert_parses(Rule::wildcard, "?X1"); + assert_fails(Rule::test_wildcard, "NotAVar"); // Use test rule + assert_fails(Rule::test_wildcard, "?"); // Use test rule + assert_fails(Rule::test_wildcard, "?invalid-char"); // Use test rule + } + + #[test] + fn test_parse_anchored_key() { + assert_parses(Rule::anchored_key, "?PodVar[\"literal_key\"]"); + assert_parses(Rule::anchored_key, "?PodVar[?KeyVar]"); + assert_parses(Rule::anchored_key, "SELF[?KeyVar]"); + assert_parses(Rule::anchored_key, "SELF[\"literal_key\"]"); + assert_fails(Rule::anchored_key, "PodVar[\"key\"]"); // Needs wildcard for pod + assert_fails(Rule::anchored_key, "?PodVar[invalid_key]"); // Key must be literal string or wildcard + assert_fails(Rule::anchored_key, "?PodVar[]"); // Key cannot be empty + } + + #[test] + fn test_parse_literals() { + // Int + assert_parses(Rule::literal_int, "123"); + assert_parses(Rule::literal_int, "-45"); + assert_parses(Rule::literal_int, "0"); + assert_fails(Rule::test_literal_int, "1.23"); // Use test_literal_int rule + // Bool + assert_parses(Rule::literal_bool, "true"); + assert_parses(Rule::literal_bool, "false"); + + // Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements) + assert_parses( + Rule::literal_raw, + "0x0000000000000000000000000000000000000000000000000000000000000000", + ); + assert_parses( + Rule::literal_raw, + "0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", + ); + let long_valid_raw = format!("0x{}", "a".repeat(64)); + assert_parses(Rule::literal_raw, &long_valid_raw); + + // Use anchored rule for failure cases + assert_fails(Rule::test_literal_raw, "0xabc"); // Fails (string is too short) + assert_fails(Rule::test_literal_raw, "0x"); // Fails (needs at least one pair) + assert_fails(Rule::test_literal_raw, &format!("0x{}", "a".repeat(66))); // Fails (string is too long) + + // String + assert_parses(Rule::literal_string, "\"hello\""); + assert_parses(Rule::literal_string, "\"escaped \\\" quote\""); + assert_parses(Rule::literal_string, "\"\\\\ backslash\""); + assert_parses(Rule::literal_string, "\"\\uABCD\""); + assert_fails(Rule::literal_string, "\"unterminated"); + // Array + assert_parses(Rule::literal_array, "[]"); + assert_parses(Rule::literal_array, "[1, \"two\", true]"); + assert_parses(Rule::literal_array, "[ [1], #[2] ]"); + // Set + assert_parses(Rule::literal_set, "#[]"); + assert_parses(Rule::literal_set, "#[1, 2, 3]"); + assert_parses( + Rule::literal_set, + "#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]", + ); + // Dict + assert_parses(Rule::literal_dict, "{}"); + assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }"); + assert_parses(Rule::literal_dict, "{ \"nested\": { \"key\": 1 } }"); + assert_parses( + Rule::literal_dict, + "{ \"raw_val\": 0x0000000000000000000000000000000000000000000000000000000000000000 } ", + ); + assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes + } + + #[test] + fn test_parse_simple_request() { + assert_parses(Rule::request_def, "REQUEST()"); + assert_parses( + Rule::request_def, + // Trimmed leading/trailing whitespace + r#"REQUEST( + // Check equality + Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"]) + // Check age > 18 + ValueOf(?const_holder["const_18y"], 1169909388) + Lt(?gov["dateOfBirth"], ?const_holder["const_18y"]) + )"#, + ); + } + + #[test] + fn test_parse_simple_custom_def() { + assert_parses( + Rule::test_custom_predicate_def, + // Trimmed leading/trailing whitespace + r#"my_pred(A, B) = AND( + Equal(?A["foo"], ?B["bar"]) + )"#, + ); + assert_parses( + Rule::test_custom_predicate_def, + // Trimmed leading/trailing whitespace + r#"pred_with_private(X, private: TempKey) = OR( + Equal(?X[?TempKey], ?X["other"]) + )"#, + ); + assert_fails( + Rule::test_custom_predicate_def, + r#"pred_no_stmts(A,B) = AND()"#, + ); + } + + #[test] + fn test_parse_document() { + assert_parses( + Rule::document, + r#"// File defining one predicate and one request + is_valid_user(UserPod, private: ConstVal) = AND( + // User age must be > 18 (using a constant value) + ValueOf(?ConstVal["min_age"], 18) + Gt(?UserPod["age"], ?ConstVal["min_age"]) + // User must not be banned + NotContains(?_BANNED_USERS["list"], ?UserPod["userId"]) + ) + + REQUEST( + is_valid_user(?SomeUser) + Equal(?SomeUser["country"], ?Other["country"]) + )"#, + ); + } +} diff --git a/src/lang/processor.rs b/src/lang/processor.rs new file mode 100644 index 0000000..791c778 --- /dev/null +++ b/src/lang/processor.rs @@ -0,0 +1,1120 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use pest::iterators::{Pair, Pairs}; +use plonky2::field::types::Field; + +use super::error::ProcessorError; +use crate::{ + frontend::{ + BuilderArg, CustomPredicateBatchBuilder, KeyOrWildcardStr, SelfOrWildcardStr, + StatementTmplBuilder, + }, + lang::parser::Rule, + middleware::{ + self, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard, NativePredicate, + Params, Predicate, SelfOrWildcard as MiddlewareSelfOrWildcard, StatementTmpl, + StatementTmplArg, Value, Wildcard, F, VALUE_SIZE, + }, +}; + +fn get_span(pair: &Pair) -> (usize, usize) { + let span = pair.as_span(); + (span.start(), span.end()) +} + +pub fn native_predicate_from_string(s: &str) -> Option { + match s { + "ValueOf" => Some(NativePredicate::ValueOf), + "Equal" => Some(NativePredicate::Equal), + "NotEqual" => Some(NativePredicate::NotEqual), + // Syntactic sugar for Gt/GtEq is handled at a later stage + "Gt" => Some(NativePredicate::Gt), + "GtEq" => Some(NativePredicate::GtEq), + "Lt" => Some(NativePredicate::Lt), + "LtEq" => Some(NativePredicate::LtEq), + "Contains" => Some(NativePredicate::Contains), + "NotContains" => Some(NativePredicate::NotContains), + "SumOf" => Some(NativePredicate::SumOf), + "ProductOf" => Some(NativePredicate::ProductOf), + "MaxOf" => Some(NativePredicate::MaxOf), + "HashOf" => Some(NativePredicate::HashOf), + "DictContains" => Some(NativePredicate::DictContains), + "DictNotContains" => Some(NativePredicate::DictNotContains), + "ArrayContains" => Some(NativePredicate::ArrayContains), + "SetContains" => Some(NativePredicate::SetContains), + "SetNotContains" => Some(NativePredicate::SetNotContains), + "None" => Some(NativePredicate::None), + "False" => Some(NativePredicate::False), + _ => None, + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ProcessedOutput { + pub custom_batch: Arc, + pub request_templates: Vec, +} + +struct ProcessingContext<'a> { + params: &'a Params, + /// Maps predicate names to their batch index and public argument count (from Pass 1) + custom_predicate_signatures: HashMap, + /// Stores the original Pest pairs for custom predicate definitions for Pass 2 + custom_predicate_pairs: Vec>, + /// Stores the original Pest pair for the request definition for Pass 2 + request_pair: Option>, +} + +impl<'a> ProcessingContext<'a> { + fn new(params: &'a Params) -> Self { + ProcessingContext { + params, + custom_predicate_signatures: HashMap::new(), + custom_predicate_pairs: Vec::new(), + request_pair: None, + } + } +} + +pub fn process_pest_tree( + mut pairs_iterator_for_document_rule: Pairs<'_, Rule>, + params: &Params, +) -> Result { + let mut processing_ctx = ProcessingContext::new(params); + + let document_node = pairs_iterator_for_document_rule.next().ok_or_else(|| { + ProcessorError::Internal(format!( + "Parser returned no pairs for the expected top-level rule: {:?}.", + Rule::document + )) + })?; + + if document_node.as_rule() != Rule::document { + return Err(ProcessorError::Internal(format!( + "Expected top-level pair to be Rule::{:?}, but found Rule::{:?}.", + Rule::document, + document_node.as_rule() + ))); + } + + let document_content_pairs = document_node.into_inner(); + + first_pass(document_content_pairs, &mut processing_ctx)?; + + second_pass(&mut processing_ctx) +} + +/// Pass 1: Iterates through top-level definitions, records custom predicate +/// signatures and stores pairs for Pass 2. +fn first_pass<'a>( + document_pairs: Pairs<'a, Rule>, + ctx: &mut ProcessingContext<'a>, +) -> Result<(), ProcessorError> { + let mut defined_custom_names: HashSet = HashSet::new(); + let mut first_request_span: Option<(usize, usize)> = None; + + for pair in document_pairs { + match pair.as_rule() { + Rule::custom_predicate_def => { + let pred_name_pair = pair + .clone() + .into_inner() + .find(|p| p.as_rule() == Rule::identifier) + .unwrap(); + let pred_name = pred_name_pair.as_str().to_string(); + + if defined_custom_names.contains(&pred_name) { + return Err(ProcessorError::DuplicateDefinition { + name: pred_name, + span: Some(get_span(&pred_name_pair)), + }); + } + defined_custom_names.insert(pred_name.clone()); + + let public_arity = count_public_args(&pair)?; + ctx.custom_predicate_signatures.insert( + pred_name.clone(), + (ctx.custom_predicate_pairs.len(), public_arity), + ); + ctx.custom_predicate_pairs.push(pair); + } + Rule::request_def => { + if ctx.request_pair.is_some() { + return Err(ProcessorError::MultipleRequestDefinitions { + first_span: first_request_span, + second_span: Some(get_span(&pair)), + }); + } + first_request_span = Some(get_span(&pair)); + ctx.request_pair = Some(pair); + } + Rule::EOI => break, + Rule::COMMENT | Rule::WHITESPACE => {} + _ => { + unreachable!("Unexpected rule: {:?}", pair.as_rule()); + } + } + } + Ok(()) +} + +fn count_public_args(pred_def_pair: &Pair) -> Result { + let arg_section_pair = pred_def_pair + .clone() + .into_inner() + .find(|p| p.as_rule() == Rule::arg_section) + .unwrap(); + + let public_arg_list_pair = arg_section_pair + .into_inner() + .find(|p| p.as_rule() == Rule::public_arg_list) + .unwrap(); + + Ok(public_arg_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + .count()) +} + +fn second_pass(ctx: &mut ProcessingContext) -> Result { + let mut cpb_builder = + CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlogBatch".to_string()); + + for pred_pair in &ctx.custom_predicate_pairs { + process_and_add_custom_predicate_to_batch(pred_pair, ctx, &mut cpb_builder)?; + } + + let custom_batch = cpb_builder.finish(); + + let request_templates = if let Some(req_pair) = &ctx.request_pair { + process_request_def(req_pair, ctx, &custom_batch)? + } else { + Vec::new() + }; + + Ok(ProcessedOutput { + custom_batch, + request_templates, + }) +} + +fn pest_pair_to_builder_arg(arg_content_pair: &Pair) -> Result { + match arg_content_pair.as_rule() { + Rule::literal_value => { + let value = process_literal_value(arg_content_pair)?; + Ok(BuilderArg::Literal(value)) + } + Rule::wildcard => { + let name = arg_content_pair.as_str().strip_prefix("?").unwrap(); + Ok(BuilderArg::WildcardLiteral(name.to_string())) + } + Rule::anchored_key => { + let mut inner_ak_pairs = arg_content_pair.clone().into_inner(); + let pod_id_pair = inner_ak_pairs.next().unwrap(); + + let pod_self_or_wc_str = match pod_id_pair.as_rule() { + Rule::wildcard => { + let name = pod_id_pair.as_str().strip_prefix("?").unwrap(); + SelfOrWildcardStr::Wildcard(name.to_string()) + } + Rule::self_keyword => SelfOrWildcardStr::SELF, + _ => { + unreachable!("Unexpected rule: {:?}", pod_id_pair.as_rule()); + } + }; + + let key_part_pair = inner_ak_pairs.next().unwrap(); + + let key_or_wildcard_str = match key_part_pair.as_rule() { + Rule::wildcard => { + let key_wildcard_name = key_part_pair.as_str().strip_prefix("?").unwrap(); + KeyOrWildcardStr::Wildcard(key_wildcard_name.to_string()) + } + Rule::literal_string => { + let key_str_literal = parse_pest_string_literal(&key_part_pair)?; + KeyOrWildcardStr::Key(key_str_literal) + } + _ => { + unreachable!("Unexpected rule: {:?}", key_part_pair.as_rule()); + } + }; + Ok(BuilderArg::Key(pod_self_or_wc_str, key_or_wildcard_str)) + } + _ => unreachable!("Unexpected rule: {:?}", arg_content_pair.as_rule()), + } +} + +fn validate_and_build_statement_template( + stmt_name_str: &str, + pred: &Predicate, + args: Vec, + processing_ctx: &ProcessingContext, + stmt_span: (usize, usize), + stmt_name_span: (usize, usize), +) -> Result { + match pred { + Predicate::Native(native_pred) => { + let (expected_arity, mapped_pred_for_arity_check) = match native_pred { + NativePredicate::Gt => (2, NativePredicate::Lt), + NativePredicate::GtEq => (2, NativePredicate::LtEq), + NativePredicate::ValueOf + | NativePredicate::Equal + | NativePredicate::NotEqual + | NativePredicate::Lt + | NativePredicate::LtEq + | NativePredicate::SetContains + | NativePredicate::DictNotContains + | NativePredicate::SetNotContains => (2, *native_pred), + NativePredicate::NotContains + | NativePredicate::Contains + | NativePredicate::ArrayContains + | NativePredicate::DictContains + | NativePredicate::SumOf + | NativePredicate::ProductOf + | NativePredicate::MaxOf + | NativePredicate::HashOf => (3, *native_pred), + NativePredicate::None | NativePredicate::False => (0, *native_pred), + }; + + if args.len() != expected_arity { + return Err(ProcessorError::ArgumentCountMismatch { + predicate: stmt_name_str.to_string(), + expected: expected_arity, + found: args.len(), + span: Some(stmt_name_span), + }); + } + + if mapped_pred_for_arity_check == NativePredicate::ValueOf { + if !matches!(args.get(0), Some(BuilderArg::Key(..))) { + return Err(ProcessorError::TypeError { + expected: "Anchored Key".to_string(), + found: args + .get(0) + .map_or("None".to_string(), |a| format!("{:?}", a)), + item: format!("argument 1 of native predicate '{}'", stmt_name_str), + span: Some(stmt_span), + }); + } + if !matches!(args.get(1), Some(BuilderArg::Literal(..))) { + return Err(ProcessorError::TypeError { + expected: "Literal".to_string(), + found: args + .get(1) + .map_or("None".to_string(), |a| format!("{:?}", a)), + item: format!("argument 2 of native predicate '{}'", stmt_name_str), + span: Some(stmt_span), + }); + } + } else if expected_arity > 0 { + for (i, arg) in args.iter().enumerate() { + if !matches!(arg, BuilderArg::Key(..)) { + return Err(ProcessorError::TypeError { + expected: "Anchored Key".to_string(), + found: format!("{:?}", arg), + item: format!( + "argument {} of native predicate '{}'", + i + 1, + stmt_name_str + ), + span: Some(stmt_span), + }); + } + } + } + } + Predicate::Custom(_) | Predicate::BatchSelf(_) => { + let (_original_pred_idx, expected_arity_val) = processing_ctx + .custom_predicate_signatures + .get(stmt_name_str) + .ok_or_else(|| { + ProcessorError::Internal(format!( + "Custom predicate signature not found for '{}' during validation", + stmt_name_str + )) + })?; + + if args.len() != *expected_arity_val { + return Err(ProcessorError::ArgumentCountMismatch { + predicate: stmt_name_str.to_string(), + expected: *expected_arity_val, + found: args.len(), + span: Some(stmt_name_span), + }); + } + + for (idx, arg) in args.iter().enumerate() { + if !matches!(arg, BuilderArg::WildcardLiteral(_) | BuilderArg::Literal(_)) { + return Err(ProcessorError::TypeError { + expected: "Wildcard or Literal".to_string(), + found: format!("{:?}", arg), + item: format!( + "argument {} of custom predicate call '{}'", + idx + 1, + stmt_name_str + ), + span: Some(stmt_span), + }); + } + } + } + } + + let mut stb = StatementTmplBuilder::new(pred.clone()); + for arg in args { + stb = stb.arg(arg); + } + Ok(stb.desugar()) +} + +fn process_and_add_custom_predicate_to_batch( + pred_def_pair: &Pair, + processing_ctx: &ProcessingContext, + cpb_builder: &mut CustomPredicateBatchBuilder, +) -> Result<(), ProcessorError> { + let mut inner_pairs = pred_def_pair.clone().into_inner(); + let name_pair = inner_pairs + .find(|p| p.as_rule() == Rule::identifier) + .unwrap(); + let name = name_pair.as_str().to_string(); + + let arg_section_pair = inner_pairs + .find(|p| p.as_rule() == Rule::arg_section) + .unwrap(); + + let mut public_arg_strings: Vec = Vec::new(); + let mut private_arg_strings: Vec = Vec::new(); + let mut defined_arg_names: HashSet = HashSet::new(); + + for arg_part_pair in arg_section_pair.into_inner() { + match arg_part_pair.as_rule() { + Rule::public_arg_list => { + for arg_ident_pair in arg_part_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + { + let arg_name = arg_ident_pair.as_str().to_string(); + if !defined_arg_names.insert(arg_name.clone()) { + return Err(ProcessorError::DuplicateWildcard { + name: arg_name, + span: Some(get_span(&arg_ident_pair)), + }); + } + public_arg_strings.push(arg_name); + } + } + Rule::private_arg_list => { + for arg_ident_pair in arg_part_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier) + { + let arg_name = arg_ident_pair.as_str().to_string(); + if !defined_arg_names.insert(arg_name.clone()) { + return Err(ProcessorError::DuplicateWildcard { + name: arg_name, + span: Some(get_span(&arg_ident_pair)), + }); + } + private_arg_strings.push(arg_name); + } + } + Rule::private_kw | Rule::COMMENT | Rule::WHITESPACE => {} + _ if arg_part_pair.as_str() == "," => {} + _ => { + unreachable!("Unexpected rule: {:?}", arg_part_pair.as_rule()); + } + } + } + + let conjunction_type_pair = inner_pairs + .find(|p| p.as_rule() == Rule::conjunction_type) + .unwrap(); + let conjunction = match conjunction_type_pair.as_str() { + "AND" => true, + "OR" => false, + _ => { + unreachable!( + "Invalid conjunction type: {}", + conjunction_type_pair.as_str() + ); + } + }; + + let statement_list_pair = inner_pairs + .find(|p| p.as_rule() == Rule::statement_list) + .unwrap_or_else(|| { + unreachable!("statement_list rule must be present in predicate definition") + }); + + let mut statement_builders = Vec::new(); + for stmt_pair in statement_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::statement) + { + let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); + let stmt_name_pair = inner_stmt_pairs + .find(|p| p.as_rule() == Rule::identifier) + .unwrap_or_else(|| unreachable!("statement name must be present in statement")); + let stmt_name_str = stmt_name_pair.as_str(); + + let builder_args = parse_statement_args(&stmt_pair)?; + + let middleware_predicate_type = + if let Some(native_pred) = native_predicate_from_string(stmt_name_str) { + Predicate::Native(native_pred) + } else if let Some((pred_index, _expected_arity)) = processing_ctx + .custom_predicate_signatures + .get(stmt_name_str) + { + Predicate::BatchSelf(*pred_index) + } else { + return Err(ProcessorError::UndefinedIdentifier { + name: stmt_name_str.to_string(), + span: Some(get_span(&stmt_name_pair)), + }); + }; + + let stb = validate_and_build_statement_template( + stmt_name_str, + &middleware_predicate_type, + builder_args, + processing_ctx, + get_span(&stmt_pair), + get_span(&stmt_name_pair), + )?; + statement_builders.push(stb); + } + + let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect(); + let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect(); + let sts_slice: &[StatementTmplBuilder] = &statement_builders; + + if conjunction { + cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?; + } else { + cpb_builder.predicate_or(&name, &public_args_strs, &private_args_strs, sts_slice)?; + } + + Ok(()) +} + +fn process_request_def( + req_def_pair: &Pair, + processing_ctx: &ProcessingContext, + custom_batch: &Arc, +) -> Result, ProcessorError> { + let mut request_wildcard_names: Vec = Vec::new(); + let mut defined_request_wildcards: HashSet = HashSet::new(); + + let mut request_statement_builders: Vec = Vec::new(); + + if let Some(statement_list_pair) = req_def_pair + .clone() + .into_inner() + .find(|p| p.as_rule() == Rule::statement_list) + { + for stmt_pair in statement_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::statement) + { + let built_stb = process_proof_request_statement_template( + &stmt_pair, + processing_ctx, + Some(custom_batch), // Pass as Option<&Arc<...>> + &mut request_wildcard_names, + &mut defined_request_wildcards, + )?; + request_statement_builders.push(built_stb); + } + } + + let mut request_templates: Vec = + Vec::with_capacity(request_statement_builders.len()); + for stb in request_statement_builders { + let tmpl = + resolve_request_statement_builder(stb, &request_wildcard_names, processing_ctx.params)?; + request_templates.push(tmpl); + } + + Ok(request_templates) +} + +fn process_proof_request_statement_template( + stmt_pair: &Pair, + processing_ctx: &ProcessingContext, + custom_batch_for_request: Option<&Arc>, + request_wildcard_names: &mut Vec, + defined_request_wildcards: &mut HashSet, +) -> Result { + let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); + let name_pair = inner_stmt_pairs + .find(|p| p.as_rule() == Rule::identifier) + .unwrap(); + let stmt_name_str = name_pair.as_str(); + + let builder_args = parse_statement_args(stmt_pair)?; + let mut temp_stmt_wildcard_names: Vec = Vec::new(); + + for arg in &builder_args { + match arg { + BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()), + BuilderArg::Key(pod_id_str, key_wc_str) => { + if let SelfOrWildcardStr::Wildcard(name) = pod_id_str { + temp_stmt_wildcard_names.push(name.clone()); + } + if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str { + temp_stmt_wildcard_names.push(key_wc_name.clone()); + } + } + _ => {} + } + } + + for name in temp_stmt_wildcard_names { + if defined_request_wildcards.insert(name.clone()) { + request_wildcard_names.push(name); + } + } + + let middleware_predicate_type = + if let Some(native_pred) = native_predicate_from_string(stmt_name_str) { + Predicate::Native(native_pred) + } else if let Some((pred_index, _expected_arity)) = processing_ctx + .custom_predicate_signatures + .get(stmt_name_str) + { + if let Some(batch_ref) = custom_batch_for_request { + Predicate::Custom(CustomPredicateRef::new(batch_ref.clone(), *pred_index)) + } else { + return Err(ProcessorError::Internal(format!( + "Custom predicate '{}' found but no custom batch provided for request processing.", + stmt_name_str + ))); + } + } else { + return Err(ProcessorError::UndefinedIdentifier { + name: stmt_name_str.to_string(), + span: Some(get_span(&name_pair)), + }); + }; + + let stb = validate_and_build_statement_template( + stmt_name_str, + &middleware_predicate_type, + builder_args, + processing_ctx, + get_span(stmt_pair), + get_span(&name_pair), + )?; + + Ok(stb.desugar()) +} + +fn process_literal_value(lit_val_pair: &Pair) -> Result { + let inner_lit = lit_val_pair.clone().into_inner().next().unwrap(); + + match inner_lit.as_rule() { + Rule::literal_int => { + let val = inner_lit.as_str().parse::().unwrap(); + Ok(Value::from(val)) + } + Rule::literal_bool => { + let val = inner_lit.as_str().parse::().unwrap(); + Ok(Value::from(val)) + } + Rule::literal_raw => { + let full_literal_str = inner_lit.as_str(); + let hex_str_no_prefix = full_literal_str + .strip_prefix("0x") + .unwrap_or(full_literal_str); + + parse_hex_str_to_raw_value(hex_str_no_prefix) + .map_err(|e| match e { + ProcessorError::InvalidLiteralFormat { kind, value, .. } => { + ProcessorError::InvalidLiteralFormat { + kind, + value, + span: Some(get_span(&inner_lit)), + } + } + ProcessorError::Internal(message) => ProcessorError::InvalidLiteralFormat { + kind: format!("raw hex processing (internal: {})", message), + value: full_literal_str.to_string(), + span: Some(get_span(&inner_lit)), + }, + _ => ProcessorError::InvalidLiteralFormat { + kind: "raw hex processing error".to_string(), + value: full_literal_str.to_string(), + span: Some(get_span(&inner_lit)), + }, + }) + .map(Value::from) + } + Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)), + Rule::literal_array => { + let elements: Result, ProcessorError> = inner_lit + .into_inner() + .map(|elem_pair| process_literal_value(&elem_pair)) + .collect(); + let middleware_array = middleware::containers::Array::new(elements?) + .map_err(|e| ProcessorError::Internal(format!("Failed to create Array: {}", e)))?; + Ok(Value::from(middleware_array)) + } + Rule::literal_set => { + let elements: Result, ProcessorError> = inner_lit + .into_inner() + .map(|elem_pair| process_literal_value(&elem_pair)) + .collect(); + let middleware_set = middleware::containers::Set::new(elements?) + .map_err(|e| ProcessorError::Internal(format!("Failed to create Set: {}", e)))?; + Ok(Value::from(middleware_set)) + } + Rule::literal_dict => { + let pairs: Result, ProcessorError> = inner_lit + .into_inner() + .map(|dict_entry_pair| { + let mut entry_inner = dict_entry_pair.clone().into_inner(); + let key_pair = entry_inner.next().unwrap(); + let val_pair = entry_inner.next().unwrap(); + let key_str = parse_pest_string_literal(&key_pair)?; + let val = process_literal_value(&val_pair)?; + Ok((Key::new(key_str), val)) + }) + .collect(); + let middleware_dict = middleware::containers::Dictionary::new(pairs?).map_err(|e| { + ProcessorError::Internal(format!("Failed to create Dictionary: {}", e)) + })?; + Ok(Value::from(middleware_dict)) + } + _ => unreachable!("Unexpected rule: {:?}", inner_lit.as_rule()), + } +} + +fn parse_pest_string_literal(pair: &Pair) -> Result { + let inner_pair = pair.clone().into_inner().next().unwrap(); + + let raw_content = inner_pair.as_str(); + let mut result = String::with_capacity(raw_content.len()); + let mut chars = raw_content.chars().peekable(); + + while let Some(c) = chars.next() { + if c == '\\' { + match chars.next() { + Some('"') => result.push('"'), + Some('\\') => result.push('\\'), + Some('/') => result.push('/'), + Some('b') => result.push('\x08'), + Some('f') => result.push('\x0C'), + Some('n') => result.push('\n'), + Some('r') => result.push('\r'), + Some('t') => result.push('\t'), + Some('u') => { + let mut hex_code = String::with_capacity(4); + for _ in 0..4 { + hex_code.push(chars.next().ok_or_else(|| { + ProcessorError::InvalidLiteralFormat { + kind: "unicode escape".to_string(), + value: format!("\\u{}... (incomplete)", hex_code), + span: Some(get_span(&inner_pair)), + } + })?); + } + let char_code = u32::from_str_radix(&hex_code, 16).map_err(|_| { + ProcessorError::InvalidLiteralFormat { + kind: "unicode escape".to_string(), + value: format!("\\u{}", hex_code), + span: Some(get_span(&inner_pair)), + } + })?; + result.push(std::char::from_u32(char_code).ok_or_else(|| { + ProcessorError::InvalidLiteralFormat { + kind: "unicode escape (invalid code point)".to_string(), + value: format!("\\u{}", hex_code), + span: Some(get_span(&inner_pair)), + } + })?); + } + Some(other) => { + return Err(ProcessorError::InvalidLiteralFormat { + kind: "escape sequence".to_string(), + value: format!("\\{}", other), + span: Some(get_span(&inner_pair)), + }) + } + None => { + return Err(ProcessorError::InvalidLiteralFormat { + kind: "escape sequence".to_string(), + value: "\\ (ends with escape)".to_string(), + span: Some(get_span(&inner_pair)), + }) + } + } + } else { + result.push(c); + } + } + Ok(result) +} + +// Translates a big-endian hex string to a little-endian RawValue +fn parse_hex_str_to_raw_value(hex_str: &str) -> Result { + let mut v = [F::ZERO; VALUE_SIZE]; + let value_range = 0..VALUE_SIZE; + for i in value_range { + let start_idx = i * 16; + let end_idx = start_idx + 16; + let hex_part = &hex_str[start_idx..end_idx]; + + let u64_val = u64::from_str_radix(hex_part, 16).unwrap(); + v[VALUE_SIZE - i - 1] = F::from_canonical_u64(u64_val); + } + Ok(middleware::RawValue(v)) +} + +// Helper to resolve a wildcard name string to an indexed middleware::Wildcard +// based on an ordered list of names from the current scope (e.g., request or predicate def). +fn resolve_wildcard( + ordered_scope_wildcard_names: &[String], + name_to_resolve: &str, +) -> Result { + ordered_scope_wildcard_names + .iter() + .position(|n| n == name_to_resolve) + .map(|index| Wildcard::new(name_to_resolve.to_string(), index)) + .ok_or_else(|| ProcessorError::UndefinedWildcard { + name: name_to_resolve.to_string(), + span: None, + }) +} + +fn resolve_key_or_wildcard_str( + ordered_scope_wildcard_names: &[String], + kows: &KeyOrWildcardStr, +) -> Result { + match kows { + KeyOrWildcardStr::Key(k_str) => Ok(KeyOrWildcard::Key(Key::new(k_str.clone()))), + KeyOrWildcardStr::Wildcard(wc_name_str) => { + let resolved_wc = resolve_wildcard(ordered_scope_wildcard_names, wc_name_str)?; + Ok(KeyOrWildcard::Wildcard(resolved_wc)) + } + } +} + +fn resolve_request_statement_builder( + stb: StatementTmplBuilder, + ordered_request_wildcard_names: &[String], + params: &Params, +) -> Result { + let stb = stb.desugar(); + + let mut middleware_args = Vec::with_capacity(stb.args.len()); + for builder_arg in stb.args { + let mw_arg = match builder_arg { + BuilderArg::Literal(v) => StatementTmplArg::Literal(v), + BuilderArg::Key(pod_id_str, key_wc_str) => { + let pod_sowc = match pod_id_str { + SelfOrWildcardStr::SELF => MiddlewareSelfOrWildcard::SELF, + SelfOrWildcardStr::Wildcard(name) => MiddlewareSelfOrWildcard::Wildcard( + resolve_wildcard(ordered_request_wildcard_names, &name)?, + ), + }; + let key_or_wc = + resolve_key_or_wildcard_str(ordered_request_wildcard_names, &key_wc_str)?; + StatementTmplArg::AnchoredKey(pod_sowc, key_or_wc) + } + BuilderArg::WildcardLiteral(wc_name) => { + let pod_wc = resolve_wildcard(ordered_request_wildcard_names, &wc_name)?; + StatementTmplArg::WildcardLiteral(pod_wc) + } + }; + middleware_args.push(mw_arg); + } + + if middleware_args.len() > params.max_statement_args { + return Err(ProcessorError::Middleware(middleware::Error::max_length( + format!("Arguments for predicate {:?}", stb.predicate), + middleware_args.len(), + params.max_statement_args, + ))); + } + + Ok(StatementTmpl { + pred: stb.predicate, + args: middleware_args, + }) +} + +fn parse_statement_args(stmt_pair: &Pair) -> Result, ProcessorError> { + let mut builder_args = Vec::new(); + let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); + + if let Some(arg_list_pair) = inner_stmt_pairs.find(|p| p.as_rule() == Rule::statement_arg_list) + { + for arg_pair in arg_list_pair + .into_inner() + .filter(|p| p.as_rule() == Rule::statement_arg) + { + let arg_content_pair = arg_pair.into_inner().next().unwrap(); + let builder_arg = pest_pair_to_builder_arg(&arg_content_pair)?; + builder_args.push(builder_arg); + } + } + Ok(builder_args) +} + +#[cfg(test)] +mod processor_tests { + use std::collections::HashMap; + + use pest::iterators::Pairs; + + use super::{first_pass, second_pass, ProcessingContext}; + use crate::{ + lang::{ + error::ProcessorError, + parser::{parse_podlog, Rule}, + }, + middleware::Params, + }; + + fn get_document_content_pairs(input: &str) -> Result, ProcessorError> { + let full_parse_tree = parse_podlog(input) + .map_err(|e| ProcessorError::Internal(format!("Test parsing failed: {:?}", e)))?; + + let document_node = full_parse_tree.peek().ok_or_else(|| { + ProcessorError::Internal("Parser returned no pairs for the document rule.".to_string()) + })?; + + if document_node.as_rule() != Rule::document { + return Err(ProcessorError::Internal(format!( + "Expected top-level pair to be Rule::document, but found {:?}.", + document_node.as_rule() + ))); + } + Ok(full_parse_tree.into_iter().next().unwrap().into_inner()) + } + + #[test] + fn test_fp_empty_input() -> Result<(), ProcessorError> { + let input = ""; + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + assert!(ctx.custom_predicate_signatures.is_empty()); + assert!(ctx.custom_predicate_pairs.is_empty()); + assert!(ctx.request_pair.is_none()); + Ok(()) + } + + #[test] + fn test_fp_only_request() -> Result<(), ProcessorError> { + let input = "REQUEST( Equal(?A[\"k\"],?B[\"k\"]) )"; // Escaped quotes + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + assert!(ctx.custom_predicate_signatures.is_empty()); + assert!(ctx.custom_predicate_pairs.is_empty()); + assert!(ctx.request_pair.is_some()); + assert_eq!( + ctx.request_pair.as_ref().unwrap().as_rule(), + Rule::request_def + ); + Ok(()) + } + + #[test] + fn test_fp_simple_predicate() -> Result<(), ProcessorError> { + let input = "my_pred(A, B) = AND( Equal(?A[\"k\"],?B[\"k\"]) )"; // Escaped quotes + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + assert_eq!(ctx.custom_predicate_signatures.len(), 1); + assert_eq!(ctx.custom_predicate_pairs.len(), 1); + assert!(ctx.request_pair.is_none()); + + let (index, arity) = ctx.custom_predicate_signatures.get("my_pred").unwrap(); + assert_eq!(*index, 0); + assert_eq!(*arity, 2); // A, B + assert_eq!( + ctx.custom_predicate_pairs[0].as_rule(), + Rule::custom_predicate_def + ); + Ok(()) + } + + #[test] + fn test_fp_multiple_predicates() -> Result<(), ProcessorError> { + let input = r#" + pred1(X) = AND( Equal(?X["k"],?X["k"]) ) + pred2(Y, Z) = OR( ValueOf(?Y["v"], 123) ) + "#; + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + assert_eq!(ctx.custom_predicate_signatures.len(), 2); + assert_eq!(ctx.custom_predicate_pairs.len(), 2); + + let (idx1, arity1) = ctx.custom_predicate_signatures.get("pred1").unwrap(); + assert_eq!(*idx1, 0); + assert_eq!(*arity1, 1); + + let (idx2, arity2) = ctx.custom_predicate_signatures.get("pred2").unwrap(); + assert_eq!(*idx2, 1); + assert_eq!(*arity2, 2); + Ok(()) + } + + #[test] + fn test_fp_predicate_public_args_count() -> Result<(), ProcessorError> { + let inputs_and_expected_arities = vec![ + ("p1(A) = AND(None()) // One public arg", 1), + ("p3(A,B,C) = AND(None()) // Three public args", 3), + ("p_pub_priv(Pub1, private: Priv1) = AND(None())", 1), + ]; + + for (input_str, expected_arity) in inputs_and_expected_arities { + let pairs = get_document_content_pairs(input_str)?; + let params = Params::default(); + let mut ctx = ProcessingContext { + params: ¶ms, + custom_predicate_signatures: HashMap::new(), + custom_predicate_pairs: Vec::new(), + request_pair: None, + }; + first_pass(pairs, &mut ctx)?; + let pred_name = ctx + .custom_predicate_signatures + .keys() + .next() + .expect("No predicate found in test string"); + let (_, arity) = ctx.custom_predicate_signatures.get(pred_name).unwrap(); + assert_eq!(*arity, expected_arity, "Mismatch for input: {}", input_str); + } + Ok(()) + } + + #[test] + fn test_fp_duplicate_predicate() { + let input = r#" + my_pred(A) = AND(None()) + my_pred(B) = OR(None()) + "#; + let pairs = get_document_content_pairs(input).unwrap(); + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + let result = first_pass(pairs, &mut ctx); + assert!(result.is_err()); + match result.err().unwrap() { + // Use .err().unwrap() for ProcessorError + ProcessorError::DuplicateDefinition { name, .. } => { + assert_eq!(name, "my_pred"); + } + e => panic!("Expected DuplicateDefinition, got {:?}", e), + } + } + + #[test] + fn test_fp_multiple_requests() { + let input = r#" + REQUEST(None()) + REQUEST(None()) + "#; + let pairs = get_document_content_pairs(input).unwrap(); + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + let result = first_pass(pairs, &mut ctx); + assert!(result.is_err()); + match result.err().unwrap() { + // Use .err().unwrap() for ProcessorError + ProcessorError::MultipleRequestDefinitions { .. } => { /* Correct error */ } + e => panic!("Expected MultipleRequestDefinitions, got {:?}", e), + } + } + + #[test] + fn test_fp_mixed_content() -> Result<(), ProcessorError> { + let input = r#" + pred_one(X) = AND(None()) + REQUEST( pred_one(?A) ) + pred_two(Y, Z) = OR(None()) + "#; + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + + assert_eq!(ctx.custom_predicate_signatures.len(), 2); + assert_eq!(ctx.custom_predicate_pairs.len(), 2); + assert!(ctx.request_pair.is_some()); + + let (idx1, arity1) = ctx.custom_predicate_signatures.get("pred_one").unwrap(); + assert_eq!(*idx1, 0); + assert_eq!(*arity1, 1); + + let (idx2, arity2) = ctx.custom_predicate_signatures.get("pred_two").unwrap(); + assert_eq!(*idx2, 1); + assert_eq!(*arity2, 2); + + // Check that the pairs were stored in the correct order and have the correct content (simplistic check) + assert!(ctx.custom_predicate_pairs[0].as_str().contains("pred_one")); + assert!(ctx.custom_predicate_pairs[1].as_str().contains("pred_two")); + assert!(ctx + .request_pair + .as_ref() + .unwrap() + .as_str() + .contains("pred_one(?A)")); + + Ok(()) + } + + #[test] + fn test_sp_unknown_predicate() -> Result<(), ProcessorError> { + // Undefined predicates will be flagged as an error on the second pass + let input = r#" + REQUEST( + pred_one(?A) + ) + "#; + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + let result = second_pass(&mut ctx); + assert!(result.is_err()); + match result.err().unwrap() { + ProcessorError::UndefinedIdentifier { name, span: _ } => { + assert_eq!(name, "pred_one") + } + e => panic!("Expected UndefinedIdentifier, got {:?}", e), + } + + // Native predicate names are case-sensitive + let input = r#" + REQUEST( + EQUAL(?A[?B], ?C[?D]) + ) + "#; + let pairs = get_document_content_pairs(input)?; + let params = Params::default(); + let mut ctx = ProcessingContext::new(¶ms); + first_pass(pairs, &mut ctx)?; + let result = second_pass(&mut ctx); + assert!(result.is_err()); + match result.err().unwrap() { + ProcessorError::UndefinedIdentifier { name, span: _ } => { + assert_eq!(name, "EQUAL") + } + e => panic!("Expected UndefinedIdentifier, got {:?}", e), + } + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 72c1bbc..da15f61 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod backends; pub mod constants; pub mod frontend; +pub mod lang; pub mod middleware; #[cfg(test)] diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index b43e81f..3b0dbdd 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -190,15 +190,22 @@ impl fmt::Display for Hash { impl FromHex for Hash { type Error = FromHexError; - // TODO make it dependant on backend::Value len fn from_hex>(hex: T) -> Result { - // In little endian + // The input `hex` is a big-endian hex string. let bytes = <[u8; 32]>::from_hex(hex)?; - let mut buf: [u8; 8] = [0; 8]; let mut inner = [F::ZERO; HASH_SIZE]; + for i in 0..HASH_SIZE { - buf.copy_from_slice(&bytes[8 * i..8 * (i + 1)]); - inner[i] = F::from_canonical_u64(u64::from_le_bytes(buf)); + let start = i * 8; + let end = start + 8; + let chunk: [u8; 8] = bytes[start..end] + .try_into() + .expect("slice with incorrect length"); + + // We read big-endian chunks from a big-endian string, + // and place them into a little-endian limb array. + let u64_val = u64::from_be_bytes(chunk); + inner[HASH_SIZE - 1 - i] = F::from_canonical_u64(u64_val); } Ok(Self(inner)) } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index fb5465c..ec9584a 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -373,20 +373,12 @@ impl fmt::Display for CustomPredicate { pub struct CustomPredicateBatch { id: Hash, pub name: String, - predicates: Vec, + pub(crate) predicates: Vec, } impl ToFields for CustomPredicateBatch { fn to_fields(&self, params: &Params) -> Vec { // all the custom predicates in order - - // TODO think if this check should go into the StatementTmpl creation, - // instead of at the `to_fields` method, where we should assume that the - // values are already valid - if self.predicates.len() > params.max_custom_batch_size { - panic!("Predicate batch exceeds maximum size"); - } - let pad_pred = CustomPredicate::empty(); let fields: Vec = self .predicates diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index a2ead6f..81b86ff 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,4 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Write, +}; use plonky2::field::types::Field; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; @@ -13,10 +16,16 @@ fn serialize_field_tuple( where S: serde::Serializer, { - serializer.serialize_str(&format!( - "{:016x}{:016x}{:016x}{:016x}", - value[0].0, value[1].0, value[2].0, value[3].0 - )) + // `value` is little-endian in memory. We serialize it as a big-endian hex string + // for human readability. + let s = value + .iter() + .rev() + .fold(String::with_capacity(N * 16), |mut s, limb| { + write!(s, "{:016x}", limb.0).unwrap(); + s + }); + serializer.serialize_str(&s) } fn deserialize_field_tuple<'de, D, const N: usize>(deserializer: D) -> Result<[F; N], D::Error> @@ -25,20 +34,29 @@ where { let hex_str = String::deserialize(deserializer)?; - if !hex_str.chars().count() == 64 || !hex_str.chars().all(|c| c.is_ascii_hexdigit()) { + let expected_len = N * 16; + if hex_str.len() != expected_len { + return Err(serde::de::Error::custom(format!( + "Invalid hex string length: expected {} characters, found {}", + expected_len, + hex_str.len() + ))); + } + if !hex_str.chars().all(|c| c.is_ascii_hexdigit()) { return Err(serde::de::Error::custom( - "Invalid hex string format - expected 64 hexadecimal characters", + "Invalid hex string format: contains non-hexadecimal characters", )); } let mut v = [F::ZERO; N]; - for (i, v_i) in v.iter_mut().enumerate() { + for i in 0..N { let start = i * 16; let end = start + 16; let hex_part = &hex_str[start..end]; - *v_i = F::from_canonical_u64( - u64::from_str_radix(hex_part, 16).map_err(serde::de::Error::custom)?, - ); + let u64_val = u64::from_str_radix(hex_part, 16).map_err(serde::de::Error::custom)?; + // The hex string is big-endian, so the first chunk (i=0) is the most significant. + // We store it in the last position of our little-endian array `v`. + v[N - 1 - i] = F::from_canonical_u64(u64_val); } Ok(v) }