Add Podlang pretty-printing (#353)

* Add Podlang pretty-printing

* Review feedback changes

* Formatting

* Use Display impl for printing StatementTmplArg
This commit is contained in:
Rob Knight 2025-07-25 16:43:43 +01:00 committed by GitHub
parent 8429cd224d
commit 9f8335756c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 680 additions and 15 deletions

View file

@ -11,12 +11,6 @@ use crate::{
}, },
}; };
#[derive(Serialize, Deserialize, JsonSchema)]
pub enum SignedPodType {
Signed,
MockSigned,
}
#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
#[schemars(rename = "SignedPod")] #[schemars(rename = "SignedPod")]

View file

@ -1,11 +1,13 @@
pub mod error; pub mod error;
pub mod parser; pub mod parser;
pub mod pretty_print;
pub mod processor; pub mod processor;
use std::sync::Arc; use std::sync::Arc;
pub use error::LangError; pub use error::LangError;
pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use parser::{parse_podlang, Pairs, ParseError, Rule};
pub use pretty_print::PrettyPrint;
pub use processor::process_pest_tree; pub use processor::process_pest_tree;
use processor::PodlangOutput; use processor::PodlangOutput;

598
src/lang/pretty_print.rs Normal file
View file

@ -0,0 +1,598 @@
//! Pretty-printing functionality for POD2 custom predicates
use std::fmt::Write;
use crate::middleware::{
CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value,
};
/// Trait for converting AST nodes to Podlang source code
///
/// This trait provides a consistent interface for pretty-printing different
/// types of AST nodes back to their Podlang source representation.
pub trait PrettyPrint {
/// Write this AST node to a source writer
///
/// Uses default formatting with no indentation.
fn fmt_podlang(&self, w: &mut dyn Write) -> std::fmt::Result {
self.fmt_podlang_with_indent(w, 0)
}
/// Write this AST node to a source writer with custom indentation
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result;
/// Convert this AST node to a Podlang source string
///
/// Uses default formatting with no indentation.
fn to_podlang_string(&self) -> String {
self.to_podlang_string_with_indent(0)
}
/// Convert this AST node to a Podlang source string with custom indentation
fn to_podlang_string_with_indent(&self, indent: usize) -> String {
let mut result = String::new();
let _ = self.fmt_podlang_with_indent(&mut result, indent);
result
}
}
impl PrettyPrint for CustomPredicate {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result {
fmt_predicate_definition(w, self, indent, None)
}
}
impl PrettyPrint for StatementTmpl {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
self.fmt_podlang_with_batch_context(w, None)
}
}
impl StatementTmpl {
fn fmt_podlang_with_batch_context(
&self,
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match &self.pred {
Predicate::Native(native_pred) => {
write!(w, "{}", native_pred)?;
}
Predicate::Custom(custom_ref) => {
write!(w, "{}", custom_ref.predicate().name)?;
}
Predicate::BatchSelf(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates.get(*index) {
write!(w, "{}", predicate.name)?;
} else {
write!(w, "batch_self_{}", index)?;
}
} else {
write!(w, "batch_self_{}", index)?;
}
}
}
write!(w, "(")?;
for (i, arg) in self.args.iter().enumerate() {
if i > 0 {
write!(w, ", ")?;
}
arg.fmt_podlang(w)?;
}
write!(w, ")")?;
Ok(())
}
}
impl PrettyPrint for StatementTmplArg {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self)
}
}
impl PrettyPrint for CustomPredicateBatch {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result {
for (i, predicate) in self.predicates.iter().enumerate() {
if i > 0 {
write!(w, "\n\n")?;
}
self.fmt_predicate_with_context(w, predicate, indent)?;
}
Ok(())
}
}
impl CustomPredicateBatch {
fn fmt_predicate_with_context(
&self,
w: &mut dyn Write,
predicate: &CustomPredicate,
indent: usize,
) -> std::fmt::Result {
fmt_predicate_definition(w, predicate, indent, Some(self))
}
}
impl PrettyPrint for Value {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self.typed())
}
}
fn fmt_predicate_definition(
w: &mut dyn Write,
predicate: &CustomPredicate,
indent: usize,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
let base_indent = " ".repeat(indent);
let statement_indent = " ".repeat(indent + 4);
fmt_predicate_signature(w, predicate, &base_indent)?;
let conjunction_str = if predicate.conjunction { "AND" } else { "OR" };
writeln!(w, " = {}(", conjunction_str)?;
for (i, statement) in predicate.statements.iter().enumerate() {
if i > 0 {
writeln!(w)?;
}
write!(w, "{}", statement_indent)?;
statement.fmt_podlang_with_batch_context(w, batch_context)?;
}
write!(w, "\n{})", base_indent)
}
fn fmt_predicate_signature(
w: &mut dyn Write,
predicate: &CustomPredicate,
base_indent: &str,
) -> std::fmt::Result {
write!(w, "{}{}", base_indent, predicate.name)?;
write!(w, "(")?;
let mut public_args = predicate
.wildcard_names
.iter()
.take(predicate.args_len)
.peekable();
while let Some(arg_name) = public_args.next() {
write!(w, "{}", arg_name)?;
if public_args.peek().is_some() {
write!(w, ", ")?;
}
}
let mut private_args = predicate
.wildcard_names
.iter()
.skip(predicate.args_len)
.peekable();
if private_args.peek().is_some() {
if predicate.args_len > 0 {
write!(w, ", ")?;
}
write!(w, "private: ")?;
while let Some(arg_name) = private_args.next() {
write!(w, "{}", arg_name)?;
if private_args.peek().is_some() {
write!(w, ", ")?;
}
}
}
write!(w, ")")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
lang::parse,
middleware::{
CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl,
StatementTmplArg, Value, Wildcard,
},
};
fn create_test_wildcard(name: &str, index: usize) -> Wildcard {
Wildcard::new(name.to_string(), index)
}
#[test]
fn test_simple_predicate_pretty_print() {
let params = Params::default();
// Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(?PodA["key"], ?PodB["key"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("PodA", 0),
Key::new("key".to_string()),
),
StatementTmplArg::AnchoredKey(
create_test_wildcard("PodB", 1),
Key::new("key".to_string()),
),
],
}];
let predicate = CustomPredicate::and(
&params,
"is_equal".to_string(),
statements,
2, // args_len (PodA, PodB are public)
vec!["PodA".to_string(), "PodB".to_string()],
)
.unwrap();
let pretty_printed = predicate.to_podlang_string();
let expected = r#"is_equal(PodA, PodB) = AND(
Equal(?PodA["key"], ?PodB["key"])
)"#;
assert_eq!(pretty_printed, expected);
}
#[test]
fn test_predicate_with_private_args() {
let params = Params::default();
// Create: uses_private(A, private: Temp) = AND(Equal(?A["input"], ?Temp["const"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
Key::new("input".to_string()),
),
StatementTmplArg::AnchoredKey(
create_test_wildcard("Temp", 1),
Key::new("const".to_string()),
),
],
}];
let predicate = CustomPredicate::and(
&params,
"uses_private".to_string(),
statements,
1, // args_len (only A is public)
vec!["A".to_string(), "Temp".to_string()],
)
.unwrap();
let pretty_printed = predicate.to_podlang_string();
let expected = r#"uses_private(A, private: Temp) = AND(
Equal(?A["input"], ?Temp["const"])
)"#;
assert_eq!(pretty_printed, expected);
}
#[test]
fn test_statement_with_literal_args() {
let params = Params::default();
// Create: check_value(Pod) = AND(Equal(?Pod["field"], 42))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("Pod", 0),
Key::new("field".to_string()),
),
StatementTmplArg::Literal(Value::from(42i64)),
],
}];
let predicate = CustomPredicate::and(
&params,
"check_value".to_string(),
statements,
1,
vec!["Pod".to_string()],
)
.unwrap();
let pretty_printed = predicate.to_podlang_string();
let expected = r#"check_value(Pod) = AND(
Equal(?Pod["field"], 42)
)"#;
assert_eq!(pretty_printed, expected);
}
#[test]
fn test_or_predicate() {
let params = Params::default();
// Create: either_or(A, B) = OR(Equal(?A["x"], 1), Equal(?B["y"], 2))
let statements = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
Key::new("x".to_string()),
),
StatementTmplArg::Literal(Value::from(1i64)),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("B", 1),
Key::new("y".to_string()),
),
StatementTmplArg::Literal(Value::from(2i64)),
],
},
];
let predicate = CustomPredicate::or(
&params,
"either_or".to_string(),
statements,
2,
vec!["A".to_string(), "B".to_string()],
)
.unwrap();
let pretty_printed = predicate.to_podlang_string();
let expected = r#"either_or(A, B) = OR(
Equal(?A["x"], 1)
Equal(?B["y"], 2)
)"#;
assert_eq!(pretty_printed, expected);
}
/// Helper function for round-trip testing
fn assert_round_trip(input: &str) {
let params = Params::default();
let available_batches = &[];
// Step 1: Parse the input
let parsed_result =
parse(input, &params, available_batches).expect("Initial parsing should succeed");
// Step 2: Pretty-print the parsed batch
let pretty_printed = parsed_result.custom_batch.to_podlang_string();
// Step 3: Parse the pretty-printed result
let reparsed_result =
parse(&pretty_printed, &params, available_batches).expect("Reparsing should succeed");
// Step 4: Verify the ASTs are equivalent
assert_eq!(
parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates,
"Original AST should match reparsed AST.\nOriginal input:\n{}\nPretty-printed:\n{}\n",
input, pretty_printed
);
}
#[test]
fn test_round_trip_simple_predicate() {
let input = r#"
simple_equal(PodA, PodB) = AND(
Equal(?PodA["key"], ?PodB["key"])
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_predicate_with_private_args() {
let input = r#"
uses_private(A, private: Temp) = AND(
Equal(?A["input_key"], ?Temp["const_key"])
Equal(?Temp["const_key"], "some_value")
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_or_predicate() {
let input = r#"
either_condition(X, Y) = OR(
Equal(?X["status"], "active")
Equal(?Y["type"], 1)
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_multiple_predicates() {
let input = r#"
pred_one(A) = AND(
Equal(?A["field"], 42)
)
pred_two(B, C) = AND(
Equal(?B["value"], ?C["value"])
NotEqual(?B["id"], ?C["id"])
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_various_literals() {
let input = r#"
literal_test(Pod) = AND(
Equal(?Pod["int_field"], 123)
Equal(?Pod["string_field"], "hello world")
Equal(?Pod["bool_field"], true)
NotEqual(?Pod["other_bool"], false)
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_complex_predicate() {
let input = r#"
complex_predicate(User, Document, private: Verifier, Timestamp) = AND(
Equal(?User["active"], true)
Equal(?Document["owner"], ?User["id"])
Equal(?Verifier["type"], 1)
Lt(?Timestamp["created"], ?Timestamp["expires"])
NotContains(?Document["blocked_users"], ?User["id"])
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_with_sum_and_hash_operations() {
let input = r#"
math_operations(A, B, C) = AND(
SumOf(?A["value"], ?B["value"], ?C["total"])
ProductOf(?A["factor"], ?B["factor"], ?C["product"])
MaxOf(?A["score"], ?B["score"], ?C["max_score"])
HashOf(?A["data"], ?B["salt"], ?C["hash"])
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_nested_custom_calls() {
let input = r#"
base_check(Pod) = AND(
Equal(?Pod["status"], "valid")
)
derived_check(PodA, PodB) = AND(
base_check(?PodA)
base_check(?PodB)
NotEqual(?PodA["id"], ?PodB["id"])
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_container_operations() {
let input = r#"
container_checks(List, Item, Dict, Key, Value) = AND(
Contains(?List, ?Item, ?Value)
NotContains(?Dict, ?Key)
)
"#;
assert_round_trip(input);
}
#[test]
fn test_pretty_print_demonstration() {
let input = r#"
base_check(Pod) = AND(
Equal(?Pod["status"], "valid")
)
derived_check(PodA, PodB) = AND(
base_check(?PodA)
base_check(?PodB)
NotEqual(?PodA["id"], ?PodB["id"])
)
"#;
let params = Params::default();
let parsed_result = parse(input, &params, &[]).expect("Parsing should succeed");
let pretty_printed = parsed_result.custom_batch.to_podlang_string();
println!("Original input:\n{}", input);
println!("\nPretty-printed output:\n{}", pretty_printed);
let reparsed = parse(&pretty_printed, &params, &[]).expect("Reparsing should succeed");
assert_eq!(
parsed_result.custom_batch.predicates,
reparsed.custom_batch.predicates
);
}
#[test]
fn test_value_pretty_print_string_escaping() {
// Test basic string
let value = Value::from("hello world");
assert_eq!(value.to_podlang_string(), "\"hello world\"");
// Test string with quotes
let value = Value::from("say \"hello\"");
assert_eq!(value.to_podlang_string(), "\"say \\\"hello\\\"\"");
// Test string with backslashes
let value = Value::from("path\\to\\file");
assert_eq!(value.to_podlang_string(), "\"path\\\\to\\\\file\"");
// Test string with newlines
let value = Value::from("line1\nline2");
assert_eq!(value.to_podlang_string(), "\"line1\\nline2\"");
// Test string with tabs
let value = Value::from("col1\tcol2");
assert_eq!(value.to_podlang_string(), "\"col1\\tcol2\"");
// Test string with multiple escape sequences
let value = Value::from("\"quote\"\n\\backslash\\\ttab");
assert_eq!(
value.to_podlang_string(),
"\"\\\"quote\\\"\\n\\\\backslash\\\\\\ttab\""
);
}
#[test]
fn test_string_escaping_round_trip() {
let test_cases = vec![
"simple string",
"string with \"quotes\"",
"string with \\backslashes\\",
"string with\nnewlines",
"string with\ttabs",
"mixed: \"quotes\" and \\backslashes\\ and\nnewlines",
"unicode: café résumé",
"", // empty string
];
for test_string in test_cases {
let input = format!(
r#"
test_pred(Pod) = AND(
Equal(?Pod["field"], "{}")
)
"#,
// Manually escape for the input - this simulates what would be in actual Podlang source
test_string
.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\t', "\\t")
);
let params = Params::default();
let parsed_result = parse(&input, &params, &[]).expect("Should parse successfully");
let pretty_printed = parsed_result.custom_batch.to_podlang_string();
let reparsed_result =
parse(&pretty_printed, &params, &[]).expect("Should reparse successfully");
assert_eq!(
parsed_result.custom_batch.predicates, reparsed_result.custom_batch.predicates,
"Round-trip failed for string: {:?}\nPretty-printed: {}",
test_string, pretty_printed
);
}
}
}

View file

@ -3,6 +3,8 @@
use std::sync::Arc; use std::sync::Arc;
use hex::ToHex;
use itertools::Itertools;
use strum_macros::FromRepr; use strum_macros::FromRepr;
mod basetypes; mod basetypes;
use std::{ use std::{
@ -189,15 +191,54 @@ impl TryFrom<&TypedValue> for PodId {
impl fmt::Display for TypedValue { impl fmt::Display for TypedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
TypedValue::String(s) => write!(f, "\"{}\"", s), TypedValue::Int(i) => write!(f, "{}", i),
TypedValue::Int(v) => write!(f, "{}", v), TypedValue::String(s) => {
// Use serde_json for proper JSON-style escaping
match serde_json::to_string(s) {
Ok(escaped) => write!(f, "{}", escaped),
Err(_) => write!(f, "\"{}\"", s),
}
}
TypedValue::Bool(b) => write!(f, "{}", b), TypedValue::Bool(b) => write!(f, "{}", b),
TypedValue::Dictionary(d) => write!(f, "dict:{}", d.commitment()), TypedValue::Array(a) => {
TypedValue::Set(s) => write!(f, "set:{}", s.commitment()), write!(f, "[")?;
TypedValue::Array(a) => write!(f, "arr:{}", a.commitment()), for (i, v) in a.array().iter().enumerate() {
TypedValue::Raw(v) => write!(f, "{}", v), if i > 0 {
TypedValue::PublicKey(p) => write!(f, "pk:{}", p), write!(f, ", ")?;
TypedValue::PodId(id) => write!(f, "pod_id:{}", id), }
write!(f, "{}", v)?;
}
write!(f, "]")
}
TypedValue::Dictionary(d) => {
write!(f, "{{ ")?;
let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect();
for (i, (k, v)) in kvs.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}: {}", k, v)?;
}
write!(f, " }}")
}
TypedValue::Set(s) => {
write!(f, "#[")?;
let values: Vec<_> = s.set().iter().sorted().collect();
for (i, v) in values.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", v)?;
}
write!(f, "]")
}
TypedValue::PublicKey(p) => write!(f, "PublicKey({})", p),
TypedValue::PodId(p) => {
write!(f, "0x{}", p.0.encode_hex::<String>())
}
TypedValue::Raw(r) => {
write!(f, "Raw(0x{})", r.encode_hex::<String>())
}
} }
} }
} }

View file

@ -1,4 +1,7 @@
use std::{fmt, iter}; use std::{
fmt::{self, Display},
iter,
};
use plonky2::field::types::Field; use plonky2::field::types::Field;
use schemars::JsonSchema; use schemars::JsonSchema;
@ -44,6 +47,33 @@ pub enum NativePredicate {
Gt = 1006, Gt = 1006,
} }
impl Display for NativePredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
NativePredicate::None => "None",
NativePredicate::False => "False",
NativePredicate::Equal => "Equal",
NativePredicate::NotEqual => "NotEqual",
NativePredicate::Lt => "Lt",
NativePredicate::LtEq => "LtEq",
NativePredicate::Gt => "Gt",
NativePredicate::GtEq => "GtEq",
NativePredicate::Contains => "Contains",
NativePredicate::NotContains => "NotContains",
NativePredicate::SumOf => "SumOf",
NativePredicate::ProductOf => "ProductOf",
NativePredicate::MaxOf => "MaxOf",
NativePredicate::HashOf => "HashOf",
NativePredicate::DictContains => "DictContains",
NativePredicate::DictNotContains => "DictNotContains",
NativePredicate::ArrayContains => "ArrayContains",
NativePredicate::SetContains => "SetContains",
NativePredicate::SetNotContains => "SetNotContains",
};
write!(f, "{}", s)
}
}
impl ToFields for NativePredicate { impl ToFields for NativePredicate {
fn to_fields(&self, _params: &Params) -> Vec<F> { fn to_fields(&self, _params: &Params) -> Vec<F> {
vec![F::from_canonical_u64(*self as u64)] vec![F::from_canonical_u64(*self as u64)]