Support public key literals and tidy up handling of Raw vs PodId (#319)

* Support public key literals and tidy up handling of Raw vs PodId
This commit is contained in:
Rob Knight 2025-07-01 10:34:35 +02:00 committed by GitHub
parent 6aa4acac4a
commit b123185ee9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 184 additions and 32 deletions

View file

@ -5,6 +5,7 @@ use core::ops::{Add, Mul};
use std::{ use std::{
array, fmt, array, fmt,
ops::{AddAssign, Neg, Sub}, ops::{AddAssign, Neg, Sub},
str::FromStr,
sync::LazyLock, sync::LazyLock,
}; };
@ -121,6 +122,27 @@ impl fmt::Display for Point {
} }
} }
impl FromStr for Point {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let point_bytes = bs58::decode(s)
.into_vec()
.map_err(|e| Error::custom(format!("Base58 decode error: {}", e)))?;
if point_bytes.len() == 80 {
// Non-compressed
Ok(Point {
x: ec_field_from_bytes(&point_bytes[..40])?,
u: ec_field_from_bytes(&point_bytes[40..])?,
})
} else {
// Compressed
Self::from_bytes_into_subgroup(&point_bytes)
}
}
}
impl Serialize for Point { impl Serialize for Point {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@ -137,19 +159,7 @@ impl<'de> Deserialize<'de> for Point {
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let point_b58 = String::deserialize(deserializer)?; let point_b58 = String::deserialize(deserializer)?;
let point_bytes: Vec<u8> = bs58::decode(point_b58) Self::from_str(&point_b58).map_err(serde::de::Error::custom)
.into_vec()
.map_err(serde::de::Error::custom)?;
if point_bytes.len() == 80 {
// Non-compressed
Ok(Point {
x: ec_field_from_bytes(&point_bytes[..40]).map_err(serde::de::Error::custom)?,
u: ec_field_from_bytes(&point_bytes[40..]).map_err(serde::de::Error::custom)?,
})
} else {
// Compressed
Self::from_bytes_into_subgroup(&point_bytes).map_err(serde::de::Error::custom)
}
} }
} }

View file

@ -32,7 +32,7 @@ document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI }
use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref } use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref }
use_predicate_list = { import_name ~ ("," ~ import_name)* } use_predicate_list = { import_name ~ ("," ~ import_name)* }
import_name = { identifier | "_" } import_name = { identifier | "_" }
batch_ref = { literal_raw } batch_ref = { hash_hex }
request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" } request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" }
@ -59,11 +59,13 @@ anchored_key = { wildcard ~ "[" ~ literal_string ~ "]" }
// Literal Values (ordered to avoid ambiguity, e.g., string before int) // Literal Values (ordered to avoid ambiguity, e.g., string before int)
literal_value = { literal_value = {
literal_public_key |
literal_dict | literal_dict |
literal_set | literal_set |
literal_array | literal_array |
literal_bool | literal_bool |
literal_raw | literal_raw |
literal_pod_id |
literal_string | literal_string |
literal_int literal_int
} }
@ -72,9 +74,12 @@ literal_value = {
literal_int = @{ "-"? ~ ASCII_DIGIT+ } literal_int = @{ "-"? ~ ASCII_DIGIT+ }
literal_bool = @{ "true" | "false" } literal_bool = @{ "true" | "false" }
// literal_raw: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters) // hash_hex: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters)
// representing a 32-byte value in big-endian order // representing a 32-byte value in big-endian order
literal_raw = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} } hash_hex = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} }
literal_raw = { "Raw" ~ "(" ~ hash_hex ~ ")" }
literal_pod_id = { hash_hex }
// String literal parsing based on https://pest.rs/book/examples/json.html // String literal parsing based on https://pest.rs/book/examples/json.html
literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule
@ -85,6 +90,11 @@ char = { // Rule for a single logical character (unescaped or escaped)
| "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence
} }
// PublicKey(...)
base58_char = { '1'..'9' | 'A'..'H' | 'J'..'N' | 'P'..'Z' | 'a'..'k' | 'm'..'z' }
base58_string = @{ base58_char+ }
literal_public_key = { "PublicKey" ~ "(" ~ base58_string ~ ")" }
// Container Literals (recursive definition using literal_value) // Container Literals (recursive definition using literal_value)
literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" } literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
@ -95,7 +105,9 @@ dict_pair = { literal_string ~ ":" ~ literal_value }
test_identifier = { SOI ~ identifier ~ EOI } test_identifier = { SOI ~ identifier ~ EOI }
test_wildcard = { SOI ~ wildcard ~ EOI } test_wildcard = { SOI ~ wildcard ~ EOI }
test_literal_int = { SOI ~ literal_int ~ EOI } test_literal_int = { SOI ~ literal_int ~ EOI }
test_hash_hex = { SOI ~ hash_hex ~ EOI }
test_literal_raw = { SOI ~ literal_raw ~ EOI } test_literal_raw = { SOI ~ literal_raw ~ EOI }
test_literal_pod_id = { SOI ~ literal_pod_id ~ EOI }
test_literal_value = { SOI ~ literal_value ~ EOI } test_literal_value = { SOI ~ literal_value ~ EOI }
test_statement = { SOI ~ statement ~ EOI } test_statement = { SOI ~ statement ~ EOI }
test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI } test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI }

