From 24cafde231d532353d9f29cbabf39d24ce563e6c Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Wed, 2 Jul 2025 18:27:54 +0200 Subject: [PATCH] Assorted tweaks to support external playground crate (#322) * Assorted tweaks to support external playground crate * Fix schemas * Fixed schema again * Add ToHex for RawValue * Add FromHex to RawValue --- Cargo.toml | 1 + src/frontend/mod.rs | 2 +- src/frontend/serialization.rs | 8 ++--- src/lang/processor.rs | 4 +-- src/lib.rs | 2 +- src/middleware/basetypes.rs | 34 +++++++++++++++++++ src/middleware/custom.rs | 33 ++++++++++++++---- src/middleware/mod.rs | 64 ++++++++++++++++++++--------------- src/middleware/operation.rs | 2 +- src/middleware/statement.rs | 7 ++-- 10 files changed, 111 insertions(+), 46 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dfea0c9..e941c1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,4 @@ backend_plonky2 = ["plonky2"] zk = [] metrics = [] time = [] +examples = [] diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 316ed50..a1eb39d 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, convert::From, fmt}; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use serialization::{SerializedMainPod, SerializedSignedPod}; +pub use serialization::{SerializedMainPod, SerializedSignedPod}; use crate::middleware::{ self, check_st_tmpl, hash_op, hash_str, max_op, prod_op, sum_op, AnchoredKey, Key, diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 05d5de4..7581418 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -17,7 +17,7 @@ pub enum SignedPodType { MockSigned, } -#[derive(Serialize, Deserialize, JsonSchema)] +#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[schemars(rename = "SignedPod")] pub struct SerializedSignedPod { @@ -27,7 +27,7 @@ pub struct SerializedSignedPod { data: serde_json::Value, } -#[derive(Serialize, Deserialize, JsonSchema)] +#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] #[schemars(rename = "MainPod")] pub struct SerializedMainPod { @@ -154,11 +154,11 @@ mod tests { ])) .unwrap(), ), - "{\"Dictionary\":{\"max_depth\":32,\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}}", + "{\"max_depth\":32,\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}", ), ( TypedValue::Set(Set::new(params.max_depth_mt_containers, HashSet::from(["foo".into(), "bar".into()])).unwrap()), - "{\"Set\":{\"max_depth\":32,\"set\":[\"bar\",\"foo\"]}}", + "{\"max_depth\":32,\"set\":[\"bar\",\"foo\"]}", ), ]; diff --git a/src/lang/processor.rs b/src/lang/processor.rs index 6672ed6..f0b5316 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -336,9 +336,9 @@ fn validate_and_build_statement_template( | NativePredicate::Lt | NativePredicate::LtEq | NativePredicate::SetContains - | NativePredicate::NotContains | NativePredicate::DictNotContains - | NativePredicate::SetNotContains => 2, + | NativePredicate::SetNotContains + | NativePredicate::NotContains => 2, NativePredicate::Contains | NativePredicate::ArrayContains | NativePredicate::DictContains diff --git a/src/lib.rs b/src/lib.rs index ee423c7..11b20bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ pub mod frontend; pub mod lang; pub mod middleware; -#[cfg(test)] +#[cfg(any(test, feature = "examples"))] pub mod examples; #[cfg(feature = "time")] diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 86cf2b6..fa9e8e8 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -125,6 +125,40 @@ impl fmt::Display for RawValue { } } +impl ToHex for RawValue { + fn encode_hex>(&self) -> T { + self.0 + .iter() + .rev() + .fold(String::with_capacity(64), |mut s, limb| { + write!(s, "{:016x}", limb.0).unwrap(); + s + }) + .chars() + .collect() + } + + fn encode_hex_upper>(&self) -> T { + self.0 + .iter() + .rev() + .fold(String::with_capacity(64), |mut s, limb| { + write!(s, "{:016X}", limb.0).unwrap(); + s + }) + .chars() + .collect() + } +} + +impl FromHex for RawValue { + type Error = FromHexError; + + fn from_hex>(hex: T) -> Result { + Hash::from_hex(hex).map(|h| RawValue(h.0)) + } +} + #[derive(Clone, Copy, Debug, Default, Hash, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct Hash( #[serde( diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index e7168f3..3b54613 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -9,7 +9,7 @@ use crate::middleware::{ EMPTY_HASH, F, VALUE_SIZE, }; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub struct Wildcard { pub name: String, pub index: usize, @@ -37,7 +37,7 @@ impl ToFields for Wildcard { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[serde(tag = "type", content = "value")] pub enum StatementTmplArg { None, @@ -122,7 +122,7 @@ impl fmt::Display for StatementTmplArg { } /// Statement Template for a Custom Predicate -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub struct StatementTmpl { pub pred: Predicate, pub args: Vec, @@ -179,7 +179,7 @@ impl ToFields for StatementTmpl { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] /// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through /// the `::and/or` methods, which performs checks on the values. @@ -275,6 +275,21 @@ impl CustomPredicate { args: vec![], } } + pub fn is_conjunction(&self) -> bool { + self.conjunction + } + pub fn is_disjunction(&self) -> bool { + !self.conjunction + } + pub fn statements(&self) -> &[StatementTmpl] { + &self.statements + } + pub fn args_len(&self) -> usize { + self.args_len + } + pub fn wildcard_names(&self) -> &[String] { + &self.wildcard_names + } } impl ToFields for CustomPredicate { @@ -341,13 +356,19 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct CustomPredicateBatch { id: Hash, pub name: String, pub(crate) predicates: Vec, } +impl std::hash::Hash for CustomPredicateBatch { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} + impl ToFields for CustomPredicateBatch { fn to_fields(&self, params: &Params) -> Vec { // all the custom predicates in order @@ -401,7 +422,7 @@ impl CustomPredicateBatch { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub struct CustomPredicateRef { pub batch: Arc, pub index: usize, diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 4689946..0b79613 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -53,8 +53,6 @@ pub enum TypedValue { // 53-bit precision for integers, integers are represented as tagged // strings, with a custom serializer and deserializer. // TAGGED TYPES: - Set(Set), - Dictionary(Dictionary), Int( #[serde(serialize_with = "serialize_i64", deserialize_with = "deserialize_i64")] // #[schemars(with = "String", regex(pattern = r"^\d+$"))] @@ -67,6 +65,10 @@ pub enum TypedValue { PodId(PodId), // UNTAGGED TYPES: #[serde(untagged)] + Set(Set), + #[serde(untagged)] + Dictionary(Dictionary), + #[serde(untagged)] Array(Array), #[serde(untagged)] String(String), @@ -170,6 +172,20 @@ impl TryFrom for Key { } } +impl TryFrom<&TypedValue> for PodId { + type Error = Error; + fn try_from(v: &TypedValue) -> Result { + match v { + TypedValue::PodId(id) => Ok(*id), + TypedValue::Raw(v) => Ok(PodId(Hash(v.0))), + _ => Err(Error::custom(format!( + "Value {} cannot be converted to a PodId.", + v + ))), + } + } +} + impl fmt::Display for TypedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -216,30 +232,6 @@ impl JsonSchema for TypedValue { fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { use schemars::schema::{InstanceType, Schema, SchemaObject, SingleOrVec}; - let set_schema = schemars::schema::SchemaObject { - instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))), - object: Some(Box::new(schemars::schema::ObjectValidation { - properties: [("Set".to_string(), gen.subschema_for::())] - .into_iter() - .collect(), - required: ["Set".to_string()].into_iter().collect(), - ..Default::default() - })), - ..Default::default() - }; - - let dictionary_schema = schemars::schema::SchemaObject { - instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))), - object: Some(Box::new(schemars::schema::ObjectValidation { - properties: [("Dictionary".to_string(), gen.subschema_for::())] - .into_iter() - .collect(), - required: ["Dictionary".to_string()].into_iter().collect(), - ..Default::default() - })), - ..Default::default() - }; - // Int is serialized/deserialized as a tagged string let int_schema = schemars::schema::SchemaObject { instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))), @@ -275,20 +267,36 @@ impl JsonSchema for TypedValue { ..Default::default() }; + let public_key_schema = schemars::schema::SchemaObject { + instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))), + object: Some(Box::new(schemars::schema::ObjectValidation { + // PublicKey is serialized as a string + properties: [("PublicKey".to_string(), gen.subschema_for::())] + .into_iter() + .collect(), + required: ["PublicKey".to_string()].into_iter().collect(), + ..Default::default() + })), + ..Default::default() + }; + // This is the part that Schemars can't generate automatically: let untagged_array_schema = gen.subschema_for::(); + let untagged_set_schema = gen.subschema_for::(); + let untagged_dictionary_schema = gen.subschema_for::(); let untagged_string_schema = gen.subschema_for::(); let untagged_bool_schema = gen.subschema_for::(); Schema::Object(SchemaObject { subschemas: Some(Box::new(schemars::schema::SubschemaValidation { any_of: Some(vec![ - Schema::Object(set_schema), - Schema::Object(dictionary_schema), Schema::Object(int_schema), Schema::Object(raw_schema), + Schema::Object(public_key_schema), untagged_array_schema, + untagged_dictionary_schema, untagged_string_schema, + untagged_set_schema, untagged_bool_schema, ]), ..Default::default() diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 75bc367..5da69fb 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -56,7 +56,7 @@ impl ToFields for OperationType { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum NativeOperation { None = 0, NewEntry = 1, diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 551f323..bfdd222 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -50,7 +50,7 @@ impl ToFields for NativePredicate { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[serde(tag = "type", content = "value")] pub enum Predicate { Native(NativePredicate), @@ -129,7 +129,7 @@ impl fmt::Display for Predicate { } /// Type encapsulating statements with their associated arguments. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[serde(tag = "predicate", content = "args")] pub enum Statement { None, @@ -368,7 +368,8 @@ impl ToFields for StatementArg { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] pub enum ValueRef { Literal(Value), Key(AnchoredKey),