From 9d60b0ec3a677490790fad792bd2a7b657c97537 Mon Sep 17 00:00:00 2001 From: Ahmad Afuni Date: Wed, 5 Mar 2025 21:02:28 +1000 Subject: [PATCH] Frontend work (#109) --- src/backends/plonky2/mock_signed.rs | 12 +- src/examples.rs | 136 +++++++++++++++++- src/examples/custom.rs | 158 +++++++++++++++++++++ src/frontend/custom.rs | 165 +++------------------- src/frontend/mod.rs | 208 +++++++++++++++++++++++----- src/frontend/operation.rs | 10 +- src/frontend/statement.rs | 23 ++- src/middleware/custom.rs | 71 +++++++++- src/middleware/operation.rs | 90 ++++-------- 9 files changed, 611 insertions(+), 262 deletions(-) create mode 100644 src/examples/custom.rs diff --git a/src/backends/plonky2/mock_signed.rs b/src/backends/plonky2/mock_signed.rs index 51ce54e..8cfb793 100644 --- a/src/backends/plonky2/mock_signed.rs +++ b/src/backends/plonky2/mock_signed.rs @@ -13,16 +13,22 @@ pub struct MockSigner { pub pk: String, } +impl MockSigner { + pub fn pubkey(&self) -> Value { + Value(hash_str(&self.pk).0) + } +} + impl PodSigner for MockSigner { fn sign(&mut self, _params: &Params, kvs: &HashMap) -> Result> { let mut kvs = kvs.clone(); - let pk_hash = hash_str(&self.pk); - kvs.insert(hash_str(&KEY_SIGNER), Value(pk_hash.0)); + let pubkey = self.pubkey(); + kvs.insert(hash_str(&KEY_SIGNER), pubkey); kvs.insert(hash_str(&KEY_TYPE), Value::from(PodType::MockSigned)); let dict = Dictionary::new(&kvs)?; let id = PodId(dict.commitment()); - let signature = format!("{}_signed_by_{}", id, pk_hash); + let signature = format!("{}_signed_by_{}", id, pubkey); Ok(Box::new(MockSignedPod { dict, id, diff --git a/src/examples.rs b/src/examples.rs index 328c000..be19a93 100644 --- a/src/examples.rs +++ b/src/examples.rs @@ -1,11 +1,16 @@ -use anyhow::Result; +pub mod custom; + +use anyhow::{anyhow, Result}; +use custom::{eth_dos_batch, eth_friend_batch}; use std::collections::HashMap; use crate::backends::plonky2::mock_signed::MockSigner; -use crate::frontend::{MainPodBuilder, SignedPod, SignedPodBuilder, Value}; +use crate::frontend::{ + MainPodBuilder, Operation, OperationArg, SignedPod, SignedPodBuilder, Statement, Value, +}; use crate::middleware::containers::Set; -use crate::middleware::hash_str; use crate::middleware::{containers::Dictionary, Params, PodType, KEY_SIGNER, KEY_TYPE}; +use crate::middleware::{hash_str, CustomPredicateRef, NativeOperation, OperationType, Pod}; use crate::op; // ZuKYC @@ -61,6 +66,131 @@ pub fn zu_kyc_pod_builder( Ok(kyc) } +// ETHDoS + +pub fn eth_friend_signed_pod_builder(params: &Params, friend_pubkey: Value) -> SignedPodBuilder { + let mut attestation = SignedPodBuilder::new(params); + attestation.insert("attestation", friend_pubkey); + + attestation +} + +pub fn eth_dos_pod_builder( + params: &Params, + alice_attestation: &SignedPod, + charlie_attestation: &SignedPod, + bob_pubkey: &Value, +) -> Result { + // Will need ETH friend and ETH DoS custom predicate batches. + let eth_friend_batch = eth_friend_batch(params)?; + let eth_dos_batch = eth_dos_batch(params)?; + + // ETHDoS POD builder + let mut alice_bob_ethdos = MainPodBuilder::new(params); + alice_bob_ethdos.add_signed_pod(&alice_attestation); + alice_bob_ethdos.add_signed_pod(&charlie_attestation); + + // Attestation POD entries + let alice_pubkey = alice_attestation + .kvs() + .get(&KEY_SIGNER.into()) + .ok_or(anyhow!("Could not find Alice's public key!"))? + .clone(); + let charlie_pubkey = charlie_attestation + .kvs() + .get(&KEY_SIGNER.into()) + .ok_or(anyhow!("Could not find Charlie's public key!"))? + .clone(); + + // Include Alice and Bob's keys as public statements. + let alice_pubkey_copy = alice_bob_ethdos.pub_op(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry( + "Alice".to_string(), + alice_pubkey.into(), + )], + ))?; + let bob_pubkey_copy = alice_bob_ethdos.pub_op(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry( + "Bob".to_string(), + bob_pubkey.clone().into(), + )], + ))?; + let charlie_pubkey = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry( + "Charlie".to_string(), + charlie_pubkey.into(), + )], + ))?; + + // The ETHDoS distance from Alice to Alice is 0. + let zero = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry("ZERO".to_string(), Value::from(0i64))], + ))?; + let alice_equals_alice = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + (alice_attestation, KEY_SIGNER).into(), + OperationArg::Statement(alice_pubkey_copy.clone()), + ], + ))?; + let ethdos_alice_alice_is_zero = alice_bob_ethdos.priv_op(Operation( + OperationType::Custom(CustomPredicateRef(eth_dos_batch, 0)), + vec![ + OperationArg::Statement(alice_equals_alice), + OperationArg::Statement(zero.clone()), + ], + ))?; + + // Alice and Charlie are ETH friends. + let attestation_is_signed_pod = Statement::from((alice_attestation, KEY_TYPE)); + let attestation_signed_by_alice = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + OperationArg::from((alice_attestation, KEY_SIGNER)), + OperationArg::Statement(alice_pubkey_copy), + ], + ))?; + let alice_attests_to_charlie = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + OperationArg::from((alice_attestation, "attestation")), + OperationArg::Statement(charlie_pubkey), + ], + ))?; + let ethfriends_alice_charlie = alice_bob_ethdos.priv_op(Operation( + OperationType::Custom(CustomPredicateRef(eth_friend_batch, 0)), + vec![ + OperationArg::Statement(attestation_is_signed_pod), + OperationArg::Statement(attestation_signed_by_alice), + OperationArg::Statement(alice_attests_to_charlie), + ], + ))?; + + // The ETHDoS distance from Alice to Charlie is 1. + let one = alice_bob_ethdos.priv_op(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry("ZERO".to_string(), Value::from(0i64))], + ))?; + // 1 = 0 + 1 + // let ethdos_sum = alice_bob_ethdos.priv_op( + // Operation( + // OperationType::Native(NativeOperation::SumOf + // ), + // vec![ + // OperationArg::Statement(one.clone()), + // OperationArg::Statement(zero.clone()), + // OperationArg::Statement(zero.clone()) + // ] + // ) + // ); + + Ok(alice_bob_ethdos) +} + // GreatBoy pub fn good_boy_sign_pod_builder(params: &Params, user: &str, age: i64) -> SignedPodBuilder { diff --git a/src/examples/custom.rs b/src/examples/custom.rs new file mode 100644 index 0000000..8d3b541 --- /dev/null +++ b/src/examples/custom.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use anyhow::Result; + +use crate::{ + frontend::{literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, PodType, Predicate, + KEY_SIGNER, KEY_TYPE, + }, +}; + +use NativePredicate as NP; +use StatementTmplBuilder as STB; + +/// Instantiates an ETH friend batch +pub fn eth_friend_batch(params: &Params) -> Result> { + let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); + let _eth_friend = builder.predicate_and( + params, + // arguments: + &["src_ori", "src_key", "dst_ori", "dst_key"], + // private arguments: + &["attestation_pod"], + // statement templates: + &[ + // there is an attestation pod that's a SignedPod + STB::new(NP::ValueOf) + .arg(("attestation_pod", literal(KEY_TYPE))) + .arg(PodType::MockSigned), // TODO + // the attestation pod is signed by (src_or, src_key) + STB::new(NP::Equal) + .arg(("attestation_pod", literal(KEY_SIGNER))) + .arg(("src_ori", "src_key")), + // that same attestation pod has an "attestation" + STB::new(NP::Equal) + .arg(("attestation_pod", literal("attestation"))) + .arg(("dst_ori", "dst_key")), + ], + )?; + + println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); + Ok(builder.finish()) +} + +/// Instantiates an ETHDoS batch +pub fn eth_dos_batch(params: &Params) -> Result> { + let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend_batch(params)?, 0)); + let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); + + // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< + // eq(src_or, src_key, dst_or, dst_key), + // ValueOf(distance_or, distance_key, 0) + // > + let eth_dos_distance_base = builder.predicate_and( + ¶ms, + &[ + // arguments: + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[ // private arguments: + ], + &[ + // statement templates: + STB::new(NP::Equal) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")), + STB::new(NP::ValueOf) + .arg(("distance_ori", "distance_key")) + .arg(0), + ], + )?; + println!( + "b.0. eth_dos_distance_base = {}", + builder.predicates.last().unwrap() + ); + + let eth_dos_distance = Predicate::BatchSelf(2); + + let eth_dos_distance_ind = builder.predicate_and( + ¶ms, + &[ + // arguments: + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[ + // private arguments: + "one_ori", + "one_key", + "shorter_distance_ori", + "shorter_distance_key", + "intermed_ori", + "intermed_key", + ], + &[ + // statement templates: + STB::new(eth_dos_distance) + .arg(("src_ori", "src_key")) + .arg(("intermed_ori", "intermed_key")) + .arg(("shorter_distance_ori", "shorter_distance_key")), + // distance == shorter_distance + 1 + STB::new(NP::ValueOf).arg(("one_ori", "one_key")).arg(1), + STB::new(NP::SumOf) + .arg(("distance_ori", "distance_key")) + .arg(("shorter_distance_ori", "shorter_distance_key")) + .arg(("one_ori", "one_key")), + // intermed is a friend of dst + STB::new(eth_friend) + .arg(("intermed_ori", "intermed_key")) + .arg(("dst_ori", "dst_key")), + ], + )?; + + println!( + "b.1. eth_dos_distance_ind = {}", + builder.predicates.last().unwrap() + ); + + let _eth_dos_distance = builder.predicate_or( + ¶ms, + &[ + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[], + &[ + STB::new(eth_dos_distance_base) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")) + .arg(("distance_ori", "distance_key")), + STB::new(eth_dos_distance_ind) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")) + .arg(("distance_ori", "distance_key")), + ], + )?; + + println!( + "b.2. eth_dos_distance = {}", + builder.predicates.last().unwrap() + ); + + Ok(builder.finish()) +} diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index a589036..668ee71 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -63,39 +63,39 @@ where } } -struct StatementTmplBuilder { +pub struct StatementTmplBuilder { predicate: Predicate, args: Vec, } impl StatementTmplBuilder { - fn new(p: impl Into) -> StatementTmplBuilder { + pub fn new(p: impl Into) -> StatementTmplBuilder { StatementTmplBuilder { predicate: p.into(), args: Vec::new(), } } - fn arg(mut self, a: impl Into) -> Self { + pub fn arg(mut self, a: impl Into) -> Self { self.args.push(a.into()); self } } -struct CustomPredicateBatchBuilder { - name: String, - predicates: Vec, +pub struct CustomPredicateBatchBuilder { + pub name: String, + pub predicates: Vec, } impl CustomPredicateBatchBuilder { - fn new(name: String) -> Self { + pub fn new(name: String) -> Self { Self { name, predicates: Vec::new(), } } - fn predicate_and( + pub fn predicate_and( &mut self, params: &Params, args: &[&str], @@ -105,7 +105,7 @@ impl CustomPredicateBatchBuilder { self.predicate(params, true, args, priv_args, sts) } - fn predicate_or( + pub fn predicate_or( &mut self, params: &Params, args: &[&str], @@ -147,7 +147,7 @@ impl CustomPredicateBatchBuilder { Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } - fn finish(self) -> Arc { + pub fn finish(self) -> Arc { Arc::new(CustomPredicateBatch { name: self.name, predicates: self.predicates, @@ -171,7 +171,10 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) -> #[cfg(test)] mod tests { use super::*; - use crate::middleware::{CustomPredicateRef, Params, PodType}; + use crate::{ + examples::custom::{eth_dos_batch, eth_friend_batch}, + middleware::{CustomPredicateRef, Params, PodType}, + }; #[test] fn test_custom_pred() -> Result<()> { @@ -181,146 +184,14 @@ mod tests { let params = Params::default(); params.print_serialized_sizes(); - let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); - let _eth_friend = builder.predicate_and( - ¶ms, - // arguments: - &["src_ori", "src_key", "dst_ori", "dst_key"], - // private arguments: - &["attestation_pod"], - // statement templates: - &[ - // there is an attestation pod that's a SignedPod - STB::new(NP::ValueOf) - .arg(("attestation_pod", literal("type"))) - .arg(PodType::Signed), - // the attestation pod is signed by (src_or, src_key) - STB::new(NP::Equal) - .arg(("attestation_pod", literal("signer"))) - .arg(("src_ori", "src_key")), - // that same attestation pod has an "attestation" - STB::new(NP::Equal) - .arg(("attestation_pod", literal("attestation"))) - .arg(("dst_ori", "dst_key")), - ], - )?; + // ETH friend custom predicate batch + let eth_friend = eth_friend_batch(¶ms)?; - println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); - let eth_friend = builder.finish(); // This batch only has 1 predicate, so we pick it already for convenience let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend, 0)); - // next chunk builds: - // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< - // eq(src_or, src_key, dst_or, dst_key), - // ValueOf(distance_or, distance_key, 0) - // > - let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); - let eth_dos_distance_base = builder.predicate_and( - ¶ms, - &[ - // arguments: - "src_ori", - "src_key", - "dst_ori", - "dst_key", - "distance_ori", - "distance_key", - ], - &[ // private arguments: - ], - &[ - // statement templates: - STB::new(NP::Equal) - .arg(("src_ori", "src_key")) - .arg(("dst_ori", "dst_key")), - STB::new(NP::ValueOf) - .arg(("distance_ori", "distance_key")) - .arg(0), - ], - )?; - println!( - "b.0. eth_dos_distance_base = {}", - builder.predicates.last().unwrap() - ); - - let eth_dos_distance = Predicate::BatchSelf(2); - - // next chunk builds: - let eth_dos_distance_ind = builder.predicate_and( - ¶ms, - &[ - // arguments: - "src_ori", - "src_key", - "dst_ori", - "dst_key", - "distance_ori", - "distance_key", - ], - &[ - // private arguments: - "one_ori", - "one_key", - "shorter_distance_ori", - "shorter_distance_key", - "intermed_ori", - "intermed_key", - ], - &[ - // statement templates: - STB::new(eth_dos_distance) - .arg(("src_ori", "src_key")) - .arg(("intermed_ori", "intermed_key")) - .arg(("shorter_distance_ori", "shorter_distance_key")), - // distance == shorter_distance + 1 - STB::new(NP::ValueOf).arg(("one_ori", "one_key")).arg(1), - STB::new(NP::SumOf) - .arg(("distance_ori", "distance_key")) - .arg(("shorter_distance_ori", "shorter_distance_key")) - .arg(("one_ori", "one_key")), - // intermed is a friend of dst - STB::new(eth_friend) - .arg(("intermed_ori", "intermed_key")) - .arg(("dst_ori", "dst_key")), - ], - )?; - - println!( - "b.1. eth_dos_distance_ind = {}", - builder.predicates.last().unwrap() - ); - - let _eth_dos_distance = builder.predicate_or( - ¶ms, - &[ - "src_ori", - "src_key", - "dst_ori", - "dst_key", - "distance_ori", - "distance_key", - ], - &[], - &[ - STB::new(eth_dos_distance_base) - .arg(("src_ori", "src_key")) - .arg(("dst_ori", "dst_key")) - .arg(("distance_ori", "distance_key")), - STB::new(eth_dos_distance_ind) - .arg(("src_ori", "src_key")) - .arg(("dst_ori", "dst_key")) - .arg(("distance_ori", "distance_key")), - ], - )?; - - println!( - "b.2. eth_dos_distance = {}", - builder.predicates.last().unwrap() - ); - - let eth_dos_batch_b = builder.finish(); - let fields = eth_dos_batch_b.to_fields(¶ms); + let eth_dos_batch = eth_dos_batch(¶ms)?; + let fields = eth_dos_batch.to_fields(¶ms); println!("Batch b, serialized: {:?}", fields); Ok(()) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 7525752..8004b7c 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, Error, Result}; use itertools::Itertools; use std::collections::HashMap; use std::convert::From; +use std::sync::Arc; use std::{fmt, hash as h}; use crate::middleware::{ @@ -13,7 +14,7 @@ use crate::middleware::{ hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver, PodSigner, SELF, }; -use crate::middleware::{OperationType, Predicate}; +use crate::middleware::{OperationType, Predicate, KEY_SIGNER}; mod custom; mod operation; @@ -77,6 +78,12 @@ impl From<&Value> for middleware::Value { } } +impl From for Value { + fn from(v: middleware::Value) -> Self { + Self::Raw(v) + } +} + impl fmt::Display for Value { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -122,15 +129,19 @@ impl SignedPodBuilder { pub fn sign(&self, signer: &mut S) -> Result { let mut kvs = HashMap::new(); let mut key_string_map = HashMap::new(); + let mut value_hash_map = HashMap::new(); for (k, v) in self.kvs.iter() { let k_hash = hash_str(k); - kvs.insert(k_hash, middleware::Value::from(v)); + let v_hash = middleware::Value::from(v); + kvs.insert(k_hash, v_hash); key_string_map.insert(k_hash, k.clone()); + value_hash_map.insert(v_hash.into(), v.clone()); } let pod = signer.sign(&self.params, &kvs)?; Ok(SignedPod { pod, key_string_map, + value_hash_map, }) } } @@ -142,6 +153,8 @@ pub struct SignedPod { pub pod: Box, /// HashMap to store the reverse relation between key strings and key hashes pub key_string_map: HashMap, + /// HashMap to store the reverse relation between values and their hashes + pub value_hash_map: HashMap, } impl fmt::Display for SignedPod { @@ -202,6 +215,8 @@ pub struct MainPodBuilder { pub public_statements: Vec, // Internal state const_cnt: usize, + key_table: HashMap, + pod_class_table: HashMap, } impl fmt::Display for MainPodBuilder { @@ -235,12 +250,26 @@ impl MainPodBuilder { operations: Vec::new(), public_statements: Vec::new(), const_cnt: 0, + key_table: HashMap::new(), + pod_class_table: HashMap::from_iter([(SELF, PodClass::Main)].into_iter()), } } pub fn add_signed_pod(&mut self, pod: &SignedPod) { self.input_signed_pods.push(pod.clone()); + pod.key_string_map.iter().for_each(|(hash, key)| { + self.key_table.insert(hash.clone(), key.clone()); + }); + self.pod_class_table.insert(pod.id(), PodClass::Signed); } pub fn add_main_pod(&mut self, pod: MainPod) { + self.pod_class_table.insert(pod.id(), PodClass::Main); + pod.key_string_map.iter().for_each(|(hash, key)| { + self.key_table.insert(hash.clone(), key.clone()); + }); + pod.pod_class_map.iter().for_each(|(pod_id, pod_class)| { + self.pod_class_table + .insert(pod_id.clone(), pod_class.clone()); + }); self.input_main_pods.push(pod); } pub fn insert(&mut self, st_op: (Statement, Operation)) { @@ -250,7 +279,11 @@ impl MainPodBuilder { } /// Convert [OperationArg]s to [StatementArg]s for the operations that work with entries - fn op_args_entries(&mut self, public: bool, args: &mut [OperationArg]) -> Vec { + fn op_args_entries( + &mut self, + public: bool, + args: &mut [OperationArg], + ) -> Result> { let mut st_args = Vec::new(); for arg in args.iter_mut() { match arg { @@ -270,7 +303,7 @@ impl MainPodBuilder { OperationType::Native(NativeOperation::NewEntry), vec![OperationArg::Entry(k.clone(), v.clone())], ), - ); + )?; *arg = OperationArg::Statement(value_of_st.clone()); st_args.push(value_of_st.1[0].clone()) } @@ -283,14 +316,18 @@ impl MainPodBuilder { } }; } - st_args + Ok(st_args) } - pub fn pub_op(&mut self, op: Operation) -> Statement { + pub fn pub_op(&mut self, op: Operation) -> Result { self.op(true, op) } - pub fn op(&mut self, public: bool, mut op: Operation) -> Statement { + pub fn priv_op(&mut self, op: Operation) -> Result { + self.op(false, op) + } + + fn op(&mut self, public: bool, mut op: Operation) -> Result { use NativeOperation::*; let Operation(op_type, ref mut args) = &mut op; // TODO: argument type checking @@ -299,49 +336,96 @@ impl MainPodBuilder { None => Statement(Predicate::Native(NativePredicate::None), vec![]), NewEntry => Statement( Predicate::Native(NativePredicate::ValueOf), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), CopyStatement => todo!(), EqualFromEntries => Statement( Predicate::Native(NativePredicate::Equal), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), NotEqualFromEntries => Statement( Predicate::Native(NativePredicate::NotEqual), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), GtFromEntries => Statement( Predicate::Native(NativePredicate::Gt), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), LtFromEntries => Statement( Predicate::Native(NativePredicate::Lt), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), TransitiveEqualFromStatements => todo!(), GtToNotEqual => todo!(), LtToNotEqual => todo!(), ContainsFromEntries => Statement( Predicate::Native(NativePredicate::Contains), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), NotContainsFromEntries => Statement( Predicate::Native(NativePredicate::NotContains), - self.op_args_entries(public, args), + self.op_args_entries(public, args)?, ), RenameContainedBy => todo!(), SumOf => todo!(), ProductOf => todo!(), MaxOf => todo!(), }, - _ => todo!(), + OperationType::Custom(cpr) => { + // All args should be statements to be pattern matched against statement templates. + let args = args.iter().map( + |a| match a { + OperationArg::Statement(s) => middleware::Statement::try_from(s.clone()), + _ => Err(anyhow!("Invalid argument {} to operation corresponding to custom predicate {:?}.", a, cpr)) + } + ).collect::>>()?; + // Match these statements against the custom predicate definition + let bindings = cpr.match_against(&args)?; + let output_arg_values = (0..cpr.arg_len()) + .map(|i| { + bindings.get(&i).cloned().ok_or(anyhow!( + "Wildcard {} of custom predicate {:?} is unbound.", + i, + cpr + )) + }) + .collect::>>()?; + let output_args = output_arg_values + .chunks(2) + .map(|chunk| { + Ok(StatementArg::Key(AnchoredKey( + Origin( + self.pod_class_table + .get(&PodId(chunk[0].into())) + .cloned() + .ok_or(anyhow!("Missing POD class value."))?, + PodId(chunk[0].into()), + ), + self.key_table + .get(&chunk[1].into()) + .cloned() + .ok_or(anyhow!("Missing key corresponding to hash."))?, + ))) + }) + .collect::>>()?; + Statement(Predicate::Custom(cpr.clone()), output_args) + } }; self.operations.push(op); if public { self.public_statements.push(st.clone()); } + + // Add key-hash pairs in statement to table. + (&st).1.iter().for_each(|arg| match arg { + StatementArg::Key(AnchoredKey(_, key)) => { + self.key_table.insert(hash_str(key), key.clone()); + } + _ => (), + }); + self.statements.push(st); - self.statements[self.statements.len() - 1].clone() + Ok(self.statements[self.statements.len() - 1].clone()) } pub fn reveal(&mut self, st: &Statement) { @@ -357,8 +441,29 @@ impl MainPodBuilder { operations: &self.operations, public_statements: &self.public_statements, }; - let (statements, operations, public_statements) = compiler.compile(inputs, params)?; + let key_string_map = inputs + .public_statements + .iter() + .flat_map(|s| &s.1) + .flat_map(|arg| match arg { + StatementArg::Key(AnchoredKey(_, key)) => Some((hash_str(key), key.clone())), + _ => None, + }) + .collect::>(); + let pod_class_map = (&inputs) + .public_statements + .into_iter() + .flat_map(|s| &s.1) + .flat_map(|arg| match arg { + StatementArg::Key(AnchoredKey(Origin(pod_class, pod_id), _)) => { + Some((pod_id.clone(), pod_class.clone())) + } + _ => None, + }) + .collect::>(); + + let (statements, operations, public_statements) = compiler.compile(inputs, params)?; let inputs = MainPodInputs { signed_pods: &self.input_signed_pods.iter().map(|p| &p.pod).collect_vec(), main_pods: &self.input_main_pods.iter().map(|p| &p.pod).collect_vec(), @@ -367,7 +472,11 @@ impl MainPodBuilder { public_statements: &public_statements, }; let pod = prover.prove(&self.params, inputs)?; - Ok(MainPod { pod }) + Ok(MainPod { + pod, + key_string_map, + pod_class_map, + }) } } @@ -375,6 +484,8 @@ impl MainPodBuilder { pub struct MainPod { pub pod: Box, // TODO: metadata + pub key_string_map: HashMap, + pub pod_class_map: HashMap, } impl fmt::Display for MainPod { @@ -433,7 +544,7 @@ impl MainPodCompiler { fn compile_op_arg(&self, op_arg: &OperationArg) -> Option { match op_arg { - OperationArg::Statement(s) => Some(self.compile_st(s)), + OperationArg::Statement(s) => self.compile_st(s).ok(), OperationArg::Literal(_v) => { // OperationArg::Literal is a syntax sugar for the frontend. This is translated to // a new ValueOf statement and it's key used instead. @@ -448,23 +559,23 @@ impl MainPodCompiler { } } - fn compile_st(&self, st: &Statement) -> middleware::Statement { - st.clone().try_into().unwrap() + fn compile_st(&self, st: &Statement) -> Result { + st.clone().try_into() } - fn compile_op(&self, op: &Operation) -> middleware::Operation { + fn compile_op(&self, op: &Operation) -> Result { // TODO let mop_code: OperationType = op.0.clone(); let mop_args = op.1.iter() - .flat_map(|arg| self.compile_op_arg(arg).map(|s| s.try_into().unwrap())) - .collect::>(); - middleware::Operation::op(mop_code, &mop_args).unwrap() + .flat_map(|arg| self.compile_op_arg(arg).map(|s| Ok(s.try_into()?))) + .collect::>>()?; + middleware::Operation::op(mop_code, &mop_args) } fn compile_st_op(&mut self, st: &Statement, op: &Operation, params: &Params) -> Result<()> { - let middle_st = self.compile_st(st); - let middle_op = self.compile_op(op); + let middle_st = self.compile_st(st)?; + let middle_op = self.compile_op(op)?; let is_correct = middle_op.check(params, &middle_st)?; if !is_correct { // todo: improve error handling @@ -504,7 +615,7 @@ impl MainPodCompiler { let public_statements = public_statements .iter() .map(|st| self.compile_st(st)) - .collect_vec(); + .collect::>>()?; Ok((self.statements, self.operations, public_statements)) } } @@ -548,8 +659,8 @@ pub mod tests { use crate::backends::plonky2::mock_main::MockProver; use crate::backends::plonky2::mock_signed::MockSigner; use crate::examples::{ - great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, - zu_kyc_sign_pod_builders, + eth_dos_pod_builder, eth_friend_signed_pod_builder, great_boy_pod_full_flow, + tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders, }; #[test] @@ -563,19 +674,19 @@ pub mod tests { let mut signer = MockSigner { pk: "ZooGov".into(), }; - let gov_id = gov_id.sign(&mut signer).unwrap(); + let gov_id = gov_id.sign(&mut signer)?; println!("{}", gov_id); let mut signer = MockSigner { pk: "ZooDeel".into(), }; - let pay_stub = pay_stub.sign(&mut signer).unwrap(); + let pay_stub = pay_stub.sign(&mut signer)?; println!("{}", pay_stub); let mut signer = MockSigner { pk: "ZooOFAC".into(), }; - let sanction_list = sanction_list.sign(&mut signer).unwrap(); + let sanction_list = sanction_list.sign(&mut signer)?; println!("{}", sanction_list); let kyc = zu_kyc_pod_builder(¶ms, &gov_id, &pay_stub, &sanction_list)?; @@ -590,6 +701,35 @@ pub mod tests { Ok(()) } + #[test] + fn test_ethdos() -> Result<()> { + let params = Params::default(); + + let mut alice = MockSigner { pk: "Alice".into() }; + let mut bob = MockSigner { pk: "Bob".into() }; + let mut charlie = MockSigner { + pk: "Charlie".into(), + }; + + // Alice attests that she is ETH friends with Charlie and Charlie + // attests that he is ETH friends with Bob. + let alice_attestation = + eth_friend_signed_pod_builder(¶ms, charlie.pubkey().into()).sign(&mut alice)?; + let charlie_attestation = + eth_friend_signed_pod_builder(¶ms, bob.pubkey().into()).sign(&mut charlie)?; + + let mut prover = MockProver {}; + let alice_bob_ethdos = eth_dos_pod_builder( + ¶ms, + &alice_attestation, + &charlie_attestation, + &bob.pubkey().into(), + )? + .prove(&mut prover, ¶ms)?; + + Ok(()) + } + #[test] fn test_front_great_boy() -> Result<()> { let great_boy = great_boy_pod_full_flow()?; @@ -625,7 +765,7 @@ pub mod tests { let mut builder = MainPodBuilder::new(¶ms); builder.add_signed_pod(&pod); - builder.pub_op(op!(gt, (&pod, "num"), 5)); + builder.pub_op(op!(gt, (&pod, "num"), 5)).unwrap(); let mut prover = MockProver {}; let false_pod = builder.prove(&mut prover, ¶ms).unwrap(); diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 85ccf4b..4be6a54 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -52,15 +52,7 @@ impl From for OperationArg { impl From<(&SignedPod, &str)> for OperationArg { fn from((pod, key): (&SignedPod, &str)) -> Self { - // TODO: Actual value, TryFrom. - let value = pod.kvs().get(&hash_str(key)).unwrap().clone(); - Self::Statement(Statement( - Predicate::Native(NativePredicate::ValueOf), - vec![ - StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), - StatementArg::Literal(Value::Raw(value)), - ], - )) + Self::Statement((pod, key).into()) } } diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs index c0dfc25..21bc4ca 100644 --- a/src/frontend/statement.rs +++ b/src/frontend/statement.rs @@ -1,8 +1,8 @@ use anyhow::{anyhow, Result}; use std::fmt; -use super::{AnchoredKey, Value}; -use crate::middleware::{self, NativePredicate, Predicate}; +use super::{AnchoredKey, SignedPod, Value}; +use crate::middleware::{self, hash_str, NativePredicate, Predicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum StatementArg { @@ -22,6 +22,25 @@ impl fmt::Display for StatementArg { #[derive(Clone, Debug, PartialEq, Eq)] pub struct Statement(pub Predicate, pub Vec); +impl From<(&SignedPod, &str)> for Statement { + fn from((pod, key): (&SignedPod, &str)) -> Self { + // TODO: Actual value, TryFrom. + let value_hash = pod.kvs().get(&hash_str(key)).cloned().unwrap(); + let value = pod + .value_hash_map + .get(&value_hash.into()) + .cloned() + .unwrap_or(Value::Raw(value_hash)); + Statement( + Predicate::Native(NativePredicate::ValueOf), + vec![ + StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), + StatementArg::Literal(value), + ], + ) + } +} + impl TryFrom for middleware::Statement { type Error = anyhow::Error; fn try_from(s: Statement) -> Result { diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 5dbb1f4..3cf0e47 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,9 +1,12 @@ +use std::collections::HashMap; use std::sync::Arc; use std::{fmt, hash as h, iter::zip}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; +use crate::util::hashmap_insert_no_dupe; + use super::{ hash_fields, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, ToFields, Value, F, @@ -25,7 +28,11 @@ impl HashOrWildcard { match self { HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None), HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))), - _ => Err(anyhow!("Failed to match {} against {}.", self, v)), + _ => Err(anyhow!( + "Failed to match hash or wildcard {} against value {}.", + self, + v + )), } } } @@ -76,7 +83,11 @@ impl StatementTmplArg { let k_corr = tmpl_k.match_against(&k.clone().into())?; Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect()) } - _ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)), + _ => Err(anyhow!( + "Failed to match statement template argument {:?} against statement argument {:?}.", + self, + s_arg + )), } } } @@ -322,6 +333,62 @@ impl CustomPredicateBatch { #[derive(Clone, Debug, PartialEq, Eq)] pub struct CustomPredicateRef(pub Arc, pub usize); +impl CustomPredicateRef { + pub fn arg_len(&self) -> usize { + (*self.0).predicates[self.1].args_len + } + pub fn match_against(&self, statements: &[Statement]) -> Result> { + let mut bindings = HashMap::new(); + // Single out custom predicate, replacing batch-self + // references with custom predicate references. + let custom_predicate = { + let cp = &Arc::unwrap_or_clone(self.0.clone()).predicates[self.1]; + CustomPredicate { + conjunction: cp.conjunction, + statements: cp + .statements + .iter() + .map(|StatementTmpl(p, args)| { + StatementTmpl( + match p { + Predicate::BatchSelf(i) => { + Predicate::Custom(CustomPredicateRef(self.0.clone(), *i)) + } + _ => p.clone(), + }, + args.to_vec(), + ) + }) + .collect(), + args_len: cp.args_len, + } + }; + match custom_predicate.conjunction { + true if custom_predicate.statements.len() == statements.len() => { + // Match op args against statement templates + let match_bindings = std::iter::zip(custom_predicate.statements, statements).map( + |(s_tmpl, s)| s_tmpl.match_against(s) + ).collect::>>() + .map(|v| v.concat())?; + // Add bindings to binding table, throwing if there is an inconsistency. + match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok(bindings) + }, + false if statements.len() == 1 => { + // Match op arg against each statement template + custom_predicate.statements.iter().map( + |s_tmpl| { + let mut bindings = bindings.clone(); + s_tmpl.match_against(&statements[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok::<_, anyhow::Error>(bindings) + } + ).find(|m| m.is_ok()).unwrap_or(Err(anyhow!("Statement {} does not match disjunctive custom predicate {}.", &statements[0], custom_predicate))) + }, + _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, statements)) + } + } +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Predicate { Native(NativePredicate), diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index fac4437..584814c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -194,68 +194,34 @@ impl Operation { let v3: i64 = v3.clone().try_into()?; Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) } - ( - Self::Custom(CustomPredicateRef(cpb, i), args), - Custom(CustomPredicateRef(s_cpb, s_i), s_args), - ) if cpb == s_cpb && i == s_i => { - // Bind statement arguments - let mut bindings = s_args - .into_iter() - .enumerate() - .flat_map(|(i, AnchoredKey(PodId(o), k))| { - vec![ - (2 * i, Value::from(o.clone())), - (2 * i + 1, Value::from(k.clone())), - ] - }) - .collect::>(); - - // Single out custom predicate, replacing batch-self - // references with custom predicate references. - let custom_predicate = { - let cp = (**cpb).predicates[*i].clone(); - CustomPredicate::new( - params, - cp.conjunction, - cp.statements - .into_iter() - .map(|StatementTmpl(p, args)| { - StatementTmpl( - match p { - Predicate::BatchSelf(i) => { - Predicate::Custom(CustomPredicateRef(cpb.clone(), i)) - } - _ => p, - }, - args, - ) - }) - .collect(), - cp.args_len, - )? - }; - match custom_predicate.conjunction { - true if custom_predicate.statements.len() == args.len() => { - // Match op args against statement templates - let match_bindings = std::iter::zip(custom_predicate.statements, args).map( - |(s_tmpl, s)| s_tmpl.match_against(s) - ).collect::>>() - .map(|v| v.concat())?; - // Add bindings to binding table, throwing if there is an inconsistency. - match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok(true) - }, - false if args.len() == 1 => { - // Match op arg against each statement template - custom_predicate.statements.into_iter().map( - |s_tmpl| { - let mut bindings = bindings.clone(); - s_tmpl.match_against(&args[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok::<_, anyhow::Error>(true) - } - ).find(|m| m.is_ok()).unwrap_or(Ok(false)) - }, - _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, args)) + (Self::Custom(CustomPredicateRef(cpb, i), args), Custom(cpr, s_args)) + if cpb == &cpr.0 && i == &cpr.1 => + { + // Bind according to custom predicate pattern match against arg list. + let bindings = cpr.match_against(args)?; + // Check arg length + let arg_len = cpr.arg_len(); + if arg_len != 2 * s_args.len() { + Err(anyhow!("Custom predicate arg list {:?} must have {} arguments after destructuring.", s_args, arg_len)) + } else { + let bound_args = (0..arg_len) + .map(|i| { + bindings.get(&i).cloned().ok_or(anyhow!( + "Wildcard {} of custom predicate {:?} is unbound.", + i, + cpr + )) + }) + .collect::>>()?; + let s_args = s_args + .into_iter() + .flat_map(|AnchoredKey(o, k)| [Value::from(o.0.clone()), k.clone().into()]) + .collect::>(); + if bound_args != s_args { + Err(anyhow!("Arguments to output statement {} do not match those implied by operation {:?}", output_statement,self)) + } else { + Ok(true) + } } } _ => Err(anyhow!(