View file

@ -29,9 +29,9 @@ mod tests {
use crate::{ use crate::{
lang::error::ProcessorError, lang::error::ProcessorError,
middleware::{ middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
Params, PodType, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard, NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
SELF_ID_HASH, StatementTmplArg, Value, Wildcard, SELF_ID_HASH,
}, },
}; };
@ -854,6 +854,78 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_e2e_literals() -> Result<(), LangError> {
let pk = crate::backends::plonky2::primitives::ec::curve::Point::generator();
let pk_b58 = pk.to_string();
let pod_id = PodId(hash_str("test"));
let raw = RawValue::from(1);
let string = "hello";
let int = 123;
let bool = true;
let input = format!(
r#"
REQUEST(
Equal(?A["pk"], PublicKey({}))
Equal(?B["pod_id"], {:#})
Equal(?C["raw"], Raw({:#}))
Equal(?D["string"], "{}")
Equal(?E["int"], {})
Equal(?F["bool"], {})
)
"#,
pk_b58, pod_id, raw, string, int, bool
);
/*
REQUEST(
Equal(?A["pk"], PublicKey(3t9fNuU194n7mSJPRdeaJRMqw6ZQCUddzvECWNe1k2b1rdBezXpJxF))
Equal(?B["pod_id"], 0x735b31d3aad0f5b66002ffe1dc7d2eaa0ee9c59c09b641e8261530c5f3a02f29)
Equal(?C["raw"], Raw(0x0000000000000000000000000000000000000000000000000000000000000001))
Equal(?D["string"], "hello")
Equal(?E["int"], 123)
Equal(?F["bool"], true)
)
*/
let params = Params::default();
let processed = parse(&input, &params, &[])?;
let request_templates = processed.request_templates;
assert_eq!(request_templates.len(), 6);
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("B", 1), "pod_id"), sta_lit(Value::from(pod_id))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("C", 2), "raw"), sta_lit(Value::from(raw))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("D", 3), "string"), sta_lit(Value::from(string))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("E", 4), "int"), sta_lit(Value::from(int))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("F", 5), "bool"), sta_lit(Value::from(bool))],
},
];
assert_eq!(request_templates, expected_templates);
Ok(())
}
#[test] #[test]
fn test_e2e_use_unknown_batch() { fn test_e2e_use_unknown_batch() {
let params = Params::default(); let params = Params::default();

View file

@ -106,19 +106,42 @@ mod tests {
// Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements) // Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements)
assert_parses( assert_parses(
Rule::literal_raw, Rule::literal_raw,
"0x0000000000000000000000000000000000000000000000000000000000000000", "Raw(0x0000000000000000000000000000000000000000000000000000000000000000)",
); );
assert_parses( assert_parses(
Rule::literal_raw, Rule::literal_raw,
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd", "Raw(0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd)",
); );
let long_valid_raw = format!("0x{}", "a".repeat(64)); let long_valid_raw = format!("Raw(0x{})", "a".repeat(64));
assert_parses(Rule::literal_raw, &long_valid_raw); assert_parses(Rule::literal_raw, &long_valid_raw);
// Use anchored rule for failure cases // Use anchored rule for failure cases
assert_fails(Rule::test_literal_raw, "0xabc"); // Fails (string is too short) assert_fails(
assert_fails(Rule::test_literal_raw, "0x"); // Fails (needs at least one pair) Rule::test_literal_raw,
assert_fails(Rule::test_literal_raw, &format!("0x{}", "a".repeat(66))); // Fails (string is too long) "0x0000000000000000000000000000000000000000000000000000000000000000)",
); // Missing Raw() wrapper
assert_fails(Rule::test_literal_raw, "Raw(0xabc)"); // Fails (string is too short)
assert_fails(Rule::test_literal_raw, "Raw(0x)"); // Fails (needs at least one pair)
assert_fails(
Rule::test_literal_raw,
&format!("Raw(0x{})", "a".repeat(66)),
); // Fails (string is too long)
// PodId (essentially identical to Raw but without the wrapper)
assert_parses(
Rule::literal_pod_id,
"0x0000000000000000000000000000000000000000000000000000000000000000",
);
assert_parses(
Rule::literal_pod_id,
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd",
);
let long_valid_pod_id = format!("0x{}", "a".repeat(64));
assert_parses(Rule::literal_pod_id, &long_valid_pod_id);
assert_fails(Rule::test_literal_pod_id, "0xabc"); // Fails (string is too short)
assert_fails(Rule::test_literal_pod_id, "0x"); // Fails (needs at least one pair)
assert_fails(Rule::test_literal_pod_id, &format!("0x{}", "a".repeat(66))); // Fails (string is too long)
// String // String
assert_parses(Rule::literal_string, "\"hello\""); assert_parses(Rule::literal_string, "\"hello\"");
@ -126,10 +149,16 @@ mod tests {
assert_parses(Rule::literal_string, "\"\\\\ backslash\""); assert_parses(Rule::literal_string, "\"\\\\ backslash\"");
assert_parses(Rule::literal_string, "\"\\uABCD\""); assert_parses(Rule::literal_string, "\"\\uABCD\"");
assert_fails(Rule::literal_string, "\"unterminated"); assert_fails(Rule::literal_string, "\"unterminated");
// PublicKey
assert_parses(Rule::literal_public_key, "PublicKey(base58string)");
assert_fails(Rule::literal_public_key, "PublicKey(OhNo)"); // Fails because O is not valid base58
// Array // Array
assert_parses(Rule::literal_array, "[]"); assert_parses(Rule::literal_array, "[]");
assert_parses(Rule::literal_array, "[1, \"two\", true]"); assert_parses(Rule::literal_array, "[1, \"two\", true]");
assert_parses(Rule::literal_array, "[ [1], #[2] ]"); assert_parses(Rule::literal_array, "[ [1], #[2] ]");
// Set // Set
assert_parses(Rule::literal_set, "#[]"); assert_parses(Rule::literal_set, "#[]");
assert_parses(Rule::literal_set, "#[1, 2, 3]"); assert_parses(Rule::literal_set, "#[1, 2, 3]");
@ -137,6 +166,7 @@ mod tests {
Rule::literal_set, Rule::literal_set,
"#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]", "#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]",
); );
// Dict // Dict
assert_parses(Rule::literal_dict, "{}"); assert_parses(Rule::literal_dict, "{}");
assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }"); assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }");

