From 9f8335756c00e524247d90ea7672f4364a0a7312 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Fri, 25 Jul 2025 16:43:43 +0100 Subject: [PATCH] Add Podlang pretty-printing (#353) * Add Podlang pretty-printing * Review feedback changes * Formatting * Use Display impl for printing StatementTmplArg --- src/frontend/serialization.rs | 6 - src/lang/mod.rs | 2 + src/lang/pretty_print.rs | 598 ++++++++++++++++++++++++++++++++++ src/middleware/mod.rs | 57 +++- src/middleware/statement.rs | 32 +- 5 files changed, 680 insertions(+), 15 deletions(-) create mode 100644 src/lang/pretty_print.rs diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 2090a67..aa4ac86 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -11,12 +11,6 @@ use crate::{ }, }; -#[derive(Serialize, Deserialize, JsonSchema)] -pub enum SignedPodType { - Signed, - MockSigned, -} - #[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[schemars(rename = "SignedPod")] diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 9248ced..233d566 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -1,11 +1,13 @@ pub mod error; pub mod parser; +pub mod pretty_print; pub mod processor; use std::sync::Arc; pub use error::LangError; pub use parser::{parse_podlang, Pairs, ParseError, Rule}; +pub use pretty_print::PrettyPrint; pub use processor::process_pest_tree; use processor::PodlangOutput; diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs new file mode 100644 index 0000000..b36a159 --- /dev/null +++ b/src/lang/pretty_print.rs @@ -0,0 +1,598 @@ +//! Pretty-printing functionality for POD2 custom predicates + +use std::fmt::Write; + +use crate::middleware::{ + CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value, +}; + +/// Trait for converting AST nodes to Podlang source code +/// +/// This trait provides a consistent interface for pretty-printing different +/// types of AST nodes back to their Podlang source representation. +pub trait PrettyPrint { + /// Write this AST node to a source writer + /// + /// Uses default formatting with no indentation. + fn fmt_podlang(&self, w: &mut dyn Write) -> std::fmt::Result { + self.fmt_podlang_with_indent(w, 0) + } + + /// Write this AST node to a source writer with custom indentation + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result; + + /// Convert this AST node to a Podlang source string + /// + /// Uses default formatting with no indentation. + fn to_podlang_string(&self) -> String { + self.to_podlang_string_with_indent(0) + } + + /// Convert this AST node to a Podlang source string with custom indentation + fn to_podlang_string_with_indent(&self, indent: usize) -> String { + let mut result = String::new(); + let _ = self.fmt_podlang_with_indent(&mut result, indent); + result + } +} + +impl PrettyPrint for CustomPredicate { + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result { + fmt_predicate_definition(w, self, indent, None) + } +} + +impl PrettyPrint for StatementTmpl { + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { + self.fmt_podlang_with_batch_context(w, None) + } +} + +impl StatementTmpl { + fn fmt_podlang_with_batch_context( + &self, + w: &mut dyn Write, + batch_context: Option<&CustomPredicateBatch>, + ) -> std::fmt::Result { + match &self.pred { + Predicate::Native(native_pred) => { + write!(w, "{}", native_pred)?; + } + Predicate::Custom(custom_ref) => { + write!(w, "{}", custom_ref.predicate().name)?; + } + Predicate::BatchSelf(index) => { + if let Some(batch) = batch_context { + if let Some(predicate) = batch.predicates.get(*index) { + write!(w, "{}", predicate.name)?; + } else { + write!(w, "batch_self_{}", index)?; + } + } else { + write!(w, "batch_self_{}", index)?; + } + } + } + + write!(w, "(")?; + for (i, arg) in self.args.iter().enumerate() { + if i > 0 { + write!(w, ", ")?; + } + arg.fmt_podlang(w)?; + } + write!(w, ")")?; + + Ok(()) + } +} + +impl PrettyPrint for StatementTmplArg { + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { + write!(w, "{}", self) + } +} + +impl PrettyPrint for CustomPredicateBatch { + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result { + for (i, predicate) in self.predicates.iter().enumerate() { + if i > 0 { + write!(w, "\n\n")?; + } + self.fmt_predicate_with_context(w, predicate, indent)?; + } + Ok(()) + } +} + +impl CustomPredicateBatch { + fn fmt_predicate_with_context( + &self, + w: &mut dyn Write, + predicate: &CustomPredicate, + indent: usize, + ) -> std::fmt::Result { + fmt_predicate_definition(w, predicate, indent, Some(self)) + } +} + +impl PrettyPrint for Value { + fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { + write!(w, "{}", self.typed()) + } +} + +fn fmt_predicate_definition( + w: &mut dyn Write, + predicate: &CustomPredicate, + indent: usize, + batch_context: Option<&CustomPredicateBatch>, +) -> std::fmt::Result { + let base_indent = " ".repeat(indent); + let statement_indent = " ".repeat(indent + 4); + + fmt_predicate_signature(w, predicate, &base_indent)?; + + let conjunction_str = if predicate.conjunction { "AND" } else { "OR" }; + writeln!(w, " = {}(", conjunction_str)?; + + for (i, statement) in predicate.statements.iter().enumerate() { + if i > 0 { + writeln!(w)?; + } + write!(w, "{}", statement_indent)?; + statement.fmt_podlang_with_batch_context(w, batch_context)?; + } + + write!(w, "\n{})", base_indent) +} + +fn fmt_predicate_signature( + w: &mut dyn Write, + predicate: &CustomPredicate, + base_indent: &str, +) -> std::fmt::Result { + write!(w, "{}{}", base_indent, predicate.name)?; + write!(w, "(")?; + + let mut public_args = predicate + .wildcard_names + .iter() + .take(predicate.args_len) + .peekable(); + while let Some(arg_name) = public_args.next() { + write!(w, "{}", arg_name)?; + if public_args.peek().is_some() { + write!(w, ", ")?; + } + } + + let mut private_args = predicate + .wildcard_names + .iter() + .skip(predicate.args_len) + .peekable(); + if private_args.peek().is_some() { + if predicate.args_len > 0 { + write!(w, ", ")?; + } + write!(w, "private: ")?; + while let Some(arg_name) = private_args.next() { + write!(w, "{}", arg_name)?; + if private_args.peek().is_some() { + write!(w, ", ")?; + } + } + } + + write!(w, ")") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + lang::parse, + middleware::{ + CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl, + StatementTmplArg, Value, Wildcard, + }, + }; + + fn create_test_wildcard(name: &str, index: usize) -> Wildcard { + Wildcard::new(name.to_string(), index) + } + + #[test] + fn test_simple_predicate_pretty_print() { + let params = Params::default(); + + // Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(?PodA["key"], ?PodB["key"])) + let statements = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + StatementTmplArg::AnchoredKey( + create_test_wildcard("PodA", 0), + Key::new("key".to_string()), + ), + StatementTmplArg::AnchoredKey( + create_test_wildcard("PodB", 1), + Key::new("key".to_string()), + ), + ], + }]; + + let predicate = CustomPredicate::and( + ¶ms, + "is_equal".to_string(), + statements, + 2, // args_len (PodA, PodB are public) + vec!["PodA".to_string(), "PodB".to_string()], + ) + .unwrap(); + + let pretty_printed = predicate.to_podlang_string(); + + let expected = r#"is_equal(PodA, PodB) = AND( + Equal(?PodA["key"], ?PodB["key"]) +)"#; + assert_eq!(pretty_printed, expected); + } + + #[test] + fn test_predicate_with_private_args() { + let params = Params::default(); + + // Create: uses_private(A, private: Temp) = AND(Equal(?A["input"], ?Temp["const"])) + let statements = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + StatementTmplArg::AnchoredKey( + create_test_wildcard("A", 0), + Key::new("input".to_string()), + ), + StatementTmplArg::AnchoredKey( + create_test_wildcard("Temp", 1), + Key::new("const".to_string()), + ), + ], + }]; + + let predicate = CustomPredicate::and( + ¶ms, + "uses_private".to_string(), + statements, + 1, // args_len (only A is public) + vec!["A".to_string(), "Temp".to_string()], + ) + .unwrap(); + + let pretty_printed = predicate.to_podlang_string(); + + let expected = r#"uses_private(A, private: Temp) = AND( + Equal(?A["input"], ?Temp["const"]) +)"#; + assert_eq!(pretty_printed, expected); + } + + #[test] + fn test_statement_with_literal_args() { + let params = Params::default(); + + // Create: check_value(Pod) = AND(Equal(?Pod["field"], 42)) + let statements = vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + StatementTmplArg::AnchoredKey( + create_test_wildcard("Pod", 0), + Key::new("field".to_string()), + ), + StatementTmplArg::Literal(Value::from(42i64)), + ], + }]; + + let predicate = CustomPredicate::and( + ¶ms, + "check_value".to_string(), + statements, + 1, + vec!["Pod".to_string()], + ) + .unwrap(); + + let pretty_printed = predicate.to_podlang_string(); + + let expected = r#"check_value(Pod) = AND( + Equal(?Pod["field"], 42) +)"#; + assert_eq!(pretty_printed, expected); + } + + #[test] + fn test_or_predicate() { + let params = Params::default(); + + // Create: either_or(A, B) = OR(Equal(?A["x"], 1), Equal(?B["y"], 2)) + let statements = vec![ + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + StatementTmplArg::AnchoredKey( + create_test_wildcard("A", 0), + Key::new("x".to_string()), + ), + StatementTmplArg::Literal(Value::from(1i64)), + ], + }, + StatementTmpl { + pred: Predicate::Native(NativePredicate::Equal), + args: vec![ + StatementTmplArg::AnchoredKey( + create_test_wildcard("B", 1), + Key::new("y".to_string()), + ), + StatementTmplArg::Literal(Value::from(2i64)), + ], + }, + ]; + + let predicate = CustomPredicate::or( + ¶ms, + "either_or".to_string(), + statements, + 2, + vec!["A".to_string(), "B".to_string()], + ) + .unwrap(); + + let pretty_printed = predicate.to_podlang_string(); + + let expected = r#"either_or(A, B) = OR( + Equal(?A["x"], 1) + Equal(?B["y"], 2) +)"#; + assert_eq!(pretty_printed, expected); + } + + /// Helper function for round-trip testing + fn assert_round_trip(input: &str) { + let params = Params::default(); + let available_batches = &[]; + + // Step 1: Parse the input + let parsed_result = + parse(input, ¶ms, available_batches).expect("Initial parsing should succeed"); + + // Step 2: Pretty-print the parsed batch + let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + + // Step 3: Parse the pretty-printed result + let reparsed_result = + parse(&pretty_printed, ¶ms, available_batches).expect("Reparsing should succeed"); + + // Step 4: Verify the ASTs are equivalent + assert_eq!( + parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates, + "Original AST should match reparsed AST.\nOriginal input:\n{}\nPretty-printed:\n{}\n", + input, pretty_printed + ); + } + + #[test] + fn test_round_trip_simple_predicate() { + let input = r#" + simple_equal(PodA, PodB) = AND( + Equal(?PodA["key"], ?PodB["key"]) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_predicate_with_private_args() { + let input = r#" + uses_private(A, private: Temp) = AND( + Equal(?A["input_key"], ?Temp["const_key"]) + Equal(?Temp["const_key"], "some_value") + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_or_predicate() { + let input = r#" + either_condition(X, Y) = OR( + Equal(?X["status"], "active") + Equal(?Y["type"], 1) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_multiple_predicates() { + let input = r#" + pred_one(A) = AND( + Equal(?A["field"], 42) + ) + + pred_two(B, C) = AND( + Equal(?B["value"], ?C["value"]) + NotEqual(?B["id"], ?C["id"]) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_various_literals() { + let input = r#" + literal_test(Pod) = AND( + Equal(?Pod["int_field"], 123) + Equal(?Pod["string_field"], "hello world") + Equal(?Pod["bool_field"], true) + NotEqual(?Pod["other_bool"], false) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_complex_predicate() { + let input = r#" + complex_predicate(User, Document, private: Verifier, Timestamp) = AND( + Equal(?User["active"], true) + Equal(?Document["owner"], ?User["id"]) + Equal(?Verifier["type"], 1) + Lt(?Timestamp["created"], ?Timestamp["expires"]) + NotContains(?Document["blocked_users"], ?User["id"]) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_with_sum_and_hash_operations() { + let input = r#" + math_operations(A, B, C) = AND( + SumOf(?A["value"], ?B["value"], ?C["total"]) + ProductOf(?A["factor"], ?B["factor"], ?C["product"]) + MaxOf(?A["score"], ?B["score"], ?C["max_score"]) + HashOf(?A["data"], ?B["salt"], ?C["hash"]) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_nested_custom_calls() { + let input = r#" + base_check(Pod) = AND( + Equal(?Pod["status"], "valid") + ) + + derived_check(PodA, PodB) = AND( + base_check(?PodA) + base_check(?PodB) + NotEqual(?PodA["id"], ?PodB["id"]) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_container_operations() { + let input = r#" + container_checks(List, Item, Dict, Key, Value) = AND( + Contains(?List, ?Item, ?Value) + NotContains(?Dict, ?Key) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_pretty_print_demonstration() { + let input = r#" + base_check(Pod) = AND( + Equal(?Pod["status"], "valid") + ) + + derived_check(PodA, PodB) = AND( + base_check(?PodA) + base_check(?PodB) + NotEqual(?PodA["id"], ?PodB["id"]) + ) + "#; + + let params = Params::default(); + let parsed_result = parse(input, ¶ms, &[]).expect("Parsing should succeed"); + + let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + + println!("Original input:\n{}", input); + println!("\nPretty-printed output:\n{}", pretty_printed); + + let reparsed = parse(&pretty_printed, ¶ms, &[]).expect("Reparsing should succeed"); + + assert_eq!( + parsed_result.custom_batch.predicates, + reparsed.custom_batch.predicates + ); + } + + #[test] + fn test_value_pretty_print_string_escaping() { + // Test basic string + let value = Value::from("hello world"); + assert_eq!(value.to_podlang_string(), "\"hello world\""); + + // Test string with quotes + let value = Value::from("say \"hello\""); + assert_eq!(value.to_podlang_string(), "\"say \\\"hello\\\"\""); + + // Test string with backslashes + let value = Value::from("path\\to\\file"); + assert_eq!(value.to_podlang_string(), "\"path\\\\to\\\\file\""); + + // Test string with newlines + let value = Value::from("line1\nline2"); + assert_eq!(value.to_podlang_string(), "\"line1\\nline2\""); + + // Test string with tabs + let value = Value::from("col1\tcol2"); + assert_eq!(value.to_podlang_string(), "\"col1\\tcol2\""); + + // Test string with multiple escape sequences + let value = Value::from("\"quote\"\n\\backslash\\\ttab"); + assert_eq!( + value.to_podlang_string(), + "\"\\\"quote\\\"\\n\\\\backslash\\\\\\ttab\"" + ); + } + + #[test] + fn test_string_escaping_round_trip() { + let test_cases = vec![ + "simple string", + "string with \"quotes\"", + "string with \\backslashes\\", + "string with\nnewlines", + "string with\ttabs", + "mixed: \"quotes\" and \\backslashes\\ and\nnewlines", + "unicode: café résumé", + "", // empty string + ]; + + for test_string in test_cases { + let input = format!( + r#" + test_pred(Pod) = AND( + Equal(?Pod["field"], "{}") + ) + "#, + // Manually escape for the input - this simulates what would be in actual Podlang source + test_string + .replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('\n', "\\n") + .replace('\t', "\\t") + ); + + let params = Params::default(); + let parsed_result = parse(&input, ¶ms, &[]).expect("Should parse successfully"); + + let pretty_printed = parsed_result.custom_batch.to_podlang_string(); + + let reparsed_result = + parse(&pretty_printed, ¶ms, &[]).expect("Should reparse successfully"); + + assert_eq!( + parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates, + "Round-trip failed for string: {:?}\nPretty-printed: {}", + test_string, pretty_printed + ); + } + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 8b43ab7..21f7145 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -3,6 +3,8 @@ use std::sync::Arc; +use hex::ToHex; +use itertools::Itertools; use strum_macros::FromRepr; mod basetypes; use std::{ @@ -189,15 +191,54 @@ impl TryFrom<&TypedValue> for PodId { impl fmt::Display for TypedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TypedValue::String(s) => write!(f, "\"{}\"", s), - TypedValue::Int(v) => write!(f, "{}", v), + TypedValue::Int(i) => write!(f, "{}", i), + TypedValue::String(s) => { + // Use serde_json for proper JSON-style escaping + match serde_json::to_string(s) { + Ok(escaped) => write!(f, "{}", escaped), + Err(_) => write!(f, "\"{}\"", s), + } + } TypedValue::Bool(b) => write!(f, "{}", b), - TypedValue::Dictionary(d) => write!(f, "dict:{}", d.commitment()), - TypedValue::Set(s) => write!(f, "set:{}", s.commitment()), - TypedValue::Array(a) => write!(f, "arr:{}", a.commitment()), - TypedValue::Raw(v) => write!(f, "{}", v), - TypedValue::PublicKey(p) => write!(f, "pk:{}", p), - TypedValue::PodId(id) => write!(f, "pod_id:{}", id), + TypedValue::Array(a) => { + write!(f, "[")?; + for (i, v) in a.array().iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", v)?; + } + write!(f, "]") + } + TypedValue::Dictionary(d) => { + write!(f, "{{ ")?; + let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect(); + for (i, (k, v)) in kvs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", k, v)?; + } + write!(f, " }}") + } + TypedValue::Set(s) => { + write!(f, "#[")?; + let values: Vec<_> = s.set().iter().sorted().collect(); + for (i, v) in values.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", v)?; + } + write!(f, "]") + } + TypedValue::PublicKey(p) => write!(f, "PublicKey({})", p), + TypedValue::PodId(p) => { + write!(f, "0x{}", p.0.encode_hex::()) + } + TypedValue::Raw(r) => { + write!(f, "Raw(0x{})", r.encode_hex::()) + } } } } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index bfdd222..2daca16 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -1,4 +1,7 @@ -use std::{fmt, iter}; +use std::{ + fmt::{self, Display}, + iter, +}; use plonky2::field::types::Field; use schemars::JsonSchema; @@ -44,6 +47,33 @@ pub enum NativePredicate { Gt = 1006, } +impl Display for NativePredicate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + NativePredicate::None => "None", + NativePredicate::False => "False", + NativePredicate::Equal => "Equal", + NativePredicate::NotEqual => "NotEqual", + NativePredicate::Lt => "Lt", + NativePredicate::LtEq => "LtEq", + NativePredicate::Gt => "Gt", + NativePredicate::GtEq => "GtEq", + NativePredicate::Contains => "Contains", + NativePredicate::NotContains => "NotContains", + NativePredicate::SumOf => "SumOf", + NativePredicate::ProductOf => "ProductOf", + NativePredicate::MaxOf => "MaxOf", + NativePredicate::HashOf => "HashOf", + NativePredicate::DictContains => "DictContains", + NativePredicate::DictNotContains => "DictNotContains", + NativePredicate::ArrayContains => "ArrayContains", + NativePredicate::SetContains => "SetContains", + NativePredicate::SetNotContains => "SetNotContains", + }; + write!(f, "{}", s) + } +} + impl ToFields for NativePredicate { fn to_fields(&self, _params: &Params) -> Vec { vec![F::from_canonical_u64(*self as u64)]