Prototype custom predicates (#74)

* wip

* prototype custom predicates 1b

* feat: implement custom pred recursion

* files reorg, add github CI for rustfmt checks

---------

Co-authored-by: arnaucube <git@arnaucube.com>
This commit is contained in:
Eduard S. 2025-02-21 01:55:36 +01:00 committed by GitHub
parent c2d23b0b1b
commit 2e9719a1ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 529 additions and 124 deletions

21
.github/workflows/rustfmt.yml vendored Normal file
View file

@ -0,0 +1,21 @@
name: Rustfmt Check
on:
pull_request:
branches: [ main ]
types: [ready_for_review, opened, synchronize, reopened]
push:
branches: [ main ]
jobs:
rustfmt:
if: github.event.pull_request.draft == false
name: Rust formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: rustfmt
- name: Check formatting
uses: actions-rust-lang/rustfmt@v1

View file

@ -1,2 +1,3 @@
[default.extend-words] [default.extend-words]
groth = "groth" # to avoid it dectecting it as 'growth' groth = "groth" # to avoid it dectecting it as 'growth'
BA = "BA"

View file

@ -1,19 +1,20 @@
mod operation; use anyhow::{anyhow, Result};
mod statement;
use crate::middleware::{
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod,
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
};
use anyhow::Result;
use itertools::Itertools; use itertools::Itertools;
pub use operation::*;
use plonky2::hash::poseidon::PoseidonHash; use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher; use plonky2::plonk::config::Hasher;
pub use statement::*;
use std::any::Any; use std::any::Any;
use std::fmt; use std::fmt;
use crate::middleware::{
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod,
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
};
mod operation;
mod statement;
pub use operation::*;
pub use statement::*;
pub const VALUE_TYPE: &str = "MockMainPOD"; pub const VALUE_TYPE: &str = "MockMainPOD";
pub struct MockProver {} pub struct MockProver {}
@ -222,18 +223,17 @@ impl MockMainPod {
fn find_op_arg( fn find_op_arg(
statements: &[Statement], statements: &[Statement],
op_arg: &middleware::Statement, op_arg: &middleware::Statement,
) -> Result<OperationArg, OperationArgError> { ) -> Result<OperationArg> {
match op_arg { match op_arg {
middleware::Statement::None => Ok(OperationArg::None), middleware::Statement::None => Ok(OperationArg::None),
_ => statements _ => statements
.iter() .iter()
.enumerate() .enumerate()
.find_map(|(i, s)| { .find_map(|(i, s)| {
// TODO: Error handling (&middleware::Statement::try_from(s.clone()).ok()? == op_arg).then_some(i)
(&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i)
}) })
.map(OperationArg::Index) .map(OperationArg::Index)
.ok_or(OperationArgError::StatementNotFound), .ok_or(anyhow!("statement not found")),
} }
} }
@ -241,7 +241,7 @@ impl MockMainPod {
params: &Params, params: &Params,
statements: &[Statement], statements: &[Statement],
input_operations: &[middleware::Operation], input_operations: &[middleware::Operation],
) -> Result<Vec<Operation>, OperationArgError> { ) -> Result<Vec<Operation>> {
let mut operations = Vec::new(); let mut operations = Vec::new();
for i in 0..params.max_priv_statements() { for i in 0..params.max_priv_statements() {
let op = input_operations let op = input_operations
@ -252,7 +252,7 @@ impl MockMainPod {
let mut args = mid_args let mut args = mid_args
.iter() .iter()
.map(|mid_arg| Self::find_op_arg(statements, mid_arg)) .map(|mid_arg| Self::find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>, OperationArgError>>()?; .collect::<Result<Vec<_>>>()?;
Self::pad_operation_args(params, &mut args); Self::pad_operation_args(params, &mut args);
operations.push(Operation(op.code(), args)); operations.push(Operation(op.code(), args));
} }
@ -265,7 +265,7 @@ impl MockMainPod {
params: &Params, params: &Params,
statements: &[Statement], statements: &[Statement],
mut operations: Vec<Operation>, mut operations: Vec<Operation>,
) -> Result<Vec<Operation>, OperationArgError> { ) -> Result<Vec<Operation>> {
let offset_public_statements = statements.len() - params.max_public_statements; let offset_public_statements = statements.len() - params.max_public_statements;
operations.push(Operation(NativeOperation::NewEntry, vec![])); operations.push(Operation(NativeOperation::NewEntry, vec![]));
for i in 0..(params.max_public_statements - 1) { for i in 0..(params.max_public_statements - 1) {
@ -318,7 +318,7 @@ impl MockMainPod {
statements[statements.len() - params.max_public_statements..].to_vec(); statements[statements.len() - params.max_public_statements..].to_vec();
// get the id out of the public statements // get the id out of the public statements
let id: PodId = PodId(hash_statements(&public_statements)?); let id: PodId = PodId(hash_statements(&public_statements));
Ok(Self { Ok(Self {
params: params.clone(), params: params.clone(),
@ -335,7 +335,7 @@ impl MockMainPod {
fn statement_none(params: &Params) -> Statement { fn statement_none(params: &Params) -> Statement {
let mut args = Vec::with_capacity(params.max_statement_args); let mut args = Vec::with_capacity(params.max_statement_args);
Self::pad_statement_args(&params, &mut args); Self::pad_statement_args(&params, &mut args);
Statement(NativeStatement::None, args) Statement(NativePredicate::None, args)
} }
fn operation_none(params: &Params) -> Operation { fn operation_none(params: &Params) -> Operation {
@ -353,12 +353,12 @@ impl MockMainPod {
} }
} }
pub fn hash_statements(statements: &[Statement]) -> Result<middleware::Hash> { pub fn hash_statements(statements: &[Statement]) -> middleware::Hash {
let field_elems = statements let field_elems = statements
.into_iter() .into_iter()
.flat_map(|statement| statement.clone().to_fields().0) .flat_map(|statement| statement.clone().to_fields().0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(Hash(PoseidonHash::hash_no_pad(&field_elems).elements)) Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
} }
impl Pod for MockMainPod { impl Pod for MockMainPod {
@ -367,14 +367,14 @@ impl Pod for MockMainPod {
// get the input_statements from the self.statements // get the input_statements from the self.statements
let input_statements = &self.statements[input_statement_offset..]; let input_statements = &self.statements[input_statement_offset..];
// get the id out of the public statements, and ensure it is equal to self.id // get the id out of the public statements, and ensure it is equal to self.id
let ids_match = self.id == PodId(hash_statements(&self.public_statements).unwrap()); let ids_match = self.id == PodId(hash_statements(&self.public_statements));
// find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the
// value is PodType::MockMainPod // value is PodType::MockMainPod
let has_type_statement = self let has_type_statement = self
.public_statements .public_statements
.iter() .iter()
.find(|s| { .find(|s| {
s.0 == NativeStatement::ValueOf s.0 == NativePredicate::ValueOf
&& s.1.len() > 0 && s.1.len() > 0
&& if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] { && if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] {
pod_id == SELF && key_hash == hash_str(KEY_TYPE) pod_id == SELF && key_hash == hash_str(KEY_TYPE)
@ -402,7 +402,7 @@ impl Pod for MockMainPod {
s, s,
) )
}) })
.filter(|(_, s)| s.0 == NativeStatement::ValueOf) .filter(|(_, s)| s.0 == NativePredicate::ValueOf)
.flat_map(|(i, s)| { .flat_map(|(i, s)| {
if let StatementArg::Key(ak) = &s.1[0] { if let StatementArg::Key(ak) = &s.1[0] {
vec![(i, ak.1, ak.0)] vec![(i, ak.1, ak.0)]
@ -473,22 +473,22 @@ pub mod tests {
use crate::middleware; use crate::middleware;
#[test] #[test]
fn test_mock_main_zu_kyc() { fn test_mock_main_zu_kyc() -> Result<()> {
let params = middleware::Params::default(); let params = middleware::Params::default();
let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_pod_builders(&params); let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_pod_builders(&params);
let mut signer = MockSigner { let mut signer = MockSigner {
pk: "ZooGov".into(), pk: "ZooGov".into(),
}; };
let gov_id_pod = gov_id_builder.sign(&mut signer).unwrap(); let gov_id_pod = gov_id_builder.sign(&mut signer)?;
let mut signer = MockSigner { let mut signer = MockSigner {
pk: "ZooDeel".into(), pk: "ZooDeel".into(),
}; };
let pay_stub_pod = pay_stub_builder.sign(&mut signer).unwrap(); let pay_stub_pod = pay_stub_builder.sign(&mut signer)?;
let kyc_builder = zu_kyc_pod_builder(&params, &gov_id_pod, &pay_stub_pod); let kyc_builder = zu_kyc_pod_builder(&params, &gov_id_pod, &pay_stub_pod);
let mut prover = MockProver {}; let mut prover = MockProver {};
let kyc_pod = kyc_builder.prove(&mut prover).unwrap(); let kyc_pod = kyc_builder.prove(&mut prover)?;
let pod = kyc_pod.pod.into_any().downcast::<MockMainPod>().unwrap(); let pod = kyc_pod.pod.into_any().downcast::<MockMainPod>().unwrap();
println!("{:#}", pod); println!("{:#}", pod);
@ -496,14 +496,15 @@ pub mod tests {
assert_eq!(pod.verify(), true); // TODO assert_eq!(pod.verify(), true); // TODO
// println!("id: {}", pod.id()); // println!("id: {}", pod.id());
// println!("pub_statements: {:?}", pod.pub_statements()); // println!("pub_statements: {:?}", pod.pub_statements());
Ok(())
} }
#[test] #[test]
fn test_mock_main_great_boy() { fn test_mock_main_great_boy() -> Result<()> {
let great_boy_builder = great_boy_pod_full_flow(); let great_boy_builder = great_boy_pod_full_flow();
let mut prover = MockProver {}; let mut prover = MockProver {};
let great_boy_pod = great_boy_builder.prove(&mut prover).unwrap(); let great_boy_pod = great_boy_builder.prove(&mut prover)?;
let pod = great_boy_pod let pod = great_boy_pod
.pod .pod
.into_any() .into_any()
@ -513,16 +514,20 @@ pub mod tests {
println!("{}", pod); println!("{}", pod);
assert_eq!(pod.verify(), true); assert_eq!(pod.verify(), true);
Ok(())
} }
#[test] #[test]
fn test_mock_main_tickets() { fn test_mock_main_tickets() -> Result<()> {
let tickets_builder = tickets_pod_full_flow(); let tickets_builder = tickets_pod_full_flow();
let mut prover = MockProver {}; let mut prover = MockProver {};
let proof_pod = tickets_builder.prove(&mut prover).unwrap(); let proof_pod = tickets_builder.prove(&mut prover)?;
let pod = proof_pod.pod.into_any().downcast::<MockMainPod>().unwrap(); let pod = proof_pod.pod.into_any().downcast::<MockMainPod>().unwrap();
println!("{}", pod); println!("{}", pod);
assert_eq!(pod.verify(), true); assert_eq!(pod.verify(), true);
Ok(())
} }
} }

View file

@ -1,10 +1,8 @@
use anyhow::Result;
use std::fmt; use std::fmt;
use anyhow::Result;
use crate::middleware::{self, NativeOperation};
use super::Statement; use super::Statement;
use crate::middleware::{self, NativeOperation};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg { pub enum OperationArg {
@ -18,23 +16,6 @@ impl OperationArg {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArgError {
KeyNotFound,
StatementNotFound,
}
impl std::fmt::Display for OperationArgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OperationArgError::KeyNotFound => write!(f, "Key not found"),
OperationArgError::StatementNotFound => write!(f, "Statement not found"),
}
}
}
impl std::error::Error for OperationArgError {}
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Operation(pub NativeOperation, pub Vec<OperationArg>); pub struct Operation(pub NativeOperation, pub Vec<OperationArg>);

View file

@ -1,15 +1,14 @@
use anyhow::{anyhow, Result};
use std::fmt; use std::fmt;
use anyhow::{anyhow, Result}; use crate::middleware::{self, NativePredicate, StatementArg, ToFields};
use crate::middleware::{self, NativeStatement, StatementArg, ToFields};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Statement(pub NativeStatement, pub Vec<StatementArg>); pub struct Statement(pub NativePredicate, pub Vec<StatementArg>);
impl Statement { impl Statement {
pub fn is_none(&self) -> bool { pub fn is_none(&self) -> bool {
self.0 == NativeStatement::None self.0 == NativePredicate::None
} }
/// Argument method. Trailing Nones are filtered out. /// Argument method. Trailing Nones are filtered out.
pub fn args(&self) -> Vec<StatementArg> { pub fn args(&self) -> Vec<StatementArg> {
@ -44,7 +43,7 @@ impl TryFrom<Statement> for middleware::Statement {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(s: Statement) -> Result<Self> { fn try_from(s: Statement) -> Result<Self> {
type S = middleware::Statement; type S = middleware::Statement;
type NS = NativeStatement; type NP = NativePredicate;
type SA = StatementArg; type SA = StatementArg;
let proper_args = s.args(); let proper_args = s.args();
let args = ( let args = (
@ -53,27 +52,27 @@ impl TryFrom<Statement> for middleware::Statement {
proper_args.get(2).cloned(), proper_args.get(2).cloned(),
); );
Ok(match (s.0, args, proper_args.len()) { Ok(match (s.0, args, proper_args.len()) {
(NS::None, _, 0) => S::None, (NP::None, _, 0) => S::None,
(NS::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v), (NP::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v),
(NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2), (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2),
(NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => {
S::NotEqual(ak1, ak2) S::NotEqual(ak1, ak2)
} }
(NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2),
(NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2),
(NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => {
S::Contains(ak1, ak2) S::Contains(ak1, ak2)
} }
(NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => {
S::NotContains(ak1, ak2) S::NotContains(ak1, ak2)
} }
(NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::SumOf(ak1, ak2, ak3) S::SumOf(ak1, ak2, ak3)
} }
(NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::ProductOf(ak1, ak2, ak3) S::ProductOf(ak1, ak2, ak3)
} }
(NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::MaxOf(ak1, ak2, ak3) S::MaxOf(ak1, ak2, ak3)
} }
_ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?,

View file

@ -1,11 +1,12 @@
use anyhow::Result;
use std::any::Any;
use std::collections::HashMap;
use crate::middleware::{ use crate::middleware::{
containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType, containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType,
Statement, Value, KEY_SIGNER, KEY_TYPE, Statement, Value, KEY_SIGNER, KEY_TYPE,
}; };
use crate::primitives::merkletree::MerkleTree; use crate::primitives::merkletree::MerkleTree;
use anyhow::Result;
use std::any::Any;
use std::collections::HashMap;
pub struct MockSigner { pub struct MockSigner {
pub pk: String, pub pk: String,

View file

@ -1,9 +1,6 @@
//! The frontend includes the user-level abstractions and user-friendly types to define and work //! The frontend includes the user-level abstractions and user-friendly types to define and work
//! with Pods. //! with Pods.
mod operation;
mod statement;
use anyhow::Result; use anyhow::Result;
use itertools::Itertools; use itertools::Itertools;
use std::collections::HashMap; use std::collections::HashMap;
@ -13,9 +10,12 @@ use std::fmt;
use crate::middleware::{ use crate::middleware::{
self, self,
containers::{Array, Dictionary, Set}, containers::{Array, Dictionary, Set},
hash_str, Hash, MainPodInputs, NativeOperation, NativeStatement, Params, PodId, PodProver, hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver,
PodSigner, SELF, PodSigner, SELF,
}; };
mod operation;
mod statement;
pub use operation::*; pub use operation::*;
pub use statement::*; pub use statement::*;
@ -236,7 +236,7 @@ impl MainPodBuilder {
for arg in args.iter_mut() { for arg in args.iter_mut() {
match arg { match arg {
OperationArg::Statement(s) => { OperationArg::Statement(s) => {
if s.0 == NativeStatement::ValueOf { if s.0 == NativePredicate::ValueOf {
st_args.push(s.1[0].clone()) st_args.push(s.1[0].clone())
} else { } else {
panic!("Invalid statement argument."); panic!("Invalid statement argument.");
@ -276,27 +276,27 @@ impl MainPodBuilder {
let Operation(op_type, ref mut args) = op; let Operation(op_type, ref mut args) = op;
// TODO: argument type checking // TODO: argument type checking
let st = match op_type { let st = match op_type {
None => Statement(NativeStatement::None, vec![]), None => Statement(NativePredicate::None, vec![]),
NewEntry => Statement(NativeStatement::ValueOf, self.op_args_entries(public, args)), NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)),
CopyStatement => todo!(), CopyStatement => todo!(),
EqualFromEntries => { EqualFromEntries => {
Statement(NativeStatement::Equal, self.op_args_entries(public, args)) Statement(NativePredicate::Equal, self.op_args_entries(public, args))
} }
NotEqualFromEntries => Statement( NotEqualFromEntries => Statement(
NativeStatement::NotEqual, NativePredicate::NotEqual,
self.op_args_entries(public, args), self.op_args_entries(public, args),
), ),
GtFromEntries => Statement(NativeStatement::Gt, self.op_args_entries(public, args)), GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)),
LtFromEntries => Statement(NativeStatement::Lt, self.op_args_entries(public, args)), LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)),
TransitiveEqualFromStatements => todo!(), TransitiveEqualFromStatements => todo!(),
GtToNotEqual => todo!(), GtToNotEqual => todo!(),
LtToNotEqual => todo!(), LtToNotEqual => todo!(),
ContainsFromEntries => Statement( ContainsFromEntries => Statement(
NativeStatement::Contains, NativePredicate::Contains,
self.op_args_entries(public, args), self.op_args_entries(public, args),
), ),
NotContainsFromEntries => Statement( NotContainsFromEntries => Statement(
NativeStatement::NotContains, NativePredicate::NotContains,
self.op_args_entries(public, args), self.op_args_entries(public, args),
), ),
RenameContainedBy => todo!(), RenameContainedBy => todo!(),

View file

@ -1,8 +1,7 @@
use std::fmt; use std::fmt;
use crate::middleware::{hash_str, NativeOperation, NativeStatement};
use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
use crate::middleware::{hash_str, NativeOperation, NativePredicate};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum OperationArg { pub enum OperationArg {
@ -56,7 +55,7 @@ impl From<(&SignedPod, &str)> for OperationArg {
// TODO: Actual value, TryFrom. // TODO: Actual value, TryFrom.
let value = pod.kvs().get(&hash_str(key)).unwrap().clone(); let value = pod.kvs().get(&hash_str(key)).unwrap().clone();
Self::Statement(Statement( Self::Statement(Statement(
NativeStatement::ValueOf, NativePredicate::ValueOf,
vec![ vec![
StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())),
StatementArg::Literal(Value::Raw(value)), StatementArg::Literal(Value::Raw(value)),

View file

@ -1,10 +1,8 @@
use anyhow::{anyhow, Result};
use std::fmt; use std::fmt;
use anyhow::{anyhow, Result};
use crate::middleware::{self, NativeStatement};
use super::{AnchoredKey, Value}; use super::{AnchoredKey, Value};
use crate::middleware::{self, NativePredicate};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum StatementArg { pub enum StatementArg {
@ -22,13 +20,13 @@ impl fmt::Display for StatementArg {
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Statement(pub NativeStatement, pub Vec<StatementArg>); pub struct Statement(pub NativePredicate, pub Vec<StatementArg>);
impl TryFrom<Statement> for middleware::Statement { impl TryFrom<Statement> for middleware::Statement {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(s: Statement) -> Result<Self> { fn try_from(s: Statement) -> Result<Self> {
type MS = middleware::Statement; type MS = middleware::Statement;
type NS = NativeStatement; type NP = NativePredicate;
type SA = StatementArg; type SA = StatementArg;
let args = ( let args = (
s.1.get(0).cloned(), s.1.get(0).cloned(),
@ -36,35 +34,35 @@ impl TryFrom<Statement> for middleware::Statement {
s.1.get(2).cloned(), s.1.get(2).cloned(),
); );
Ok(match (s.0, args) { Ok(match (s.0, args) {
(NS::None, (None, None, None)) => MS::None, (NP::None, (None, None, None)) => MS::None,
(NS::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { (NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => {
MS::ValueOf(ak.into(), (&v).into()) MS::ValueOf(ak.into(), (&v).into())
} }
(NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Equal(ak1.into(), ak2.into()) MS::Equal(ak1.into(), ak2.into())
} }
(NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotEqual(ak1.into(), ak2.into()) MS::NotEqual(ak1.into(), ak2.into())
} }
(NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Gt(ak1.into(), ak2.into()) MS::Gt(ak1.into(), ak2.into())
} }
(NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Lt(ak1.into(), ak2.into()) MS::Lt(ak1.into(), ak2.into())
} }
(NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::Contains(ak1.into(), ak2.into()) MS::Contains(ak1.into(), ak2.into())
} }
(NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
MS::NotContains(ak1.into(), ak2.into()) MS::NotContains(ak1.into(), ak2.into())
} }
(NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::SumOf(ak1.into(), ak2.into(), ak3.into()) MS::SumOf(ak1.into(), ak2.into(), ak3.into())
} }
(NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) MS::ProductOf(ak1.into(), ak2.into(), ak3.into())
} }
(NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) MS::MaxOf(ak1.into(), ak2.into(), ak3.into())
} }
_ => Err(anyhow!("Ill-formed statement: {}", s))?, _ => Err(anyhow!("Ill-formed statement: {}", s))?,

397
src/middleware/custom.rs Normal file
View file

@ -0,0 +1,397 @@
use std::fmt;
use std::sync::Arc;
use super::{hash_str, Hash, NativePredicate, ToFields, Value, F};
// BEGIN Custom 1b
#[derive(Debug)]
pub enum HashOrWildcard {
Hash(Hash),
Wildcard(usize),
}
impl fmt::Display for HashOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Hash(h) => write!(f, "{}", h),
Self::Wildcard(n) => write!(f, "*{}", n),
}
}
}
#[derive(Debug)]
pub enum StatementTmplArg {
None,
Literal(Value),
Key(HashOrWildcard, HashOrWildcard),
}
impl fmt::Display for StatementTmplArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Literal(v) => write!(f, "{}", v),
Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key),
}
}
}
// END
// BEGIN Custom 2
// pub enum StatementTmplArg {
// None,
// Literal(Value),
// Wildcard(usize),
// }
// END
/// Statement Template for a Custom Predicate
#[derive(Debug)]
pub struct StatementTmpl(Predicate, Vec<StatementTmplArg>);
#[derive(Debug)]
pub struct CustomPredicate {
/// true for "and", false for "or"
pub conjunction: bool,
pub statements: Vec<StatementTmpl>,
pub args_len: usize,
// TODO: Add private args length?
// TODO: Add args type information?
}
impl fmt::Display for CustomPredicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?;
for st in &self.statements {
write!(f, " {}", st.0)?;
for (i, arg) in st.1.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
writeln!(f, "),")?;
}
write!(f, ">(")?;
for i in 0..self.args_len {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "*{}", i)?;
}
writeln!(f, ")")?;
Ok(())
}
}
#[derive(Debug)]
pub struct CustomPredicateBatch {
predicates: Vec<CustomPredicate>,
}
impl CustomPredicateBatch {
pub fn hash(&self) -> Hash {
// TODO
hash_str(&format!("{:?}", self))
}
}
#[derive(Clone, Debug)]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(Arc<CustomPredicateBatch>, usize),
}
impl From<NativePredicate> for Predicate {
fn from(v: NativePredicate) -> Self {
Self::Native(v)
}
}
impl ToFields for Predicate {
fn to_fields(self) -> (Vec<F>, usize) {
todo!()
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::PodType;
enum HashOrWildcardStr {
Hash(Hash),
Wildcard(String),
}
fn l(s: &str) -> HashOrWildcardStr {
HashOrWildcardStr::Hash(hash_str(s))
}
fn w(s: &str) -> HashOrWildcardStr {
HashOrWildcardStr::Wildcard(s.to_string())
}
enum BuilderArg {
Literal(Value),
Key(HashOrWildcardStr, HashOrWildcardStr),
}
impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg {
fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self {
Self::Key(pod_id, key)
}
}
impl<V> From<V> for BuilderArg
where
V: Into<Value>,
{
fn from(v: V) -> Self {
Self::Literal(v.into())
}
}
struct StatementTmplBuilder {
predicate: Predicate,
args: Vec<BuilderArg>,
}
fn st_tmpl(p: impl Into<Predicate>) -> StatementTmplBuilder {
StatementTmplBuilder {
predicate: p.into(),
args: Vec::new(),
}
}
impl StatementTmplBuilder {
fn arg(mut self, a: impl Into<BuilderArg>) -> Self {
self.args.push(a.into());
self
}
}
struct CustomPredicateBatchBuilder {
predicates: Vec<CustomPredicate>,
}
impl CustomPredicateBatchBuilder {
fn new() -> Self {
Self {
predicates: Vec::new(),
}
}
fn predicate_and(
&mut self,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
self.predicate(true, args, priv_args, sts)
}
fn predicate_or(
&mut self,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
self.predicate(false, args, priv_args, sts)
}
fn predicate(
&mut self,
conjunction: bool,
args: &[&str],
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Predicate {
use BuilderArg as BA;
let statements = sts
.iter()
.map(|sb| {
let args = sb
.args
.iter()
.map(|a| match a {
BA::Literal(v) => StatementTmplArg::Literal(*v),
BA::Key(pod_id, key) => StatementTmplArg::Key(
resolve_wildcard(args, priv_args, pod_id),
resolve_wildcard(args, priv_args, key),
),
})
.collect();
StatementTmpl(sb.predicate.clone(), args)
})
.collect();
let custom_predicate = CustomPredicate {
conjunction,
statements,
args_len: args.len(),
};
self.predicates.push(custom_predicate);
Predicate::BatchSelf(self.predicates.len() - 1)
}
fn finish(self) -> Arc<CustomPredicateBatch> {
Arc::new(CustomPredicateBatch {
predicates: self.predicates,
})
}
}
fn resolve_wildcard(
args: &[&str],
priv_args: &[&str],
v: &HashOrWildcardStr,
) -> HashOrWildcard {
match v {
HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h),
HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard(
args.iter()
.chain(priv_args.iter())
.enumerate()
.find_map(|(i, name)| (&s == name).then_some(i))
.unwrap(),
),
}
}
#[test]
fn test_custom_pred() {
use NativePredicate as NP;
let mut builder = CustomPredicateBatchBuilder::new();
let _eth_friend = builder.predicate_and(
&["src_or", "src_key", "dst_or", "dst_key"],
&["attestation_pod"],
&[
st_tmpl(NP::ValueOf)
.arg((w("attestation_pod"), l("type")))
.arg(PodType::Signed),
st_tmpl(NP::Equal)
.arg((w("attestation_pod"), l("signer")))
.arg((w("src_or"), w("src_key"))),
st_tmpl(NP::Equal)
.arg((w("attestation_pod"), l("attestation")))
.arg((w("dst_or"), w("dst_key"))),
],
);
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
let eth_friend = builder.finish();
// This batch only has 1 predicate, so we pick it already for convenience
let eth_friend = Predicate::Custom(eth_friend, 0);
let mut builder = CustomPredicateBatchBuilder::new();
let eth_dos_distance_base = builder.predicate_and(
&[
"src_or",
"src_key",
"dst_or",
"dst_key",
"distance_or",
"distance_key",
],
&[],
&[
st_tmpl(NP::Equal)
.arg((w("src_or"), l("src_key")))
.arg((w("dst_or"), w("dst_key"))),
st_tmpl(NP::ValueOf)
.arg((w("distance_or"), w("distance_key")))
.arg(0),
],
);
println!(
"b.0. eth_dos_distance_base = {}",
builder.predicates.last().unwrap()
);
let eth_dos_distance = Predicate::BatchSelf(3);
let eth_dos_distance_ind = builder.predicate_and(
&[
"src_or",
"src_key",
"dst_or",
"dst_key",
"distance_or",
"distance_key",
],
&[
"one_or",
"one_key",
"shorter_distance_or",
"shorter_distance_key",
"intermed_or",
"intermed_key",
],
&[
st_tmpl(eth_dos_distance)
.arg((w("src_or"), w("src_key")))
.arg((w("intermed_or"), w("intermed_key")))
.arg((w("shorter_distance_or"), w("shorter_distance_key"))),
// distance == shorter_distance + 1
st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1),
st_tmpl(NP::SumOf)
.arg((w("distance_or"), w("distance_key")))
.arg((w("shorter_distance_or"), w("shorter_distance_key")))
.arg((w("one_or"), w("one_key"))),
// intermed is a friend of dst
st_tmpl(eth_friend)
.arg((w("intermed_or"), w("intermed_key")))
.arg((w("dst_or"), w("dst_key"))),
],
);
println!(
"b.1. eth_dos_distance_ind = {}",
builder.predicates.last().unwrap()
);
let _eth_dos_distance = builder.predicate_or(
&[
"src_or",
"src_key",
"dst_or",
"dst_key",
"distance_or",
"distance_key",
],
&[],
&[
st_tmpl(eth_dos_distance_base)
.arg((w("src_or"), w("src_key")))
.arg((w("dst_or"), w("dst_key")))
.arg((w("distance_or"), w("distance_key"))),
st_tmpl(eth_dos_distance_ind)
.arg((w("src_or"), w("src_key")))
.arg((w("dst_or"), w("dst_key")))
.arg((w("distance_or"), w("distance_key"))),
],
);
println!(
"b.2. eth_dos_distance = {}",
builder.predicates.last().unwrap()
);
}
}

View file

@ -1,18 +1,20 @@
//! The middleware includes the type definitions and the traits used to connect the frontend and //! The middleware includes the type definitions and the traits used to connect the frontend and
//! the backend. //! the backend.
mod custom;
mod operation; mod operation;
mod statement; mod statement;
pub use custom::*;
pub use operation::*;
pub use statement::*;
use anyhow::{anyhow, Error, Result}; use anyhow::{anyhow, Error, Result};
use dyn_clone::DynClone; use dyn_clone::DynClone;
use hex::{FromHex, FromHexError}; use hex::{FromHex, FromHexError};
pub use operation::*;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, PrimeField64}; use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::poseidon::PoseidonHash; use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig}; use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig};
pub use statement::*;
use std::any::Any; use std::any::Any;
use std::cmp::{Ord, Ordering}; use std::cmp::{Ord, Ordering};
use std::collections::HashMap; use std::collections::HashMap;
@ -201,6 +203,7 @@ impl From<PodType> for Value {
pub fn hash_str(s: &str) -> Hash { pub fn hash_str(s: &str) -> Hash {
let mut input = s.as_bytes().to_vec(); let mut input = s.as_bytes().to_vec();
input.push(1); // padding input.push(1); // padding
// Merge 7 bytes into 1 field, because the field is slightly below 64 bits // Merge 7 bytes into 1 field, because the field is slightly below 64 bits
let input: Vec<F> = input let input: Vec<F> = input
.chunks(7) .chunks(7)

View file

@ -1,7 +1,7 @@
use crate::middleware::{AnchoredKey, SELF};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use super::Statement; use super::Statement;
use crate::middleware::{AnchoredKey, SELF};
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NativeOperation { pub enum NativeOperation {

View file

@ -10,7 +10,7 @@ pub const KEY_TYPE: &str = "_type";
pub const STATEMENT_ARG_F_LEN: usize = 8; pub const STATEMENT_ARG_F_LEN: usize = 8;
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)]
pub enum NativeStatement { pub enum NativePredicate {
None = 0, None = 0,
ValueOf = 1, ValueOf = 1,
Equal = 2, Equal = 2,
@ -24,7 +24,7 @@ pub enum NativeStatement {
MaxOf = 10, MaxOf = 10,
} }
impl ToFields for NativeStatement { impl ToFields for NativePredicate {
fn to_fields(self) -> (Vec<F>, usize) { fn to_fields(self) -> (Vec<F>, usize) {
(vec![F::from_canonical_u64(self as u64)], 1) (vec![F::from_canonical_u64(self as u64)], 1)
} }
@ -51,19 +51,19 @@ impl Statement {
pub fn is_none(&self) -> bool { pub fn is_none(&self) -> bool {
self == &Self::None self == &Self::None
} }
pub fn code(&self) -> NativeStatement { pub fn code(&self) -> NativePredicate {
match self { match self {
Self::None => NativeStatement::None, Self::None => NativePredicate::None,
Self::ValueOf(_, _) => NativeStatement::ValueOf, Self::ValueOf(_, _) => NativePredicate::ValueOf,
Self::Equal(_, _) => NativeStatement::Equal, Self::Equal(_, _) => NativePredicate::Equal,
Self::NotEqual(_, _) => NativeStatement::NotEqual, Self::NotEqual(_, _) => NativePredicate::NotEqual,
Self::Gt(_, _) => NativeStatement::Gt, Self::Gt(_, _) => NativePredicate::Gt,
Self::Lt(_, _) => NativeStatement::Lt, Self::Lt(_, _) => NativePredicate::Lt,
Self::Contains(_, _) => NativeStatement::Contains, Self::Contains(_, _) => NativePredicate::Contains,
Self::NotContains(_, _) => NativeStatement::NotContains, Self::NotContains(_, _) => NativePredicate::NotContains,
Self::SumOf(_, _, _) => NativeStatement::SumOf, Self::SumOf(_, _, _) => NativePredicate::SumOf,
Self::ProductOf(_, _, _) => NativeStatement::ProductOf, Self::ProductOf(_, _, _) => NativePredicate::ProductOf,
Self::MaxOf(_, _, _) => NativeStatement::MaxOf, Self::MaxOf(_, _, _) => NativePredicate::MaxOf,
} }
} }
pub fn args(&self) -> Vec<StatementArg> { pub fn args(&self) -> Vec<StatementArg> {