diff --git a/src/lib.rs b/src/lib.rs index 963de0b..2dc6107 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod constants; pub mod frontend; pub mod middleware; pub mod primitives; +mod util; #[cfg(test)] pub mod examples; diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index ad2a9d8..13066f8 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -2,11 +2,12 @@ use std::sync::Arc; use std::{fmt, hash as h, iter::zip}; use anyhow::{anyhow, Result}; -use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; +use crate::middleware::{Operation, SELF}; + use super::{ hash_str, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, ToFields, Value, F, @@ -338,3 +339,275 @@ impl fmt::Display for Predicate { } } } + +#[cfg(test)] +mod tests { + use std::{array, sync::Arc}; + + use anyhow::Result; + use plonky2::field::goldilocks_field::GoldilocksField; + + use crate::middleware::{ + AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash, + HashOrWildcard, NativePredicate, Operation, PodId, PodType, Predicate, Statement, + StatementTmpl, StatementTmplArg, SELF, + }; + + fn st(p: Predicate, args: Vec) -> StatementTmpl { + StatementTmpl(p, args) + } + + type STA = StatementTmplArg; + type HOW = HashOrWildcard; + type P = Predicate; + type NP = NativePredicate; + + #[test] + fn is_double_test() -> Result<()> { + /* + is_double(S1, S2) :- + p:value_of(Constant, 2), + p:product_of(S1, Constant, S2) + */ + let cust_pred_batch = Arc::new(CustomPredicateBatch { + name: "is_double".to_string(), + predicates: vec![CustomPredicate { + conjunction: true, + statements: vec![ + st( + P::Native(NP::ValueOf), + vec![ + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::Literal(2.into()), + ], + ), + st( + P::Native(NP::ProductOf), + vec![ + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + ], + ), + ], + args_len: 4, + }], + }); + + let custom_statement = Statement::Custom( + CustomPredicateRef(cust_pred_batch.clone(), 0), + vec![ + AnchoredKey(SELF, "Some value".into()), + AnchoredKey(SELF, "Some other value".into()), + ], + ); + + let custom_deduction = Operation::Custom( + CustomPredicateRef(cust_pred_batch, 0), + vec![ + Statement::ValueOf(AnchoredKey(SELF, "Some constant".into()), 2.into()), + Statement::ProductOf( + AnchoredKey(SELF, "Some value".into()), + AnchoredKey(SELF, "Some constant".into()), + AnchoredKey(SELF, "Some other value".into()), + ), + ], + ); + + assert!(custom_deduction.check(&custom_statement)?); + + Ok(()) + } + + #[test] + fn ethdos_test() -> Result<()> { + let eth_friend_cp = CustomPredicate { + conjunction: true, + statements: vec![ + st( + P::Native(NP::ValueOf), + vec![ + STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("type".into())), + STA::Literal(PodType::Signed.into()), + ], + ), + st( + P::Native(NP::Equal), + vec![ + STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("signer".into())), + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + ], + ), + st( + P::Native(NP::Equal), + vec![ + STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("attestation".into())), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + ], + ), + ], + args_len: 4, + }; + + let eth_friend_batch = Arc::new(CustomPredicateBatch { + name: "eth_friend".to_string(), + predicates: vec![eth_friend_cp], + }); + + let eth_dos_base = CustomPredicate { + conjunction: true, + statements: vec![ + st( + P::Native(NP::Equal), + vec![ + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + ], + ), + st( + P::Native(NP::ValueOf), + vec![ + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::Literal(0.into()), + ], + ), + ], + args_len: 6, + }; + + let eth_dos_ind = CustomPredicate { + conjunction: true, + statements: vec![ + st( + P::BatchSelf(2), + vec![ + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)), + STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)), + ], + ), + st( + P::Native(NP::ValueOf), + vec![ + STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)), + STA::Literal(1.into()), + ], + ), + st( + P::Native(NP::SumOf), + vec![ + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)), + STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)), + ], + ), + st( + P::Custom(CustomPredicateRef(eth_friend_batch.clone(), 0)), + vec![ + STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + ], + ), + ], + args_len: 6, + }; + + let eth_dos_distance_either = CustomPredicate { + conjunction: false, + statements: vec![ + st( + P::BatchSelf(0), + vec![ + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + ], + ), + st( + P::BatchSelf(1), + vec![ + STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + ], + ), + ], + args_len: 6, + }; + + let eth_dos_distance_batch = Arc::new(CustomPredicateBatch { + name: "ETHDoS_distance".to_string(), + predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either], + }); + + // Some POD IDs + let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64)))); + let pod_id2 = PodId(Hash(array::from_fn(|i| GoldilocksField((i * i) as u64)))); + let pod_id3 = PodId(Hash(array::from_fn(|i| GoldilocksField((2 * i) as u64)))); + let pod_id4 = PodId(Hash(array::from_fn(|i| GoldilocksField((2 * i) as u64)))); + + // Example statement + let ethdos_example = Statement::Custom( + CustomPredicateRef(eth_dos_distance_batch.clone(), 2), + vec![ + AnchoredKey(pod_id1, "Alice".into()), + AnchoredKey(pod_id2, "Bob".into()), + AnchoredKey(SELF, "Seven".into()), + ], + ); + + // Copies should work. + assert!(Operation::CopyStatement(ethdos_example.clone()).check(ðdos_example)?); + + // This could arise as the inductive step. + let ethdos_ind_example = Statement::Custom( + CustomPredicateRef(eth_dos_distance_batch.clone(), 1), + vec![ + AnchoredKey(pod_id1, "Alice".into()), + AnchoredKey(pod_id2, "Bob".into()), + AnchoredKey(SELF, "Seven".into()), + ], + ); + + assert!(Operation::Custom( + CustomPredicateRef(eth_dos_distance_batch.clone(), 2), + vec![ethdos_ind_example.clone()] + ) + .check(ðdos_example)?); + + // And the inductive step would arise as follows: Say the + // ETHDoS distance from Alice to Charlie is 6, which is one + // less than 7, and Charlie is ETH-friends with Bob. + let ethdos_facts = vec![ + Statement::Custom( + CustomPredicateRef(eth_dos_distance_batch.clone(), 2), + vec![ + AnchoredKey(pod_id1, "Alice".into()), + AnchoredKey(pod_id3, "Charlie".into()), + AnchoredKey(pod_id4, "Six".into()), + ], + ), + Statement::ValueOf(AnchoredKey(SELF, "One".into()), 1.into()), + Statement::SumOf( + AnchoredKey(SELF, "Seven".into()), + AnchoredKey(pod_id4, "Six".into()), + AnchoredKey(SELF, "One".into()), + ), + Statement::Custom( + CustomPredicateRef(eth_friend_batch.clone(), 0), + vec![ + AnchoredKey(pod_id3, "Charlie".into()), + AnchoredKey(pod_id2, "Bob".into()), + ], + ), + ]; + + assert!(Operation::Custom( + CustomPredicateRef(eth_dos_distance_batch.clone(), 1), + ethdos_facts + ) + .check(ðdos_ind_example)?); + + Ok(()) + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 451de8a..9a62ad3 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -185,6 +185,12 @@ impl FromHex for Hash { } } +impl From<&str> for Hash { + fn from(s: &str) -> Self { + hash_str(s) + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] pub struct PodId(pub Hash); diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index ddcc37d..cf08705 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,7 +1,12 @@ +use std::collections::HashMap; + use anyhow::{anyhow, Result}; use super::{CustomPredicateRef, Statement}; -use crate::middleware::{AnchoredKey, SELF}; +use crate::{ + middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF}, + util::hashmap_insert_no_dupe, +}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NativeOperation { @@ -175,9 +180,69 @@ impl Operation { Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) } ( - Self::Custom(CustomPredicateRef(cpb, i), _args), - Custom(CustomPredicateRef(s_cpb, s_i), _s_args), - ) if cpb == s_cpb && i == s_i => todo!(), + Self::Custom(CustomPredicateRef(cpb, i), args), + Custom(CustomPredicateRef(s_cpb, s_i), s_args), + ) if cpb == s_cpb && i == s_i => { + // Bind statement arguments + let mut bindings = s_args + .into_iter() + .enumerate() + .flat_map(|(i, AnchoredKey(PodId(o), k))| { + vec![ + (2 * i, Value::from(o.clone())), + (2 * i + 1, Value::from(k.clone())), + ] + }) + .collect::>(); + + // Single out custom predicate, replacing batch-self + // references with custom predicate references. + let custom_predicate = { + let cp = (**cpb).predicates[*i].clone(); + CustomPredicate { + conjunction: cp.conjunction, + statements: cp + .statements + .into_iter() + .map(|StatementTmpl(p, args)| { + StatementTmpl( + match p { + Predicate::BatchSelf(i) => { + Predicate::Custom(CustomPredicateRef(cpb.clone(), i)) + } + _ => p, + }, + args, + ) + }) + .collect(), + args_len: cp.args_len, + } + }; + match custom_predicate.conjunction { + true if custom_predicate.statements.len() == args.len() => { + // Match op args against statement templates + let match_bindings = std::iter::zip(custom_predicate.statements, args).map( + |(s_tmpl, s)| s_tmpl.match_against(s) + ).collect::>>() + .map(|v| v.concat())?; + // Add bindings to binding table, throwing if there is an inconsistency. + match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok(true) + }, + false if args.len() == 1 => { + // Match op arg against each statement template + custom_predicate.statements.into_iter().map( + |s_tmpl| { + let mut bindings = bindings.clone(); + s_tmpl.match_against(&args[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok::<_, anyhow::Error>(true) + } + ).find(|m| m.is_ok()).unwrap_or(Ok(false)) + }, + _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, args)) + } + } _ => Err(anyhow!( "Invalid deduction: {:?} ⇏ {:#}", self, diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..b24e2a1 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,22 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; + +use anyhow::{anyhow, Result}; + +pub(crate) fn hashmap_insert_no_dupe( + hm: &mut HashMap, + kv: (S, T), +) -> Result<()> { + let (k, v) = kv.clone(); + let res = hm.insert(kv.0, kv.1); + match res { + Some(w) if w != v => Err(anyhow!( + "Key {:?} exists in table with value {:?} != {:?}.", + k, + w, + v + )), + _ => Ok(()), + } +}