Feat/fst order pred part3 & part4 (#457)

* support wildcard predicates in frontend

* suport wildcard predicate in podlang

* add validation test

* test full flow and apply some fixes

* fix clippy

* fix merge issues

* use desugared predicate

* Fix parsing of intro statement templates inside custom predicates

* Tidy up comments

* lang: handle wildcard predicate

* add unreachable message

---------

Co-authored-by: Rob Knight <mail@robknight.org.uk>
This commit is contained in:
Eduard S. 2026-02-02 10:59:33 +01:00 committed by GitHub
parent b66f5051b5
commit 498e946612
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 324 additions and 180 deletions

View file

@ -7,7 +7,7 @@ use crate::{
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
middleware::{
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
},
};
@ -41,16 +41,34 @@ pub fn literal(v: impl Into<Value>) -> BuilderArg {
BuilderArg::Literal(v.into())
}
#[derive(Clone, Debug)]
pub enum PredicateOrWildcard {
Predicate(Predicate),
Wildcard(String),
}
#[derive(Clone)]
pub struct StatementTmplBuilder {
pub(crate) predicate: Predicate,
pub(crate) pred_or_wc: PredicateOrWildcard,
pub(crate) args: Vec<BuilderArg>,
}
impl StatementTmplBuilder {
pub fn new(p: impl Into<Predicate>) -> StatementTmplBuilder {
pub fn new_from_pred(p: impl Into<Predicate>) -> StatementTmplBuilder {
StatementTmplBuilder {
predicate: p.into(),
pred_or_wc: PredicateOrWildcard::Predicate(p.into()),
args: Vec::new(),
}
}
pub fn new_from_wc(p: impl Into<String>) -> StatementTmplBuilder {
StatementTmplBuilder {
pred_or_wc: PredicateOrWildcard::Wildcard(p.into()),
args: Vec::new(),
}
}
pub fn new(pred_or_wc: PredicateOrWildcard) -> StatementTmplBuilder {
StatementTmplBuilder {
pred_or_wc,
args: Vec::new(),
}
}
@ -62,68 +80,48 @@ impl StatementTmplBuilder {
/// Desugar the predicate to a simpler form
/// Should mirror the logic in `MainPodBuilder::lower_op`
pub(crate) fn desugar(self) -> StatementTmplBuilder {
match self.predicate {
Predicate::Native(NativePredicate::Gt) => {
let mut stb = StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Lt),
args: self.args,
};
stb.args.swap(0, 1);
stb
}
Predicate::Native(NativePredicate::GtEq) => {
let mut stb = StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::LtEq),
args: self.args,
};
stb.args.swap(0, 1);
stb
}
Predicate::Native(NativePredicate::ArrayContains)
| Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: self.args,
},
Predicate::Native(NativePredicate::DictNotContains)
| Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::NotContains),
args: self.args,
},
Predicate::Native(NativePredicate::SetContains) => {
let mut new_args = self.args.clone();
new_args.push(self.args[1].clone());
StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: new_args,
pub(crate) fn desugar(mut self) -> StatementTmplBuilder {
let pred = match self.pred_or_wc {
PredicateOrWildcard::Predicate(p) => p,
PredicateOrWildcard::Wildcard(_) => return self,
};
let pred = match pred {
Predicate::Native(nat_pred) => Predicate::Native(match nat_pred {
NativePredicate::Gt => {
self.args.swap(0, 1);
NativePredicate::Lt
}
}
Predicate::Native(NativePredicate::DictInsert) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerInsert),
args: self.args,
},
Predicate::Native(NativePredicate::SetInsert) => {
let mut new_args = self.args.clone();
new_args.push(self.args[2].clone());
StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerInsert),
args: new_args,
NativePredicate::GtEq => {
self.args.swap(0, 1);
NativePredicate::LtEq
}
}
Predicate::Native(NativePredicate::DictUpdate)
| Predicate::Native(NativePredicate::ArrayUpdate) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerUpdate),
args: self.args,
},
Predicate::Native(NativePredicate::DictDelete) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerDelete),
args: self.args,
},
Predicate::Native(NativePredicate::SetDelete) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::ContainerDelete),
args: self.args,
},
_ => self,
NativePredicate::ArrayContains | NativePredicate::DictContains => {
NativePredicate::Contains
}
NativePredicate::DictNotContains | NativePredicate::SetNotContains => {
NativePredicate::NotContains
}
NativePredicate::SetContains => {
self.args.push(self.args[1].clone());
NativePredicate::Contains
}
NativePredicate::DictInsert => NativePredicate::ContainerInsert,
NativePredicate::SetInsert => {
self.args.push(self.args[2].clone());
NativePredicate::ContainerInsert
}
NativePredicate::DictUpdate | NativePredicate::ArrayUpdate => {
NativePredicate::ContainerUpdate
}
NativePredicate::DictDelete => NativePredicate::ContainerDelete,
NativePredicate::SetDelete => NativePredicate::ContainerDelete,
_ => nat_pred,
}),
_ => pred,
};
StatementTmplBuilder {
pred_or_wc: PredicateOrWildcard::Predicate(pred),
args: self.args,
}
}
}
@ -200,7 +198,7 @@ impl CustomPredicateBatchBuilder {
.iter()
.map(|sb| {
let stb = sb.clone().desugar();
let args = stb
let st_tmpl_args = stb
.args
.iter()
.map(|a| {
@ -216,10 +214,17 @@ impl CustomPredicateBatchBuilder {
})
})
.collect::<Result<_>>()?;
let pred_or_wc = match stb.pred_or_wc {
PredicateOrWildcard::Predicate(p) => {
middleware::PredicateOrWildcard::Predicate(p)
}
PredicateOrWildcard::Wildcard(v) => middleware::PredicateOrWildcard::Wildcard(
resolve_wildcard(args, priv_args, &v)?,
),
};
Ok(StatementTmpl {
// TODO: Support wildcard
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()),
args,
pred_or_wc,
args: st_tmpl_args,
})
})
.collect::<Result<_>>()?;
@ -299,7 +304,7 @@ mod tests {
let vd_set = &*MOCK_VD_SET;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
let gt_stb = StatementTmplBuilder::new_from_pred(NativePredicate::Gt)
.arg("s1")
.arg("s2");
@ -344,7 +349,7 @@ mod tests {
let mut builder =
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
let set_contains_stb = StatementTmplBuilder::new_from_pred(NativePredicate::SetContains)
.arg("s1")
.arg("s2");

View file

@ -1,34 +1,13 @@
use std::{backtrace::Backtrace, fmt::Debug};
use crate::middleware::{BackendError, Statement, StatementTmpl, Value};
use crate::middleware::BackendError;
pub type Result<T, E = Error> = core::result::Result<T, E>;
fn display_wc_map(wc_map: &[Option<Value>]) -> String {
let mut out = String::new();
use std::fmt::Write;
for (i, v) in wc_map.iter().enumerate() {
write!(out, "- {}: ", i).unwrap();
if let Some(v) = v {
writeln!(out, "{}", v).unwrap();
} else {
writeln!(out, "none").unwrap();
}
}
out
}
#[derive(thiserror::Error, Debug)]
pub enum InnerError {
#[error("{0} {1} is over the limit {2}")]
MaxLength(String, usize, usize),
#[error("{0} doesn't match {1:#}.\nWildcard map:\n{map}\nInternal error: {3}", map=display_wc_map(.2))]
StatementsDontMatch(
Statement,
StatementTmpl,
Vec<Option<Value>>,
crate::middleware::Error,
),
#[error("invalid arguments to {0} operation")]
OpInvalidArgs(String),
#[error("Podlang parse error: {0}")]
@ -99,14 +78,6 @@ impl Error {
pub(crate) fn op_invalid_args(s: String) -> Self {
new!(OpInvalidArgs(s))
}
pub(crate) fn statements_dont_match(
s0: Statement,
s1: StatementTmpl,
wc_map: Vec<Option<Value>>,
mid_error: crate::middleware::Error,
) -> Self {
new!(StatementsDontMatch(s0, s1, wc_map, mid_error))
}
pub(crate) fn max_length(obj: String, found: usize, expect: usize) -> Self {
new!(MaxLength(obj, found, expect))
}

View file

@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize};
pub use serialization::SerializedMainPod;
use crate::middleware::{
self, check_custom_pred, check_st_tmpl, containers::Dictionary, hash_op, max_op, prod_op,
sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux,
OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, StatementArg, VDSet,
Value, ValueRef,
self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op,
prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation,
OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement,
StatementArg, VDSet, Value, ValueRef,
};
mod custom;
@ -600,21 +600,7 @@ impl MainPodBuilder {
}
wildcard_map[index] = Some(value);
}
for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if let Err(st_tmpl_check_error) =
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map)
{
return Err(Error::statements_dont_match(
st.clone(),
st_tmpl.clone(),
wildcard_map,
st_tmpl_check_error,
));
}
}
}
fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?;
let v_default = Value::from(0);
let st_args: Vec<_> = wildcard_map
.into_iter()
@ -817,7 +803,8 @@ pub mod tests {
use super::*;
use crate::{
backends::plonky2::{
mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer,
basetypes::DEFAULT_VD_SET, mainpod::Prover, mock::mainpod::MockProver,
primitives::ec::schnorr::SecretKey, signer::Signer,
},
dict,
examples::{
@ -1423,6 +1410,75 @@ pub mod tests {
Ok(())
}
#[test]
fn test_wildcard_predicate() -> Result<()> {
let mock = true;
let params = Params::default();
let mock_prover = MockProver {};
let real_prover = Prover {};
let (vd_set, prover): (_, &dyn MainPodProver) = if mock {
(&VDSet::new(&[]), &mock_prover)
} else {
println!("Prebuilding circuits to calculate vd_set...");
let vd_set = &*DEFAULT_VD_SET;
println!("vd_set calculation complete");
(vd_set, &real_prover)
};
let input = r#"
Test(a, b, private: c) = AND(
Equal(a, 5)
b(1, 5)
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap();
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
let st2 = builder.priv_op(Operation::ne(6, 3)).unwrap();
let _st = builder
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
.unwrap();
let pod = builder.prove(prover).unwrap();
pod.pod.verify().unwrap();
let input = r#"
Test(a, b, private: c) = OR(
Equal(a, 5)
b(1, 5)
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let st0 = Statement::None;
let st1 = builder.priv_op(Operation::lt(1, 5)).unwrap();
let st2 = Statement::None;
let _st = builder
.op(true, vec![], Operation::custom(pred_test, [st0, st1, st2]))
.unwrap();
let pod = builder.prove(prover).unwrap();
pod.pod.verify().unwrap();
Ok(())
}
#[test]
fn test_apply_predicate_e2e() -> Result<()> {
// End-to-end test of apply_predicate with MockProver