This commit is contained in:
Daniel Gulotta 2025-09-15 07:14:24 -07:00 committed by GitHub
parent 1d14338351
commit 26548cf612
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 106 additions and 47 deletions

View file

@ -1,9 +1,9 @@
pub mod operation; pub mod operation;
use crate::middleware::PodType; use crate::middleware::{wildcard_values_from_op_st, PodType};
pub mod statement; pub mod statement;
use std::{iter, sync::Arc}; use std::{iter, sync::Arc};
use itertools::Itertools; use itertools::{zip_eq, Itertools};
use num_bigint::BigUint; use num_bigint::BigUint;
pub use operation::*; pub use operation::*;
use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher}; use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher};
@ -37,9 +37,9 @@ use crate::{
serialize_proof, serialize_verifier_only, serialize_proof, serialize_verifier_only,
}, },
middleware::{ middleware::{
self, resolve_wildcard_values, value_from_op, CustomPredicateBatch, self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs,
Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
OperationType, Params, Pod, RawValue, StatementArg, ToFields, VDSet, ToFields, VDSet,
}, },
timed, timed,
}; };
@ -97,13 +97,17 @@ pub(crate) fn extract_custom_predicate_verifications(
params: &Params, params: &Params,
aux_list: &mut [OperationAux], aux_list: &mut [OperationAux],
operations: &[middleware::Operation], operations: &[middleware::Operation],
statements: &[middleware::Statement],
custom_predicate_batches: &[Arc<CustomPredicateBatch>], custom_predicate_batches: &[Arc<CustomPredicateBatch>],
) -> Result<Vec<CustomPredicateVerification>> { ) -> Result<Vec<CustomPredicateVerification>> {
let mut table = Vec::new(); let mut table = Vec::new();
for (i, op) in operations.iter().enumerate() { for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() {
if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Operation::Custom(cpr, sts) = op {
if let middleware::Statement::Custom(st_cpr, st_args) = st {
assert_eq!(cpr, st_cpr);
let wildcard_values = let wildcard_values =
resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards"); wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args)
.expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let batch_index = custom_predicate_batches let batch_index = custom_predicate_batches
.iter() .iter()
@ -119,6 +123,9 @@ pub(crate) fn extract_custom_predicate_verifications(
args: wildcard_values, args: wildcard_values,
op_args: sts, op_args: sts,
}); });
} else {
panic!("Custom operation paired with non-custom statement");
}
} }
} }
@ -499,6 +506,7 @@ impl MainPodProver for Prover {
params, params,
&mut aux_list, &mut aux_list,
inputs.operations, inputs.operations,
inputs.statements,
&custom_predicate_batches, &custom_predicate_batches,
)?; )?;
let public_key_of_sks = let public_key_of_sks =
@ -823,6 +831,7 @@ pub mod tests {
frontend::{ frontend::{
self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB,
}, },
lang::parse,
middleware::{ middleware::{
self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _, self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _,
DEFAULT_VD_LIST, DEFAULT_VD_SET, DEFAULT_VD_LIST, DEFAULT_VD_SET,
@ -1154,4 +1163,40 @@ pub mod tests {
builder.prove(&prover)?; builder.prove(&prover)?;
Ok(()) Ok(())
} }
#[test]
fn test_undetermined_values() {
let params = Default::default();
let batch = parse(
r#"
two_equal(x,y,z) = OR(
Equal(x,y)
Equal(y,z)
Equal(x,z)
)
"#,
&params,
&[],
)
.unwrap()
.custom_batch;
let mut builder = MainPodBuilder::new(&params, &DEFAULT_VD_SET);
let cpr = CustomPredicateRef { batch, index: 0 };
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();
let op = frontend::Operation::custom(
cpr.clone(),
[
eq_st,
middleware::Statement::None,
middleware::Statement::None,
],
);
let st = middleware::Statement::Custom(
cpr,
[1, 1, 2].into_iter().map(middleware::Value::from).collect(),
);
builder.insert(true, (st, op)).unwrap();
let prover = Prover {};
builder.prove(&prover).unwrap();
}
} }

View file

@ -1,5 +1,6 @@
use std::{fmt, iter}; use std::{fmt, iter};
use itertools::Itertools;
use log::error; use log::error;
use plonky2::field::types::Field; use plonky2::field::types::Field;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -14,8 +15,8 @@ use crate::{
}, },
middleware::{ middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmpl, MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
}, },
}; };
@ -613,16 +614,11 @@ pub fn check_st_tmpl(
} }
} }
pub fn resolve_wildcard_values( pub fn fill_wildcard_values(
params: &Params,
pred: &CustomPredicate, pred: &CustomPredicate,
args: &[Statement], args: &[Statement],
) -> Result<Vec<Value>> { wildcard_map: &mut [Option<Value>],
// Check that all wildcard have consistent values as assigned in the statements while storing a ) -> Result<()> {
// map of their values.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) { for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args(); let st_args = st.args();
st_tmpl st_tmpl
@ -630,10 +626,25 @@ pub fn resolve_wildcard_values(
.iter() .iter()
.zip(&st_args) .zip(&st_args)
.try_for_each(|(st_tmpl_arg, st_arg)| { .try_for_each(|(st_tmpl_arg, st_arg)| {
check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map)
})?; })?;
} }
Ok(())
}
pub fn wildcard_values_from_op_st(
params: &Params,
pred: &CustomPredicate,
op_args: &[Statement],
st_args: &[Value],
) -> Result<Vec<Value>> {
let mut wildcard_map = st_args
.iter()
.map(|v| Some(v.clone()))
.chain(core::iter::repeat(None))
.take(params.max_custom_predicate_wildcards)
.collect_vec();
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be // they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction. // private arguments that are unused in a particular disjunction.
@ -717,21 +728,24 @@ pub(crate) fn check_custom_pred(
)); ));
} }
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
// Check that the resolved wildcards match the statement arguments. // Check that the resolved wildcards match the statement arguments.
for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() { match wildcard_values_from_op_st(params, pred, args, s_args) {
if *wc_value != *s_arg { Ok(_) => Ok(()),
return Err(Error::mismatched_wildcard_value_and_statement_arg( Err(Error::Inner { inner, backtrace }) => match *inner {
wc_value.clone(), MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
s_arg.clone(), if wc.index <= s_args.len() =>
arg_index, {
Err(Error::mismatched_wildcard_value_and_statement_arg(
v,
prev,
wc.index,
pred.clone(), pred.clone(),
)); ))
} }
_ => Err(Error::Inner { inner, backtrace }),
},
_ => unreachable!(),
} }
Ok(())
} }
impl ToFields for Operation { impl ToFields for Operation {