chore: enums for statement and op types (#69)

* Experiment with statement & op enums

* Clean-up & fixes

* More clean-up

* Add argument length checks

* More clean-up

* Place statement and operation logic in submodules
This commit is contained in:
Ahmad Afuni 2025-02-20 19:08:29 +10:00 committed by GitHub
parent 83a4f8969f
commit c2d23b0b1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 891 additions and 606 deletions

View file

@ -1,15 +1,21 @@
mod operation;
mod statement;
use crate::middleware::{
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod,
Params, Pod, PodId, PodProver, Statement, StatementArg, ToFields, KEY_TYPE, SELF,
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
};
use anyhow::Result;
use itertools::Itertools;
pub use operation::*;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
pub use statement::*;
use std::any::Any;
use std::error::Error;
use std::fmt;
pub const VALUE_TYPE: &str = "MockMainPOD";
pub struct MockProver {}
impl PodProver for MockProver {
@ -18,72 +24,6 @@ impl PodProver for MockProver {
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum OperationArg {
None,
Index(usize),
}
impl OperationArg {
fn is_none(&self) -> bool {
matches!(self, OperationArg::None)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum OperationArgError {
KeyNotFound,
StatementNotFound,
}
impl std::fmt::Display for OperationArgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OperationArgError::KeyNotFound => write!(f, "Key not found"),
OperationArgError::StatementNotFound => write!(f, "Statement not found"),
}
}
}
impl std::error::Error for OperationArgError {}
#[derive(Clone, Debug, PartialEq, Eq)]
struct Operation(pub NativeOperation, pub Vec<OperationArg>);
impl Operation {
pub fn deref(&self, statements: &[Statement]) -> crate::middleware::Operation {
let deref_args = self
.1
.iter()
.map(|arg| match arg {
OperationArg::None => middleware::OperationArg::None,
OperationArg::Index(i) => {
middleware::OperationArg::Statement(statements[*i].clone())
}
})
.collect();
middleware::Operation(self.0, deref_args)
}
}
impl fmt::Display for Operation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?} ", self.0)?;
for (i, arg) in self.1.iter().enumerate() {
if !(!f.alternate() && arg.is_none()) {
if i != 0 {
write!(f, " ")?;
}
match arg {
OperationArg::None => write!(f, "none")?,
OperationArg::Index(i) => write!(f, "{:02}", i)?,
}
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct MockMainPod {
params: Params,
@ -188,12 +128,16 @@ impl MockMainPod {
fn offset_public_statements(&self) -> usize {
self.offset_input_statements() + self.params.max_priv_statements()
}
fn pad_statement(params: &Params, s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, params.max_statement_args)
}
fn pad_operation(params: &Params, op: &mut Operation) {
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args)
}
fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec<Statement> {
let mut statements = Vec::new();
let st_none = Self::statement_none(params);
// Input signed pods region
let none_sig_pod: Box<dyn Pod> = Box::new(NonePod {});
assert!(inputs.signed_pods.len() <= params.max_input_signed_pods);
@ -206,8 +150,12 @@ impl MockMainPod {
let sts = pod.pub_statements();
assert!(sts.len() <= params.max_signed_pod_values);
for j in 0..params.max_signed_pod_values {
let mut st = sts.get(j).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = sts
.get(j)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
}
@ -224,8 +172,12 @@ impl MockMainPod {
let sts = pod.pub_statements();
assert!(sts.len() <= params.max_public_statements);
for j in 0..params.max_public_statements {
let mut st = sts.get(j).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = sts
.get(j)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
}
@ -233,54 +185,55 @@ impl MockMainPod {
// Input statements
assert!(inputs.statements.len() <= params.max_priv_statements());
for i in 0..params.max_priv_statements() {
let mut st = inputs.statements.get(i).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = inputs
.statements
.get(i)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
// Public statements
assert!(inputs.public_statements.len() < params.max_public_statements);
statements.push(Statement(
NativeStatement::ValueOf,
vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))],
));
let mut type_st = middleware::Statement::ValueOf(
AnchoredKey(SELF, hash_str(KEY_TYPE)),
middleware::Value(hash_str(VALUE_TYPE).0),
)
.into();
Self::pad_statement(params, &mut type_st);
statements.push(type_st);
for i in 0..(params.max_public_statements - 1) {
let mut st = inputs.public_statements.get(i).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = inputs
.public_statements
.get(i)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
statements
}
pub fn find_op_arg(
fn find_op_arg(
statements: &[Statement],
op_arg: &middleware::OperationArg,
op_arg: &middleware::Statement,
) -> Result<OperationArg, OperationArgError> {
match op_arg {
middleware::OperationArg::None => Ok(OperationArg::None),
middleware::OperationArg::Key(k) => {
statements
.iter()
.enumerate()
.find_map(|(i, s)| match s.0 {
NativeStatement::ValueOf => match &s.1[0] {
StatementArg::Key(sk) => (sk == k).then_some(i),
_ => None,
},
_ => None,
})
.map(OperationArg::Index)
.ok_or(OperationArgError::KeyNotFound)
}
middleware::OperationArg::Statement(st) => {
statements
.iter()
.enumerate()
.find_map(|(i, s)| (s == st).then_some(i))
.map(OperationArg::Index)
.ok_or(OperationArgError::StatementNotFound)
}
middleware::Statement::None => Ok(OperationArg::None),
_ => statements
.iter()
.enumerate()
.find_map(|(i, s)| {
// TODO: Error handling
(&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i)
})
.map(OperationArg::Index)
.ok_or(OperationArgError::StatementNotFound),
}
}
@ -289,19 +242,19 @@ impl MockMainPod {
statements: &[Statement],
input_operations: &[middleware::Operation],
) -> Result<Vec<Operation>, OperationArgError> {
let op_none = Self::operation_none(params);
let mut operations = Vec::new();
for i in 0..params.max_priv_statements() {
let op = input_operations.get(i).unwrap_or(&op_none).clone();
let mut mid_args = op.1;
Self::pad_operation_args(params, &mut mid_args);
let mut args = Vec::with_capacity(mid_args.len());
for mid_arg in &mid_args {
let op_arg = Self::find_op_arg(statements, mid_arg)?;
args.push(op_arg)
}
operations.push(Operation(op.0, args));
let op = input_operations
.get(i)
.unwrap_or(&middleware::Operation::None)
.clone();
let mid_args = op.args();
let mut args = mid_args
.iter()
.map(|mid_arg| Self::find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>, OperationArgError>>()?;
Self::pad_operation_args(params, &mut args);
operations.push(Operation(op.code(), args));
}
Ok(operations)
}
@ -320,11 +273,11 @@ impl MockMainPod {
let mut op = if st.is_none() {
Operation(NativeOperation::None, vec![])
} else {
let mid_arg = middleware::OperationArg::Statement(st.clone());
let op_arg = Self::find_op_arg(statements, &mid_arg)?;
let mid_arg = st.clone();
Operation(
NativeOperation::CopyStatement,
vec![op_arg],
// TODO
vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())?],
)
};
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
@ -351,7 +304,16 @@ impl MockMainPod {
.map(|p| (*p).clone())
.collect_vec();
let input_main_pods = inputs.main_pods.iter().map(|p| (*p).clone()).collect_vec();
let input_statements = inputs.statements.iter().cloned().collect_vec();
let input_statements = inputs
.statements
.iter()
.cloned()
.map(|s| {
let mut s = s.into();
Self::pad_statement(params, &mut s);
s
})
.collect_vec();
let public_statements =
statements[statements.len() - params.max_public_statements..].to_vec();
@ -376,26 +338,22 @@ impl MockMainPod {
Statement(NativeStatement::None, args)
}
fn operation_none(params: &Params) -> middleware::Operation {
let mut args = Vec::with_capacity(params.max_operation_args);
Self::pad_operation_args(&params, &mut args);
middleware::Operation(NativeOperation::None, args)
fn operation_none(params: &Params) -> Operation {
let mut op = Operation(NativeOperation::None, vec![]);
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
op
}
fn pad_statement_args(params: &Params, args: &mut Vec<StatementArg>) {
fill_pad(args, StatementArg::None, params.max_statement_args)
}
fn pad_operation_args(params: &Params, args: &mut Vec<middleware::OperationArg>) {
fill_pad(
args,
middleware::OperationArg::None,
params.max_operation_args,
)
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, params.max_operation_args)
}
}
pub fn hash_statements(statements: &[middleware::Statement]) -> Result<middleware::Hash> {
pub fn hash_statements(statements: &[Statement]) -> Result<middleware::Hash> {
let field_elems = statements
.into_iter()
.flat_map(|statement| statement.clone().to_fields().0)
@ -444,7 +402,7 @@ impl Pod for MockMainPod {
s,
)
})
.filter(|(i, s)| s.0 == NativeStatement::ValueOf)
.filter(|(_, s)| s.0 == NativeStatement::ValueOf)
.flat_map(|(i, s)| {
if let StatementArg::Key(ak) = &s.1[0] {
vec![(i, ak.1, ak.0)]
@ -463,7 +421,8 @@ impl Pod for MockMainPod {
.map(|(i, s)| {
self.operations[i]
.deref(&self.statements[..input_statement_offset + i])
.check(s.clone())
.unwrap()
.check(&s.clone().try_into().unwrap())
})
.collect::<Result<Vec<_>>>()
.unwrap();
@ -472,7 +431,7 @@ impl Pod for MockMainPod {
fn id(&self) -> PodId {
self.id
}
fn pub_statements(&self) -> Vec<Statement> {
fn pub_statements(&self) -> Vec<middleware::Statement> {
// return the public statements, where when origin=SELF is replaced by origin=self.id()
self.statements
.iter()
@ -492,6 +451,8 @@ impl Pod for MockMainPod {
})
.collect(),
)
.try_into()
.unwrap()
})
.collect()
}
@ -505,7 +466,10 @@ impl Pod for MockMainPod {
pub mod tests {
use super::*;
use crate::backends::mock_signed::MockSigner;
use crate::examples::{great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders};
use crate::examples::{
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder,
zu_kyc_sign_pod_builders,
};
use crate::middleware;
#[test]
@ -559,6 +523,6 @@ pub mod tests {
let pod = proof_pod.pod.into_any().downcast::<MockMainPod>().unwrap();
println!("{}", pod);
assert_eq!(pod.verify(), true);
assert_eq!(pod.verify(), true);
}
}