From b123185ee9cbaf4dd33d7ec142e984a6a364e447 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Tue, 1 Jul 2025 10:34:35 +0200 Subject: [PATCH] Support public key literals and tidy up handling of Raw vs PodId (#319) * Support public key literals and tidy up handling of Raw vs PodId --- src/backends/plonky2/primitives/ec/curve.rs | 36 ++++++---- src/lang/grammar.pest | 20 ++++-- src/lang/mod.rs | 78 ++++++++++++++++++++- src/lang/parser.rs | 42 +++++++++-- src/lang/processor.rs | 38 ++++++++-- src/middleware/mod.rs | 2 + 6 files changed, 184 insertions(+), 32 deletions(-) diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index af064d5..47f37ad 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -5,6 +5,7 @@ use core::ops::{Add, Mul}; use std::{ array, fmt, ops::{AddAssign, Neg, Sub}, + str::FromStr, sync::LazyLock, }; @@ -121,6 +122,27 @@ impl fmt::Display for Point { } } +impl FromStr for Point { + type Err = Error; + + fn from_str(s: &str) -> Result { + let point_bytes = bs58::decode(s) + .into_vec() + .map_err(|e| Error::custom(format!("Base58 decode error: {}", e)))?; + + if point_bytes.len() == 80 { + // Non-compressed + Ok(Point { + x: ec_field_from_bytes(&point_bytes[..40])?, + u: ec_field_from_bytes(&point_bytes[40..])?, + }) + } else { + // Compressed + Self::from_bytes_into_subgroup(&point_bytes) + } + } +} + impl Serialize for Point { fn serialize(&self, serializer: S) -> Result where @@ -137,19 +159,7 @@ impl<'de> Deserialize<'de> for Point { D: Deserializer<'de>, { let point_b58 = String::deserialize(deserializer)?; - let point_bytes: Vec = bs58::decode(point_b58) - .into_vec() - .map_err(serde::de::Error::custom)?; - if point_bytes.len() == 80 { - // Non-compressed - Ok(Point { - x: ec_field_from_bytes(&point_bytes[..40]).map_err(serde::de::Error::custom)?, - u: ec_field_from_bytes(&point_bytes[40..]).map_err(serde::de::Error::custom)?, - }) - } else { - // Compressed - Self::from_bytes_into_subgroup(&point_bytes).map_err(serde::de::Error::custom) - } + Self::from_str(&point_b58).map_err(serde::de::Error::custom) } } diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index ef0ce01..3d3d350 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -32,7 +32,7 @@ document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI } use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref } use_predicate_list = { import_name ~ ("," ~ import_name)* } import_name = { identifier | "_" } -batch_ref = { literal_raw } +batch_ref = { hash_hex } request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" } @@ -59,11 +59,13 @@ anchored_key = { wildcard ~ "[" ~ literal_string ~ "]" } // Literal Values (ordered to avoid ambiguity, e.g., string before int) literal_value = { + literal_public_key | literal_dict | literal_set | literal_array | literal_bool | literal_raw | + literal_pod_id | literal_string | literal_int } @@ -72,9 +74,12 @@ literal_value = { literal_int = @{ "-"? ~ ASCII_DIGIT+ } literal_bool = @{ "true" | "false" } -// literal_raw: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters) +// hash_hex: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters) // representing a 32-byte value in big-endian order -literal_raw = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} } +hash_hex = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} } + +literal_raw = { "Raw" ~ "(" ~ hash_hex ~ ")" } +literal_pod_id = { hash_hex } // String literal parsing based on https://pest.rs/book/examples/json.html literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule @@ -85,6 +90,11 @@ char = { // Rule for a single logical character (unescaped or escaped) | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence } +// PublicKey(...) +base58_char = { '1'..'9' | 'A'..'H' | 'J'..'N' | 'P'..'Z' | 'a'..'k' | 'm'..'z' } +base58_string = @{ base58_char+ } +literal_public_key = { "PublicKey" ~ "(" ~ base58_string ~ ")" } + // Container Literals (recursive definition using literal_value) literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } @@ -95,7 +105,9 @@ dict_pair = { literal_string ~ ":" ~ literal_value } test_identifier = { SOI ~ identifier ~ EOI } test_wildcard = { SOI ~ wildcard ~ EOI } test_literal_int = { SOI ~ literal_int ~ EOI } -test_literal_raw = { SOI ~ literal_raw ~ EOI } +test_hash_hex = { SOI ~ hash_hex ~ EOI } +test_literal_raw = { SOI ~ literal_raw ~ EOI } +test_literal_pod_id = { SOI ~ literal_pod_id ~ EOI } test_literal_value = { SOI ~ literal_value ~ EOI } test_statement = { SOI ~ statement ~ EOI } test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 1cdee76..1f6dc8c 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -29,9 +29,9 @@ mod tests { use crate::{ lang::error::ProcessorError, middleware::{ - CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, - Params, PodType, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard, - SELF_ID_HASH, + hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, + NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl, + StatementTmplArg, Value, Wildcard, SELF_ID_HASH, }, }; @@ -854,6 +854,78 @@ mod tests { Ok(()) } + #[test] + fn test_e2e_literals() -> Result<(), LangError> { + let pk = crate::backends::plonky2::primitives::ec::curve::Point::generator(); + let pk_b58 = pk.to_string(); + let pod_id = PodId(hash_str("test")); + let raw = RawValue::from(1); + let string = "hello"; + let int = 123; + let bool = true; + + let input = format!( + r#" + REQUEST( + Equal(?A["pk"], PublicKey({})) + Equal(?B["pod_id"], {:#}) + Equal(?C["raw"], Raw({:#})) + Equal(?D["string"], "{}") + Equal(?E["int"], {}) + Equal(?F["bool"], {}) + ) + "#, + pk_b58, pod_id, raw, string, int, bool + ); + /* + REQUEST( + Equal(?A["pk"], PublicKey(3t9fNuU194n7mSJPRdeaJRMqw6ZQCUddzvECWNe1k2b1rdBezXpJxF)) + Equal(?B["pod_id"], 0x735b31d3aad0f5b66002ffe1dc7d2eaa0ee9c59c09b641e8261530c5f3a02f29) + Equal(?C["raw"], Raw(0x0000000000000000000000000000000000000000000000000000000000000001)) + Equal(?D["string"], "hello") + Equal(?E["int"], 123) + Equal(?F["bool"], true) + ) + */ + + let params = Params::default(); + let processed = parse(&input, ¶ms, &[])?; + let request_templates = processed.request_templates; + + assert_eq!(request_templates.len(), 6); + + let expected_templates = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("B", 1), "pod_id"), sta_lit(Value::from(pod_id))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("C", 2), "raw"), sta_lit(Value::from(raw))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("D", 3), "string"), sta_lit(Value::from(string))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("E", 4), "int"), sta_lit(Value::from(int))], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![sta_ak(("F", 5), "bool"), sta_lit(Value::from(bool))], + }, + ]; + + assert_eq!(request_templates, expected_templates); + + Ok(()) + } + #[test] fn test_e2e_use_unknown_batch() { let params = Params::default(); diff --git a/src/lang/parser.rs b/src/lang/parser.rs index d191769..f8995f7 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -106,19 +106,42 @@ mod tests { // Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements) assert_parses( Rule::literal_raw, - "0x0000000000000000000000000000000000000000000000000000000000000000", + "Raw(0x0000000000000000000000000000000000000000000000000000000000000000)", ); assert_parses( Rule::literal_raw, - "0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", + "Raw(0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd)", ); - let long_valid_raw = format!("0x{}", "a".repeat(64)); + let long_valid_raw = format!("Raw(0x{})", "a".repeat(64)); assert_parses(Rule::literal_raw, &long_valid_raw); // Use anchored rule for failure cases - assert_fails(Rule::test_literal_raw, "0xabc"); // Fails (string is too short) - assert_fails(Rule::test_literal_raw, "0x"); // Fails (needs at least one pair) - assert_fails(Rule::test_literal_raw, &format!("0x{}", "a".repeat(66))); // Fails (string is too long) + assert_fails( + Rule::test_literal_raw, + "0x0000000000000000000000000000000000000000000000000000000000000000)", + ); // Missing Raw() wrapper + assert_fails(Rule::test_literal_raw, "Raw(0xabc)"); // Fails (string is too short) + assert_fails(Rule::test_literal_raw, "Raw(0x)"); // Fails (needs at least one pair) + assert_fails( + Rule::test_literal_raw, + &format!("Raw(0x{})", "a".repeat(66)), + ); // Fails (string is too long) + + // PodId (essentially identical to Raw but without the wrapper) + assert_parses( + Rule::literal_pod_id, + "0x0000000000000000000000000000000000000000000000000000000000000000", + ); + assert_parses( + Rule::literal_pod_id, + "0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", + ); + let long_valid_pod_id = format!("0x{}", "a".repeat(64)); + assert_parses(Rule::literal_pod_id, &long_valid_pod_id); + + assert_fails(Rule::test_literal_pod_id, "0xabc"); // Fails (string is too short) + assert_fails(Rule::test_literal_pod_id, "0x"); // Fails (needs at least one pair) + assert_fails(Rule::test_literal_pod_id, &format!("0x{}", "a".repeat(66))); // Fails (string is too long) // String assert_parses(Rule::literal_string, "\"hello\""); @@ -126,10 +149,16 @@ mod tests { assert_parses(Rule::literal_string, "\"\\\\ backslash\""); assert_parses(Rule::literal_string, "\"\\uABCD\""); assert_fails(Rule::literal_string, "\"unterminated"); + + // PublicKey + assert_parses(Rule::literal_public_key, "PublicKey(base58string)"); + assert_fails(Rule::literal_public_key, "PublicKey(OhNo)"); // Fails because O is not valid base58 + // Array assert_parses(Rule::literal_array, "[]"); assert_parses(Rule::literal_array, "[1, \"two\", true]"); assert_parses(Rule::literal_array, "[ [1], #[2] ]"); + // Set assert_parses(Rule::literal_set, "#[]"); assert_parses(Rule::literal_set, "#[1, 2, 3]"); @@ -137,6 +166,7 @@ mod tests { Rule::literal_set, "#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]", ); + // Dict assert_parses(Rule::literal_dict, "{}"); assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }"); diff --git a/src/lang/processor.rs b/src/lang/processor.rs index a80f504..6672ed6 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -8,6 +8,7 @@ use plonky2::field::types::Field; use super::error::ProcessorError; use crate::{ + backends::plonky2::primitives::ec::curve::Point, frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, lang::parser::Rule, middleware::{ @@ -335,10 +336,10 @@ fn validate_and_build_statement_template( | NativePredicate::Lt | NativePredicate::LtEq | NativePredicate::SetContains + | NativePredicate::NotContains | NativePredicate::DictNotContains | NativePredicate::SetNotContains => 2, - NativePredicate::NotContains - | NativePredicate::Contains + NativePredicate::Contains | NativePredicate::ArrayContains | NativePredicate::DictContains | NativePredicate::SumOf @@ -523,7 +524,6 @@ fn process_and_add_custom_predicate_to_batch( let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect(); let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect(); let sts_slice: &[StatementTmplBuilder] = &statement_builders; - if conjunction { cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?; } else { @@ -667,11 +667,11 @@ fn process_literal_value( Ok(Value::from(val)) } Rule::literal_raw => { - let full_literal_str = inner_lit.as_str(); + let full_literal_str = inner_lit.clone().into_inner().next().unwrap(); let hex_str_no_prefix = full_literal_str + .as_str() .strip_prefix("0x") - .unwrap_or(full_literal_str); - + .unwrap_or(full_literal_str.as_str()); parse_hex_str_to_raw_value(hex_str_no_prefix) .map_err(|e| match e { ProcessorError::InvalidLiteralFormat { kind, value, .. } => { @@ -694,6 +694,27 @@ fn process_literal_value( }) .map(Value::from) } + Rule::literal_pod_id => { + let hex_str_no_prefix = inner_lit + .as_str() + .strip_prefix("0x") + .unwrap_or(inner_lit.as_str()); + let pod_id = parse_hex_str_to_pod_id(hex_str_no_prefix)?; + Ok(Value::from(pod_id)) + } + Rule::literal_public_key => { + let pk_str_pair = inner_lit.into_inner().next().unwrap(); + let pk_b58 = pk_str_pair.as_str(); + let point: Point = + pk_b58 + .parse() + .map_err(|e| ProcessorError::InvalidLiteralFormat { + kind: "PublicKey".to_string(), + value: format!("{} (error: {})", pk_b58, e), + span: Some(get_span(&pk_str_pair)), + })?; + Ok(Value::from(point)) + } Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)), Rule::literal_array => { let elements: Result, ProcessorError> = inner_lit @@ -823,6 +844,11 @@ fn parse_hex_str_to_raw_value(hex_str: &str) -> Result Result { + let raw = parse_hex_str_to_raw_value(hex_str)?; + Ok(middleware::PodId(raw.into())) +} + // Helper to resolve a wildcard name string to an indexed middleware::Wildcard // based on an ordered list of names from the current scope (e.g., request or predicate def). fn resolve_wildcard( diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 04f1272..83ca711 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -440,6 +440,8 @@ impl fmt::Display for PodId { write!(f, "self") } else if self.0 == EMPTY_HASH { write!(f, "null") + } else if f.alternate() { + write!(f, "{:#}", self.0) } else { write!(f, "{}", self.0) }