Resolve #448 Previously a predicate was 6 elements. Now it grows to 8 elements; and the hash is 4 elements. Some parts of the circuit require only require equality checks with the predicate: that works with the predicate hash. Other parts require inspecting or working with particular elements in the predicate, those need the preimage of the predicate hash. Both `StatementTarget` and `StatementTmplTarget` have been updated to include the predicate hash and optionally the predicate. When the predicate is included, constraints are automatically generated for `pred_hash = hash(pred)`. We only include the predicate when needed.
160 lines
5.8 KiB
Rust
160 lines
5.8 KiB
Rust
use std::{fmt, iter};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::{
|
|
backends::plonky2::error::{Error, Result},
|
|
middleware::{self, NativePredicate, Params, Predicate, StatementArg, ToFields, Value},
|
|
};
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct Statement(pub Predicate, pub Vec<StatementArg>);
|
|
|
|
impl Eq for Statement {}
|
|
|
|
impl Statement {
|
|
pub fn is_none(&self) -> bool {
|
|
self.0 == Predicate::Native(NativePredicate::None)
|
|
}
|
|
pub fn predicate(&self) -> Predicate {
|
|
self.0.clone()
|
|
}
|
|
/// Argument method. Trailing Nones are filtered out.
|
|
pub fn args(&self) -> Vec<StatementArg> {
|
|
let maybe_last_arg_index = (0..self.1.len()).rev().find(|i| !self.1[*i].is_none());
|
|
match maybe_last_arg_index {
|
|
None => vec![],
|
|
Some(i) => self.1[0..i + 1].to_vec(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ToFields for Statement {
|
|
fn to_fields(&self, params: &Params) -> Vec<middleware::F> {
|
|
let mut fields = self.0.hash(params).to_fields(params);
|
|
fields.extend(
|
|
self.1
|
|
.iter()
|
|
.chain(iter::repeat(&StatementArg::None))
|
|
.take(params.max_statement_args)
|
|
.flat_map(|arg| arg.to_fields(params)),
|
|
);
|
|
fields
|
|
}
|
|
}
|
|
|
|
impl TryFrom<Statement> for middleware::Statement {
|
|
type Error = Error;
|
|
fn try_from(s: Statement) -> Result<Self> {
|
|
type S = middleware::Statement;
|
|
type NP = NativePredicate;
|
|
type SA = StatementArg;
|
|
let proper_args = s.args();
|
|
Ok(match s.0 {
|
|
Predicate::Native(np) => match (np, &proper_args.as_slice()) {
|
|
(NP::None, &[]) => S::None,
|
|
(NP::Equal, &[a1, a2]) => S::Equal(a1.try_into()?, a2.try_into()?),
|
|
(NP::NotEqual, &[a1, a2]) => S::NotEqual(a1.try_into()?, a2.try_into()?),
|
|
(NP::LtEq, &[a1, a2]) => S::LtEq(a1.try_into()?, a2.try_into()?),
|
|
(NP::Lt, &[a1, a2]) => S::Lt(a1.try_into()?, a2.try_into()?),
|
|
(NP::Contains, &[a1, a2, a3]) => {
|
|
S::Contains(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
(NP::NotContains, &[a1, a2]) => S::NotContains(a1.try_into()?, a2.try_into()?),
|
|
(NP::SumOf, &[a1, a2, a3]) => {
|
|
S::SumOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
(NP::ProductOf, &[a1, a2, a3]) => {
|
|
S::ProductOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
(NP::MaxOf, &[a1, a2, a3]) => {
|
|
S::MaxOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
(NP::HashOf, &[a1, a2, a3]) => {
|
|
S::HashOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
(NP::PublicKeyOf, &[a1, a2]) => S::PublicKeyOf(a1.try_into()?, a2.try_into()?),
|
|
(NP::SignedBy, &[a1, a2]) => S::SignedBy(a1.try_into()?, a2.try_into()?),
|
|
(NP::ContainerInsert, &[a1, a2, a3, a4]) => S::ContainerInsert(
|
|
a1.try_into()?,
|
|
a2.try_into()?,
|
|
a3.try_into()?,
|
|
a4.try_into()?,
|
|
),
|
|
(NP::ContainerUpdate, &[a1, a2, a3, a4]) => S::ContainerUpdate(
|
|
a1.try_into()?,
|
|
a2.try_into()?,
|
|
a3.try_into()?,
|
|
a4.try_into()?,
|
|
),
|
|
(NP::ContainerDelete, &[a1, a2, a3]) => {
|
|
S::ContainerDelete(a1.try_into()?, a2.try_into()?, a3.try_into()?)
|
|
}
|
|
_ => Err(Error::custom(format!(
|
|
"Ill-formed statement expression {:?}",
|
|
s
|
|
)))?,
|
|
},
|
|
Predicate::Custom(cpr) => {
|
|
let vs: Vec<Value> = proper_args
|
|
.into_iter()
|
|
.filter_map(|arg| match arg {
|
|
SA::None => None,
|
|
SA::Literal(v) => Some(v),
|
|
_ => unreachable!(),
|
|
})
|
|
.collect();
|
|
S::Custom(cpr, vs)
|
|
}
|
|
Predicate::Intro(ir) => {
|
|
let vs: Vec<Value> = proper_args
|
|
.into_iter()
|
|
.filter_map(|arg| match arg {
|
|
SA::None => None,
|
|
SA::Literal(v) => Some(v),
|
|
_ => unreachable!(),
|
|
})
|
|
.collect();
|
|
S::Intro(ir, vs)
|
|
}
|
|
Predicate::BatchSelf(_) => {
|
|
unreachable!()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
impl From<middleware::Statement> for Statement {
|
|
fn from(s: middleware::Statement) -> Self {
|
|
match s.predicate() {
|
|
middleware::Predicate::Native(c) => Statement(
|
|
middleware::Predicate::Native(c),
|
|
s.args().into_iter().collect(),
|
|
),
|
|
middleware::Predicate::Custom(cpr) => Statement(
|
|
middleware::Predicate::Custom(cpr),
|
|
s.args().into_iter().collect(),
|
|
),
|
|
middleware::Predicate::Intro(ir) => Statement(
|
|
middleware::Predicate::Intro(ir),
|
|
s.args().into_iter().collect(),
|
|
),
|
|
middleware::Predicate::BatchSelf(_) => unreachable!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Statement {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(f, "{:?} ", self.0)?;
|
|
for (i, arg) in self.1.iter().enumerate() {
|
|
if f.alternate() || !arg.is_none() {
|
|
if i != 0 {
|
|
write!(f, " ")?;
|
|
}
|
|
arg.fmt(f)?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|