Feat/fst order pred part3 & part4 (#457)
* support wildcard predicates in frontend * suport wildcard predicate in podlang * add validation test * test full flow and apply some fixes * fix clippy * fix merge issues * use desugared predicate * Fix parsing of intro statement templates inside custom predicates * Tidy up comments * lang: handle wildcard predicate * add unreachable message --------- Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
parent
b66f5051b5
commit
498e946612
11 changed files with 324 additions and 180 deletions
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
|
||||
middleware::{
|
||||
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
|
||||
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
||||
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -41,16 +41,34 @@ pub fn literal(v: impl Into<Value>) -> BuilderArg {
|
|||
BuilderArg::Literal(v.into())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum PredicateOrWildcard {
|
||||
Predicate(Predicate),
|
||||
Wildcard(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StatementTmplBuilder {
|
||||
pub(crate) predicate: Predicate,
|
||||
pub(crate) pred_or_wc: PredicateOrWildcard,
|
||||
pub(crate) args: Vec<BuilderArg>,
|
||||
}
|
||||
|
||||
impl StatementTmplBuilder {
|
||||
pub fn new(p: impl Into<Predicate>) -> StatementTmplBuilder {
|
||||
pub fn new_from_pred(p: impl Into<Predicate>) -> StatementTmplBuilder {
|
||||
StatementTmplBuilder {
|
||||
predicate: p.into(),
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(p.into()),
|
||||
args: Vec::new(),
|
||||
}
|
||||
}
|
||||
pub fn new_from_wc(p: impl Into<String>) -> StatementTmplBuilder {
|
||||
StatementTmplBuilder {
|
||||
pred_or_wc: PredicateOrWildcard::Wildcard(p.into()),
|
||||
args: Vec::new(),
|
||||
}
|
||||
}
|
||||
pub fn new(pred_or_wc: PredicateOrWildcard) -> StatementTmplBuilder {
|
||||
StatementTmplBuilder {
|
||||
pred_or_wc,
|
||||
args: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -62,68 +80,48 @@ impl StatementTmplBuilder {
|
|||
|
||||
/// Desugar the predicate to a simpler form
|
||||
/// Should mirror the logic in `MainPodBuilder::lower_op`
|
||||
pub(crate) fn desugar(self) -> StatementTmplBuilder {
|
||||
match self.predicate {
|
||||
Predicate::Native(NativePredicate::Gt) => {
|
||||
let mut stb = StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::Lt),
|
||||
args: self.args,
|
||||
};
|
||||
stb.args.swap(0, 1);
|
||||
stb
|
||||
}
|
||||
Predicate::Native(NativePredicate::GtEq) => {
|
||||
let mut stb = StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::LtEq),
|
||||
args: self.args,
|
||||
};
|
||||
stb.args.swap(0, 1);
|
||||
stb
|
||||
}
|
||||
Predicate::Native(NativePredicate::ArrayContains)
|
||||
| Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::Contains),
|
||||
args: self.args,
|
||||
},
|
||||
Predicate::Native(NativePredicate::DictNotContains)
|
||||
| Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::NotContains),
|
||||
args: self.args,
|
||||
},
|
||||
Predicate::Native(NativePredicate::SetContains) => {
|
||||
let mut new_args = self.args.clone();
|
||||
new_args.push(self.args[1].clone());
|
||||
StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::Contains),
|
||||
args: new_args,
|
||||
pub(crate) fn desugar(mut self) -> StatementTmplBuilder {
|
||||
let pred = match self.pred_or_wc {
|
||||
PredicateOrWildcard::Predicate(p) => p,
|
||||
PredicateOrWildcard::Wildcard(_) => return self,
|
||||
};
|
||||
let pred = match pred {
|
||||
Predicate::Native(nat_pred) => Predicate::Native(match nat_pred {
|
||||
NativePredicate::Gt => {
|
||||
self.args.swap(0, 1);
|
||||
NativePredicate::Lt
|
||||
}
|
||||
}
|
||||
Predicate::Native(NativePredicate::DictInsert) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::ContainerInsert),
|
||||
args: self.args,
|
||||
},
|
||||
Predicate::Native(NativePredicate::SetInsert) => {
|
||||
let mut new_args = self.args.clone();
|
||||
new_args.push(self.args[2].clone());
|
||||
StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::ContainerInsert),
|
||||
args: new_args,
|
||||
NativePredicate::GtEq => {
|
||||
self.args.swap(0, 1);
|
||||
NativePredicate::LtEq
|
||||
}
|
||||
}
|
||||
Predicate::Native(NativePredicate::DictUpdate)
|
||||
| Predicate::Native(NativePredicate::ArrayUpdate) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::ContainerUpdate),
|
||||
args: self.args,
|
||||
},
|
||||
Predicate::Native(NativePredicate::DictDelete) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::ContainerDelete),
|
||||
args: self.args,
|
||||
},
|
||||
Predicate::Native(NativePredicate::SetDelete) => StatementTmplBuilder {
|
||||
predicate: Predicate::Native(NativePredicate::ContainerDelete),
|
||||
args: self.args,
|
||||
},
|
||||
_ => self,
|
||||
NativePredicate::ArrayContains | NativePredicate::DictContains => {
|
||||
NativePredicate::Contains
|
||||
}
|
||||
NativePredicate::DictNotContains | NativePredicate::SetNotContains => {
|
||||
NativePredicate::NotContains
|
||||
}
|
||||
NativePredicate::SetContains => {
|
||||
self.args.push(self.args[1].clone());
|
||||
NativePredicate::Contains
|
||||
}
|
||||
NativePredicate::DictInsert => NativePredicate::ContainerInsert,
|
||||
NativePredicate::SetInsert => {
|
||||
self.args.push(self.args[2].clone());
|
||||
NativePredicate::ContainerInsert
|
||||
}
|
||||
NativePredicate::DictUpdate | NativePredicate::ArrayUpdate => {
|
||||
NativePredicate::ContainerUpdate
|
||||
}
|
||||
NativePredicate::DictDelete => NativePredicate::ContainerDelete,
|
||||
NativePredicate::SetDelete => NativePredicate::ContainerDelete,
|
||||
_ => nat_pred,
|
||||
}),
|
||||
_ => pred,
|
||||
};
|
||||
StatementTmplBuilder {
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(pred),
|
||||
args: self.args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -200,7 +198,7 @@ impl CustomPredicateBatchBuilder {
|
|||
.iter()
|
||||
.map(|sb| {
|
||||
let stb = sb.clone().desugar();
|
||||
let args = stb
|
||||
let st_tmpl_args = stb
|
||||
.args
|
||||
.iter()
|
||||
.map(|a| {
|
||||
|
|
@ -216,10 +214,17 @@ impl CustomPredicateBatchBuilder {
|
|||
})
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let pred_or_wc = match stb.pred_or_wc {
|
||||
PredicateOrWildcard::Predicate(p) => {
|
||||
middleware::PredicateOrWildcard::Predicate(p)
|
||||
}
|
||||
PredicateOrWildcard::Wildcard(v) => middleware::PredicateOrWildcard::Wildcard(
|
||||
resolve_wildcard(args, priv_args, &v)?,
|
||||
),
|
||||
};
|
||||
Ok(StatementTmpl {
|
||||
// TODO: Support wildcard
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()),
|
||||
args,
|
||||
pred_or_wc,
|
||||
args: st_tmpl_args,
|
||||
})
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
|
@ -299,7 +304,7 @@ mod tests {
|
|||
let vd_set = &*MOCK_VD_SET;
|
||||
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
|
||||
|
||||
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
|
||||
let gt_stb = StatementTmplBuilder::new_from_pred(NativePredicate::Gt)
|
||||
.arg("s1")
|
||||
.arg("s2");
|
||||
|
||||
|
|
@ -344,7 +349,7 @@ mod tests {
|
|||
let mut builder =
|
||||
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
|
||||
|
||||
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
|
||||
let set_contains_stb = StatementTmplBuilder::new_from_pred(NativePredicate::SetContains)
|
||||
.arg("s1")
|
||||
.arg("s2");
|
||||
|
||||
|
|
|
|||
|
|
@ -1,34 +1,13 @@
|
|||
use std::{backtrace::Backtrace, fmt::Debug};
|
||||
|
||||
use crate::middleware::{BackendError, Statement, StatementTmpl, Value};
|
||||
use crate::middleware::BackendError;
|
||||
|
||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||
|
||||
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
|
||||
let mut out = String::new();
|
||||
use std::fmt::Write;
|
||||
for (i, v) in wc_map.iter().enumerate() {
|
||||
write!(out, "- {}: ", i).unwrap();
|
||||
if let Some(v) = v {
|
||||
writeln!(out, "{}", v).unwrap();
|
||||
} else {
|
||||
writeln!(out, "none").unwrap();
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum InnerError {
|
||||
#[error("{0} {1} is over the limit {2}")]
|
||||
MaxLength(String, usize, usize),
|
||||
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
|
||||
StatementsDontMatch(
|
||||
Statement,
|
||||
StatementTmpl,
|
||||
Vec<Option<Value>>,
|
||||
crate::middleware::Error,
|
||||
),
|
||||
#[error("invalid arguments to {0} operation")]
|
||||
OpInvalidArgs(String),
|
||||
#[error("Podlang parse error: {0}")]
|
||||
|
|
@ -99,14 +78,6 @@ impl Error {
|
|||
pub(crate) fn op_invalid_args(s: String) -> Self {
|
||||
new!(OpInvalidArgs(s))
|
||||
}
|
||||
pub(crate) fn statements_dont_match(
|
||||
s0: Statement,
|
||||
s1: StatementTmpl,
|
||||
wc_map: Vec<Option<Value>>,
|
||||
mid_error: crate::middleware::Error,
|
||||
) -> Self {
|
||||
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
|
||||
}
|
||||
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
|
||||
new!(MaxLength(obj, found, expect))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize};
|
|||
pub use serialization::SerializedMainPod;
|
||||
|
||||
use crate::middleware::{
|
||||
self, check_custom_pred, check_st_tmpl, containers::Dictionary, hash_op, max_op, prod_op,
|
||||
sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux,
|
||||
OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, StatementArg, VDSet,
|
||||
Value, ValueRef,
|
||||
self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op,
|
||||
prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation,
|
||||
OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement,
|
||||
StatementArg, VDSet, Value, ValueRef,
|
||||
};
|
||||
|
||||
mod custom;
|
||||
|
|
@ -600,21 +600,7 @@ impl MainPodBuilder {
|
|||
}
|
||||
wildcard_map[index] = Some(value);
|
||||
}
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
|
||||
let st_args = st.args();
|
||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||
if let Err(st_tmpl_check_error) =
|
||||
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
|
||||
{
|
||||
return Err(Error::statements_dont_match(
|
||||
st.clone(),
|
||||
st_tmpl.clone(),
|
||||
wildcard_map,
|
||||
st_tmpl_check_error,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?;
|
||||
let v_default = Value::from(0);
|
||||
let st_args: Vec<_> = wildcard_map
|
||||
.into_iter()
|
||||
|
|
@ -817,7 +803,8 @@ pub mod tests {
|
|||
use super::*;
|
||||
use crate::{
|
||||
backends::plonky2::{
|
||||
mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer,
|
||||
basetypes::DEFAULT_VD_SET, mainpod::Prover, mock::mainpod::MockProver,
|
||||
primitives::ec::schnorr::SecretKey, signer::Signer,
|
||||
},
|
||||
dict,
|
||||
examples::{
|
||||
|
|
@ -1423,6 +1410,75 @@ pub mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_predicate() -> Result<()> {
|
||||
let mock = true;
|
||||
let params = Params::default();
|
||||
|
||||
let mock_prover = MockProver {};
|
||||
let real_prover = Prover {};
|
||||
let (vd_set, prover): (_, &dyn MainPodProver) = if mock {
|
||||
(&VDSet::new(&[]), &mock_prover)
|
||||
} else {
|
||||
println!("Prebuilding circuits to calculate vd_set...");
|
||||
let vd_set = &*DEFAULT_VD_SET;
|
||||
println!("vd_set calculation complete");
|
||||
(vd_set, &real_prover)
|
||||
};
|
||||
|
||||
let input = r#"
|
||||
Test(a, b, private: c) = AND(
|
||||
Equal(a, 5)
|
||||
b(1, 5)
|
||||
c(6, 3)
|
||||
)
|
||||
"#;
|
||||
let batch = parse(input, ¶ms, &[])
|
||||
.unwrap()
|
||||
.first_batch()
|
||||
.unwrap()
|
||||
.clone();
|
||||
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||
|
||||
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||
let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap();
|
||||
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
|
||||
let st2 = builder.priv_op(Operation::ne(6, 3)).unwrap();
|
||||
let _st = builder
|
||||
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
|
||||
.unwrap();
|
||||
|
||||
let pod = builder.prove(prover).unwrap();
|
||||
pod.pod.verify().unwrap();
|
||||
|
||||
let input = r#"
|
||||
Test(a, b, private: c) = OR(
|
||||
Equal(a, 5)
|
||||
b(1, 5)
|
||||
c(6, 3)
|
||||
)
|
||||
"#;
|
||||
let batch = parse(input, ¶ms, &[])
|
||||
.unwrap()
|
||||
.first_batch()
|
||||
.unwrap()
|
||||
.clone();
|
||||
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||
|
||||
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||
let st0 = Statement::None;
|
||||
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
|
||||
let st2 = Statement::None;
|
||||
let _st = builder
|
||||
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
|
||||
.unwrap();
|
||||
|
||||
let pod = builder.prove(prover).unwrap();
|
||||
pod.pod.verify().unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_predicate_e2e() -> Result<()> {
|
||||
// End-to-end test of apply_predicate with MockProver
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue