diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index c438217..98e4fa9 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,9 +1,9 @@ pub mod operation; -use crate::middleware::PodType; +use crate::middleware::{wildcard_values_from_op_st, PodType}; pub mod statement; use std::{iter, sync::Arc}; -use itertools::Itertools; +use itertools::{zip_eq, Itertools}; use num_bigint::BigUint; pub use operation::*; use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher}; @@ -37,9 +37,9 @@ use crate::{ serialize_proof, serialize_verifier_only, }, middleware::{ - self, resolve_wildcard_values, value_from_op, CustomPredicateBatch, - Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, - OperationType, Params, Pod, RawValue, StatementArg, ToFields, VDSet, + self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs, + MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, + ToFields, VDSet, }, timed, }; @@ -97,28 +97,35 @@ pub(crate) fn extract_custom_predicate_verifications( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], + statements: &[middleware::Statement], custom_predicate_batches: &[Arc], ) -> Result> { let mut table = Vec::new(); - for (i, op) in operations.iter().enumerate() { + for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() { if let middleware::Operation::Custom(cpr, sts) = op { - let wildcard_values = - resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards"); - let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); - let batch_index = custom_predicate_batches - .iter() - .enumerate() - .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) - .expect("find the custom predicate from the extracted unique list"); - let custom_predicate_table_index = - batch_index * params.max_custom_batch_size + cpr.index; - aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); - table.push(CustomPredicateVerification { - custom_predicate_table_index, - custom_predicate: cpr.clone(), - args: wildcard_values, - op_args: sts, - }); + if let middleware::Statement::Custom(st_cpr, st_args) = st { + assert_eq!(cpr, st_cpr); + let wildcard_values = + wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) + .expect("resolved wildcards"); + let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); + let batch_index = custom_predicate_batches + .iter() + .enumerate() + .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) + .expect("find the custom predicate from the extracted unique list"); + let custom_predicate_table_index = + batch_index * params.max_custom_batch_size + cpr.index; + aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); + table.push(CustomPredicateVerification { + custom_predicate_table_index, + custom_predicate: cpr.clone(), + args: wildcard_values, + op_args: sts, + }); + } else { + panic!("Custom operation paired with non-custom statement"); + } } } @@ -499,6 +506,7 @@ impl MainPodProver for Prover { params, &mut aux_list, inputs.operations, + inputs.statements, &custom_predicate_batches, )?; let public_key_of_sks = @@ -823,6 +831,7 @@ pub mod tests { frontend::{ self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB, }, + lang::parse, middleware::{ self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _, DEFAULT_VD_LIST, DEFAULT_VD_SET, @@ -1154,4 +1163,40 @@ pub mod tests { builder.prove(&prover)?; Ok(()) } + + #[test] + fn test_undetermined_values() { + let params = Default::default(); + let batch = parse( + r#" + two_equal(x,y,z) = OR( + Equal(x,y) + Equal(y,z) + Equal(x,z) + ) + "#, + ¶ms, + &[], + ) + .unwrap() + .custom_batch; + let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); + let cpr = CustomPredicateRef { batch, index: 0 }; + let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap(); + let op = frontend::Operation::custom( + cpr.clone(), + [ + eq_st, + middleware::Statement::None, + middleware::Statement::None, + ], + ); + let st = middleware::Statement::Custom( + cpr, + [1, 1, 2].into_iter().map(middleware::Value::from).collect(), + ); + builder.insert(true, (st, op)).unwrap(); + let prover = Prover {}; + builder.prove(&prover).unwrap(); + } } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index d8b7f5a..21d34f0 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,5 +1,6 @@ use std::{fmt, iter}; +use itertools::Itertools; use log::error; use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; @@ -14,8 +15,8 @@ use crate::{ }, middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, - NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmpl, - StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F, + MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg, + StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F, }, }; @@ -613,16 +614,11 @@ pub fn check_st_tmpl( } } -pub fn resolve_wildcard_values( - params: &Params, +pub fn fill_wildcard_values( pred: &CustomPredicate, args: &[Statement], -) -> Result> { - // Check that all wildcard have consistent values as assigned in the statements while storing a - // map of their values. - // NOTE: We assume the statements have the same order as defined in the custom predicate. For - // disjunctions we expect Statement::None for the unused statements. - let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; + wildcard_map: &mut [Option], +) -> Result<()> { for (st_tmpl, st) in pred.statements.iter().zip(args) { let st_args = st.args(); st_tmpl @@ -630,10 +626,25 @@ pub fn resolve_wildcard_values( .iter() .zip(&st_args) .try_for_each(|(st_tmpl_arg, st_arg)| { - check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) + check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) })?; } + Ok(()) +} +pub fn wildcard_values_from_op_st( + params: &Params, + pred: &CustomPredicate, + op_args: &[Statement], + st_args: &[Value], +) -> Result> { + let mut wildcard_map = st_args + .iter() + .map(|v| Some(v.clone())) + .chain(core::iter::repeat(None)) + .take(params.max_custom_predicate_wildcards) + .collect_vec(); + fill_wildcard_values(pred, op_args, &mut wildcard_map)?; // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // they are beyond the number of used wildcards in this custom predicate, or they could be // private arguments that are unused in a particular disjunction. @@ -717,21 +728,24 @@ pub(crate) fn check_custom_pred( )); } - let wildcard_map = resolve_wildcard_values(params, pred, args)?; - // Check that the resolved wildcards match the statement arguments. - for (arg_index, (s_arg, wc_value)) in s_args.iter().zip(wildcard_map.iter()).enumerate() { - if *wc_value != *s_arg { - return Err(Error::mismatched_wildcard_value_and_statement_arg( - wc_value.clone(), - s_arg.clone(), - arg_index, - pred.clone(), - )); - } + match wildcard_values_from_op_st(params, pred, args, s_args) { + Ok(_) => Ok(()), + Err(Error::Inner { inner, backtrace }) => match *inner { + MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) + if wc.index <= s_args.len() => + { + Err(Error::mismatched_wildcard_value_and_statement_arg( + v, + prev, + wc.index, + pred.clone(), + )) + } + _ => Err(Error::Inner { inner, backtrace }), + }, + _ => unreachable!(), } - - Ok(()) } impl ToFields for Operation {