View file

@ -8,6 +8,7 @@ use plonky2::field::types::Field;
use super::error::ProcessorError; use super::error::ProcessorError;
use crate::{ use crate::{
backends::plonky2::primitives::ec::curve::Point,
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
lang::parser::Rule, lang::parser::Rule,
middleware::{ middleware::{
@ -335,10 +336,10 @@ fn validate_and_build_statement_template(
| NativePredicate::Lt | NativePredicate::Lt
| NativePredicate::LtEq | NativePredicate::LtEq
| NativePredicate::SetContains | NativePredicate::SetContains
| NativePredicate::NotContains
| NativePredicate::DictNotContains | NativePredicate::DictNotContains
| NativePredicate::SetNotContains => 2, | NativePredicate::SetNotContains => 2,
NativePredicate::NotContains NativePredicate::Contains
| NativePredicate::Contains
| NativePredicate::ArrayContains | NativePredicate::ArrayContains
| NativePredicate::DictContains | NativePredicate::DictContains
| NativePredicate::SumOf | NativePredicate::SumOf
@ -523,7 +524,6 @@ fn process_and_add_custom_predicate_to_batch(
let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect(); let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect();
let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect(); let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect();
let sts_slice: &[StatementTmplBuilder] = &statement_builders; let sts_slice: &[StatementTmplBuilder] = &statement_builders;
if conjunction { if conjunction {
cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?; cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?;
} else { } else {
@ -667,11 +667,11 @@ fn process_literal_value(
Ok(Value::from(val)) Ok(Value::from(val))
} }
Rule::literal_raw => { Rule::literal_raw => {
let full_literal_str = inner_lit.as_str(); let full_literal_str = inner_lit.clone().into_inner().next().unwrap();
let hex_str_no_prefix = full_literal_str let hex_str_no_prefix = full_literal_str
.as_str()
.strip_prefix("0x") .strip_prefix("0x")
.unwrap_or(full_literal_str); .unwrap_or(full_literal_str.as_str());
parse_hex_str_to_raw_value(hex_str_no_prefix) parse_hex_str_to_raw_value(hex_str_no_prefix)
.map_err(|e| match e { .map_err(|e| match e {
ProcessorError::InvalidLiteralFormat { kind, value, .. } => { ProcessorError::InvalidLiteralFormat { kind, value, .. } => {
@ -694,6 +694,27 @@ fn process_literal_value(
}) })
.map(Value::from) .map(Value::from)
} }
Rule::literal_pod_id => {
let hex_str_no_prefix = inner_lit
.as_str()
.strip_prefix("0x")
.unwrap_or(inner_lit.as_str());
let pod_id = parse_hex_str_to_pod_id(hex_str_no_prefix)?;
Ok(Value::from(pod_id))
}
Rule::literal_public_key => {
let pk_str_pair = inner_lit.into_inner().next().unwrap();
let pk_b58 = pk_str_pair.as_str();
let point: Point =
pk_b58
.parse()
.map_err(|e| ProcessorError::InvalidLiteralFormat {
kind: "PublicKey".to_string(),
value: format!("{} (error: {})", pk_b58, e),
span: Some(get_span(&pk_str_pair)),
})?;
Ok(Value::from(point))
}
Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)), Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)),
Rule::literal_array => { Rule::literal_array => {
let elements: Result<Vec<Value>, ProcessorError> = inner_lit let elements: Result<Vec<Value>, ProcessorError> = inner_lit
@ -823,6 +844,11 @@ fn parse_hex_str_to_raw_value(hex_str: &str) -> Result<middleware::RawValue, Pro
Ok(middleware::RawValue(v)) Ok(middleware::RawValue(v))
} }
fn parse_hex_str_to_pod_id(hex_str: &str) -> Result<middleware::PodId, ProcessorError> {
let raw = parse_hex_str_to_raw_value(hex_str)?;
Ok(middleware::PodId(raw.into()))
}
// Helper to resolve a wildcard name string to an indexed middleware::Wildcard // Helper to resolve a wildcard name string to an indexed middleware::Wildcard
// based on an ordered list of names from the current scope (e.g., request or predicate def). // based on an ordered list of names from the current scope (e.g., request or predicate def).
fn resolve_wildcard( fn resolve_wildcard(

View file

@ -440,6 +440,8 @@ impl fmt::Display for PodId {
write!(f, "self") write!(f, "self")
} else if self.0 == EMPTY_HASH { } else if self.0 == EMPTY_HASH {
write!(f, "null") write!(f, "null")
} else if f.alternate() {
write!(f, "{:#}", self.0)
} else { } else {
write!(f, "{}", self.0) write!(f, "{}", self.0)
} }