Split Params into base and developer-defined (#458)

I thought it would be nice to have a Predicate for the typed value so that the developer can work with predicates as values comfortably.  Then I noticed that hashing a predicate required `Params` which would have been annoying for converting a `TypedValue::Predicate` to `RawValue` and this led to a small refactor over how `Params` work.

We already had some fields in the `Params` struct that determine compatibility between encoded data.  They can be seen as determining a kind of ABI compatibility.  In general it's better if those parameters don't change so that different circuit configurations can still verify proofs from each other.  So I decided to force those parameters to be constant in the code base and not allow the user of our library to change them.  Many field element serialization/deserialization functions in our code depended on those parameters, and since now they are constant many functions get rid of the `Params` argument, which simplifies the code.  This includes the serialization of a `Predicate` which was required to calculate its hash.
This commit is contained in:
Eduard S. 2026-02-02 16:23:32 +01:00 committed by GitHub
parent 498e946612
commit a7a30176a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 376 additions and 468 deletions

View file

@ -101,13 +101,8 @@ pub struct StatementArgTarget {
}
impl StatementArgTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
arg: &StatementArg,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?)
pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?)
}
pub fn new(first: ValueTarget, second: ValueTarget) -> Self {
@ -190,7 +185,7 @@ impl StatementTarget {
.iter()
.cloned()
.chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(params.max_statement_args)
.take(Params::max_statement_args())
.collect(),
}
}
@ -205,24 +200,19 @@ impl StatementTarget {
Self::new_with_pred(builder, params, pred, args)
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st: &Statement,
) -> Result<()> {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, st: &Statement) -> Result<()> {
if let Some(pred) = &self.pred {
pred.set_targets(pw, params, &st.predicate())?;
pred.set_targets(pw, &st.predicate())?;
}
pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash(params)))?;
pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash()))?;
for (i, arg) in st
.args()
.iter()
.chain(iter::repeat(&StatementArg::None))
.take(params.max_statement_args)
.take(Params::max_statement_args())
.enumerate()
{
self.args[i].set_targets(pw, params, arg)?;
self.args[i].set_targets(pw, arg)?;
}
Ok(())
}
@ -235,14 +225,9 @@ impl StatementTarget {
builder.is_equal_flattenable(&self.pred_hash, &blank_intro)
}
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder,
params: &Params,
t: NativePredicate,
) -> BoolTarget {
pub fn has_native_type(&self, builder: &mut CircuitBuilder, t: NativePredicate) -> BoolTarget {
let expected_predicate_hash =
builder.constant_hash(HashOut::from(Predicate::Native(t).hash(params)));
builder.constant_hash(HashOut::from(Predicate::Native(t).hash()));
builder.is_equal_flattenable(&self.pred_hash, &expected_predicate_hash)
}
}
@ -252,8 +237,8 @@ pub trait Build<T> {
}
impl Build<NativePredicateTarget> for NativePredicate {
fn build(self, builder: &mut CircuitBuilder, params: &Params) -> NativePredicateTarget {
NativePredicateTarget::constant(builder, params, self)
fn build(self, builder: &mut CircuitBuilder, _params: &Params) -> NativePredicateTarget {
NativePredicateTarget::constant(builder, self)
}
}
@ -301,13 +286,8 @@ impl OperationTypeTarget {
builder.and(op_is_native, op_code_matches)
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
op_type: &OperationType,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?)
pub fn set_targets(&self, pw: &mut PartialWitness<F>, op_type: &OperationType) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields())?)
}
fn size(_params: &Params) -> usize {
@ -330,7 +310,7 @@ impl OperationTarget {
params: &Params,
op: &Operation,
) -> Result<()> {
self.op_type.set_targets(pw, params, &op.op_type())?;
self.op_type.set_targets(pw, &op.op_type())?;
for (i, arg) in op
.args()
.iter()
@ -354,12 +334,8 @@ impl OperationTarget {
pub struct NativePredicateTarget(Target);
impl NativePredicateTarget {
pub fn constant(
builder: &mut CircuitBuilder,
params: &Params,
native_predicate: NativePredicate,
) -> Self {
let id = native_predicate.to_fields(params);
pub fn constant(builder: &mut CircuitBuilder, native_predicate: NativePredicate) -> Self {
let id = native_predicate.to_fields();
assert_eq!(1, id.len());
Self(builder.constant(id[0]))
}
@ -367,10 +343,9 @@ impl NativePredicateTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
native_predicate: NativePredicate,
) -> Result<()> {
let id = native_predicate.to_fields(params);
let id = native_predicate.to_fields();
assert_eq!(1, id.len());
Ok(pw.set_target(self.0, id[0])?)
}
@ -431,13 +406,8 @@ impl PredicateTarget {
builder.is_equal(prefix, self.elements[0])
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: &Predicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?)
pub fn set_targets(&self, pw: &mut PartialWitness<F>, predicate: &Predicate) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields())?)
}
pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
@ -534,10 +504,9 @@ impl StatementTmplArgTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st_tmpl_arg: &StatementTmplArg,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?)
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields())?)
}
}
@ -588,7 +557,6 @@ impl PredicateHashOrWildcardTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
pred: &PredicateOrWildcard,
) -> Result<()> {
match pred {
@ -596,7 +564,7 @@ impl PredicateHashOrWildcardTarget {
self.set_targets_raw(
pw,
PredicateOrWildcardPrefix::Predicate,
RawValue::from(pred.hash(params)),
RawValue::from(pred.hash()),
)?;
}
PredicateOrWildcard::Wildcard(wc) => {
@ -650,19 +618,14 @@ impl StatementTmplTarget {
args,
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st_tmpl: &StatementTmpl,
) -> Result<()> {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, st_tmpl: &StatementTmpl) -> Result<()> {
if let Some(pred) = &self.pred {
match &st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(p) => {
// We store a predicate (not a wildcard) and we have it available. In this
// case the hash will be calculated by constraints later on and we should not
// rely on the original data.
pred.set_targets(pw, params, p)?
pred.set_targets(pw, p)?
}
PredicateOrWildcard::Wildcard(_wc) => {
// Fill in with a recognizable constant for better debugging; this value is
@ -671,17 +634,16 @@ impl StatementTmplTarget {
}
}
}
self.pred_hash_or_wc
.set_targets(pw, params, &st_tmpl.pred_or_wc)?;
self.pred_hash_or_wc.set_targets(pw, &st_tmpl.pred_or_wc)?;
let arg_pad = StatementTmplArg::None;
for (i, arg) in st_tmpl
.args
.iter()
.chain(iter::repeat(&arg_pad))
.take(params.max_statement_args)
.take(Params::max_statement_args())
.enumerate()
{
self.args[i].set_targets(pw, params, arg)?;
self.args[i].set_targets(pw, arg)?;
}
Ok(())
}
@ -705,7 +667,6 @@ impl CustomPredicateTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
custom_pred: &CustomPredicate,
) -> Result<()> {
pw.set_target(
@ -717,10 +678,10 @@ impl CustomPredicateTarget {
.statements
.iter()
.chain(iter::repeat(&st_tmpl_pad))
.take(params.max_custom_predicate_arity)
.take(Params::max_custom_predicate_arity())
.enumerate()
{
self.statements[i].set_targets(pw, params, st_tmpl)?;
self.statements[i].set_targets(pw, st_tmpl)?;
}
pw.set_target(self.args_len, F::from_canonical_usize(custom_pred.args_len))?;
Ok(())
@ -743,7 +704,6 @@ impl CustomPredicateBatchTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
custom_predicate_batch: &CustomPredicateBatch,
) -> Result<()> {
let pad_predicate = CustomPredicate::empty();
@ -751,10 +711,10 @@ impl CustomPredicateBatchTarget {
.predicates()
.iter()
.chain(iter::repeat(&pad_predicate))
.take(params.max_custom_batch_size)
.take(Params::max_custom_batch_size())
.enumerate()
{
self.predicates[i].set_targets(pw, params, predicate)?;
self.predicates[i].set_targets(pw, predicate)?;
}
Ok(())
}
@ -772,7 +732,6 @@ impl CustomPredicateEntryTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: &CustomPredicateRef,
) -> Result<()> {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
@ -808,7 +767,7 @@ impl CustomPredicateEntryTarget {
args_len: predicate.args_len,
wildcard_names: predicate.wildcard_names.clone(),
};
self.predicate.set_targets(pw, params, &predicate)?;
self.predicate.set_targets(pw, &predicate)?;
Ok(())
}
}
@ -854,18 +813,18 @@ pub struct CustomPredicateVerifyEntryTarget {
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
let custom_predicate_table_len =
params.max_custom_predicate_batches * params.max_custom_batch_size;
params.max_custom_predicate_batches * Params::max_custom_batch_size();
CustomPredicateVerifyEntryTarget {
custom_predicate_table_index: IndexTarget::new_virtual(
custom_predicate_table_len,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
custom_predicate: builder.add_virtual_custom_predicate_entry(),
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect(),
op_args: (0..params.max_operation_args)
.map(|_| builder.add_virtual_statement(params, false))
.map(|_| builder.add_virtual_statement(false))
.collect(),
}
}
@ -879,7 +838,7 @@ impl CustomPredicateVerifyEntryTarget {
.set_targets(pw, cpv.custom_predicate_table_index)?;
// Replace statement templates of batch-self with (id,index)
self.custom_predicate
.set_targets(pw, params, &cpv.custom_predicate)?;
.set_targets(pw, &cpv.custom_predicate)?;
let pad_arg = Value::from(0);
for (arg_target, arg) in self.args.iter().zip_eq(
cpv.args
@ -896,7 +855,7 @@ impl CustomPredicateVerifyEntryTarget {
.chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args),
) {
op_arg_target.set_targets(pw, params, op_arg)?
op_arg_target.set_targets(pw, op_arg)?
}
Ok(())
}
@ -1138,7 +1097,7 @@ impl Flattenable for StatementTarget {
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let predicate_hash = HashOutTarget::from_flattened(params, &v[..HASH_SIZE]);
let args = (0..params.max_statement_args)
let args = (0..Params::max_statement_args())
.map(|i| StatementArgTarget {
elements: array::from_fn(|j| v[HASH_SIZE + i * STATEMENT_ARG_F_LEN + j]),
})
@ -1152,7 +1111,7 @@ impl Flattenable for StatementTarget {
}
fn size(params: &Params) -> usize {
HASH_SIZE + params.max_statement_args * StatementArgTarget::size(params)
HASH_SIZE + Params::max_statement_args() * StatementArgTarget::size(params)
}
}
@ -1170,8 +1129,8 @@ impl Flattenable for CustomPredicateTarget {
// this `BoolTarget` should actually safe.
let conjunction = BoolTarget::new_unsafe(v[0]);
let args_len = v[1];
let st_tmpl_size = params.statement_tmpl_size();
let statements = (0..params.max_custom_predicate_arity)
let st_tmpl_size = Params::statement_tmpl_size();
let statements = (0..Params::max_custom_predicate_arity())
.map(|i| {
let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)];
StatementTmplTarget::from_flattened(params, st_v)
@ -1184,7 +1143,7 @@ impl Flattenable for CustomPredicateTarget {
}
}
fn size(params: &Params) -> usize {
2 + params.max_custom_predicate_arity * StatementTmplTarget::size(params)
2 + Params::max_custom_predicate_arity() * StatementTmplTarget::size(params)
}
}
@ -1203,7 +1162,7 @@ impl Flattenable for StatementTmplTarget {
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..params.max_statement_args)
let args = (0..Params::max_statement_args())
.map(|i| {
let sta_v = &v
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
@ -1219,7 +1178,7 @@ impl Flattenable for StatementTmplTarget {
fn size(params: &Params) -> usize {
Params::pred_hash_or_wc_size()
+ params.max_statement_args * StatementTmplArgTarget::size(params)
+ Params::max_statement_args() * StatementTmplArgTarget::size(params)
}
}
@ -1278,29 +1237,17 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget;
fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget;
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
fn add_virtual_statement_tmpl(
&mut self,
params: &Params,
with_pred: bool,
) -> StatementTmplTarget;
fn add_virtual_custom_predicate(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
-> CustomPredicateEntryTarget;
fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget;
fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(&mut self, with_pred: bool)
-> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_statement_arg(
&mut self,
@ -1396,7 +1343,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
/// If `with_pred = true` a predicate is included and its hash constrained.
/// If `with_pred = false` only the predicate hash is included.
fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget {
fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget {
let (pred, pred_hash) = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
@ -1408,7 +1355,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
StatementTarget {
pred,
pred_hash,
args: (0..params.max_statement_args)
args: (0..Params::max_statement_args())
.map(|_| self.add_virtual_statement_arg())
.collect(),
}
@ -1452,11 +1399,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
/// If `with_pred = false` only the predicate hash is included.
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
/// predicate and not a wildcard.
fn add_virtual_statement_tmpl(
&mut self,
params: &Params,
with_pred: bool,
) -> StatementTmplTarget {
fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget {
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
let pred = if with_pred {
@ -1474,20 +1417,16 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
StatementTmplTarget {
pred,
pred_hash_or_wc,
args: (0..params.max_statement_args)
args: (0..Params::max_statement_args())
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateTarget {
let statements = (0..params.max_custom_predicate_arity)
.map(|_| self.add_virtual_statement_tmpl(params, with_pred))
fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget {
let statements = (0..Params::max_custom_predicate_arity())
.map(|_| self.add_virtual_statement_tmpl(with_pred))
.collect();
CustomPredicateTarget {
conjunction: self.add_virtual_bool_target_safe(),
@ -1499,25 +1438,21 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_batch(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateBatchTarget {
CustomPredicateBatchTarget {
predicates: (0..params.max_custom_batch_size)
.map(|_| self.add_virtual_custom_predicate(params, with_pred))
predicates: (0..Params::max_custom_batch_size())
.map(|_| self.add_virtual_custom_predicate(with_pred))
.collect(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
) -> CustomPredicateEntryTarget {
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
id: self.add_virtual_hash(),
index: self.add_virtual_target(),
predicate: self.add_virtual_custom_predicate(params, false),
predicate: self.add_virtual_custom_predicate(false),
}
}
@ -1998,7 +1933,7 @@ pub(crate) mod tests {
for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() {
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let flattened = cp.to_fields(&params);
let flattened = cp.to_fields();
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target);
// Round trip of from_flattened to flattened
@ -2018,20 +1953,18 @@ pub(crate) mod tests {
}
fn helper_custom_predicate_batch_target_id(
params: &Params,
custom_predicate_batch: &CustomPredicateBatch,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let custom_predicate_batch_target =
builder.add_virtual_custom_predicate_batch(params, false);
let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(false);
// Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder);
let mut pw = PartialWitness::<F>::new();
custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?;
custom_predicate_batch_target.set_targets(&mut pw, custom_predicate_batch)?;
let id = custom_predicate_batch.id();
pw.set_target_arr(&id_target.elements, &id.0)?;
@ -2046,7 +1979,6 @@ pub(crate) mod tests {
#[test]
fn test_custom_predicate_batch_target_id() -> frontend::Result<()> {
let params = Params {
max_statement_args: 6,
max_custom_predicate_wildcards: 12,
..Default::default()
};
@ -2055,15 +1987,15 @@ pub(crate) mod tests {
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish();
helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
// Some cases from the examples
let custom_predicate_batch = eth_dos_batch(&params)?;
helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
let custom_predicate_batch =
CustomPredicateBatch::new(&params, "empty".to_string(), vec![CustomPredicate::empty()]);
helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
Ok(())
}
@ -2079,17 +2011,13 @@ pub(crate) mod tests {
let sum_target = builder.i64_add(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
let params = Params::default();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
let mut pw = PartialWitness::<F>::new();
let (sum, overflow) = x.overflowing_add(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(&params))?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(&params))?;
pw.set_target_arr(
&sum_target.elements,
&RawValue::from(sum).to_fields(&params),
)?;
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?;
pw.set_target_arr(&sum_target.elements, &RawValue::from(sum).to_fields())?;
let proof = data.prove(pw);
@ -2113,18 +2041,14 @@ pub(crate) mod tests {
let prod_target = builder.i64_mul(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
let params = Params::default();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
println!("{}, {}", x, y);
let mut pw = PartialWitness::<F>::new();
let (prod, overflow) = x.overflowing_mul(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(&params))?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(&params))?;
pw.set_target_arr(
&prod_target.elements,
&RawValue::from(prod).to_fields(&params),
)?;
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?;
pw.set_target_arr(&prod_target.elements, &RawValue::from(prod).to_fields())?;
let proof = data.prove(pw);

View file

@ -104,13 +104,12 @@ impl StatementCache {
.collect::<Vec<_>>()
};
assert!(params.max_operation_args >= MAX_VALUE_ARGS);
assert!(params.max_statement_args >= MAX_VALUE_ARGS);
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
let equations = array::from_fn(|i| {
let pred_is_none = op_args[i].has_native_type(builder, params, NativePredicate::None);
let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None);
let arg_is_value = builder.statement_arg_is_value(&st.args[i]);
let is_literal = builder.and(pred_is_none, arg_is_value);
let pred_is_contains =
op_args[i].has_native_type(builder, params, NativePredicate::Contains);
let pred_is_contains = op_args[i].has_native_type(builder, NativePredicate::Contains);
let ref_is_value_arg: [_; 3] =
array::from_fn(|j| builder.statement_arg_is_value(&op_args[i].args[j]));
let ref_is_value = builder.and(ref_is_value_arg[0], ref_is_value_arg[1]);
@ -435,8 +434,8 @@ fn verify_operation_circuit(
if !cache.op_args.is_empty() {
op_checks.extend_from_slice(&[
verify_copy_circuit(builder, st, &op.op_type, &cache.op_args),
verify_eq_neq_from_entries_circuit(params, builder, st, &op.op_type, &cache),
verify_lt_lteq_from_entries_circuit(params, builder, st, &op.op_type, &cache),
verify_eq_neq_from_entries_circuit(builder, st, &op.op_type, &cache),
verify_lt_lteq_from_entries_circuit(builder, st, &op.op_type, &cache),
verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args),
verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args),
verify_hash_of_circuit(params, builder, st, &op.op_type, &cache),
@ -881,7 +880,6 @@ fn verify_custom_circuit(
/// Carries out the checks necessary for EqualFromEntries and
/// NotEqualFromEntries.
fn verify_eq_neq_from_entries_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
@ -890,12 +888,12 @@ fn verify_eq_neq_from_entries_circuit(
let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries");
let eq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries);
let st_code_ok = st.has_native_type(builder, params, NativePredicate::Equal);
let st_code_ok = st.has_native_type(builder, NativePredicate::Equal);
builder.and(op_code_ok, st_code_ok)
};
let neq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries);
let st_code_ok = st.has_native_type(builder, params, NativePredicate::NotEqual);
let st_code_ok = st.has_native_type(builder, NativePredicate::NotEqual);
builder.and(op_code_ok, st_code_ok)
};
let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok);
@ -911,7 +909,7 @@ fn verify_eq_neq_from_entries_circuit(
let expected_st_args: Vec<_> = [arg1_expected, arg2_expected]
.into_iter()
.chain(std::iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(params.max_statement_args)
.take(Params::max_statement_args())
.flat_map(|arg| arg.elements)
.collect();
@ -931,7 +929,6 @@ fn verify_eq_neq_from_entries_circuit(
/// Carries out the checks necessary for LtFromEntries and
/// LtEqFromEntries.
fn verify_lt_lteq_from_entries_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
@ -943,12 +940,12 @@ fn verify_lt_lteq_from_entries_circuit(
let lt_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries);
let st_code_ok = st.has_native_type(builder, params, NativePredicate::Lt);
let st_code_ok = st.has_native_type(builder, NativePredicate::Lt);
builder.and(op_code_ok, st_code_ok)
};
let lteq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries);
let st_code_ok = st.has_native_type(builder, params, NativePredicate::LtEq);
let st_code_ok = st.has_native_type(builder, NativePredicate::LtEq);
builder.and(op_code_ok, st_code_ok)
};
let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok);
@ -981,7 +978,7 @@ fn verify_lt_lteq_from_entries_circuit(
let expected_st_args: Vec<_> = [arg1_expected, arg2_expected]
.into_iter()
.chain(std::iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(params.max_statement_args)
.take(Params::max_statement_args())
.flat_map(|arg| arg.elements)
.collect();
@ -1233,8 +1230,8 @@ fn verify_transitive_eq_circuit(
let measure = measure_gates_begin!(builder, "OpTransitiveEq");
let op_code_ok = op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements);
let arg1_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Equal);
let arg2_type_ok = resolved_op_args[1].has_native_type(builder, params, NativePredicate::Equal);
let arg1_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Equal);
let arg2_type_ok = resolved_op_args[1].has_native_type(builder, NativePredicate::Equal);
let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]);
let arg1_lhs = &resolved_op_args[0].args[0];
@ -1285,7 +1282,7 @@ fn verify_lt_to_neq_circuit(
let measure = measure_gates_begin!(builder, "OpLtToNeq");
let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual);
let arg_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Lt);
let arg_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Lt);
let arg1_expected = resolved_op_args[0].args[0].clone();
let arg2_expected = resolved_op_args[0].args[1].clone();
@ -1442,7 +1439,7 @@ fn make_custom_statement_circuit(
let st_predicate = PredicateTarget::new_custom(builder, batch_id, index);
let arg_none = ValueTarget::zero(builder);
let lt_mask = builder.lt_mask(
params.max_statement_args,
Params::max_statement_args(),
custom_predicate.predicate.args_len,
);
let st_args = std::iter::zip(lt_mask, args)
@ -1466,7 +1463,7 @@ fn make_custom_statement_circuit(
.collect();
// expected_sts.len() == params.max_custom_predicate_arity
// op_args.len() == params.max_operation_args;
assert!(params.max_custom_predicate_arity <= params.max_operation_args);
assert!(Params::max_custom_predicate_arity() <= params.max_operation_args);
let sts_eq: Vec<_> = expected_sts
.iter()
@ -1508,19 +1505,18 @@ fn normalize_statement_circuit(
/// statements reversed. The part of the hash from the front-padded none-statements is
/// precomputed.
pub fn calculate_statements_hash_circuit(
params: &Params,
builder: &mut CircuitBuilder,
// These statements will be padded to reach `num_statements`
statements: &[StatementTarget],
) -> HashOutTarget {
assert!(statements.len() <= params.num_public_statements_hash);
assert!(statements.len() <= Params::num_public_statements_hash());
let measure = measure_gates_begin!(builder, "CalculateStsHash");
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(params, &mut none_st);
pad_statement(&mut none_st);
let front_pad_elts = iter::repeat(&none_st)
.take(params.num_public_statements_hash - statements.len())
.flat_map(|s| s.to_fields(params))
.take(Params::num_public_statements_hash() - statements.len())
.flat_map(|s| s.to_fields())
.collect_vec();
let (perm, front_pad_elts_rem) =
precompute_hash_state::<F, PoseidonPermutation<F>>(&front_pad_elts);
@ -1581,7 +1577,7 @@ fn build_custom_predicate_table_circuit(
) -> Result<Vec<HashOutTarget>> {
let measure = measure_gates_begin!(builder, "BuildCustomPredTbl");
let mut custom_predicate_table =
Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size);
Vec::with_capacity(params.max_custom_predicate_batches * Params::max_custom_batch_size());
for cpb in custom_predicate_batches {
let measure_cpb = measure_gates_begin!(builder, "CustomPredBatch");
let id = cpb.id(builder); // constrain the id
@ -1655,7 +1651,7 @@ fn verify_main_pod_circuit(
let mut intro_ok = is_blank_intro;
for self_st in &input_pod_self_statements[1..] {
let st_is_intro = self_st.pred_is_blank_intro(builder);
let st_is_none = self_st.has_native_type(builder, params, NativePredicate::None);
let st_is_none = self_st.has_native_type(builder, NativePredicate::None);
let st_is_intro_or_none = builder.or(st_is_intro, st_is_none);
intro_ok = builder.and(intro_ok, st_is_intro_or_none);
}
@ -1671,8 +1667,7 @@ fn verify_main_pod_circuit(
);
statements.push(normalized_st);
}
let sts_hash =
calculate_statements_hash_circuit(params, builder, input_pod_self_statements);
let sts_hash = calculate_statements_hash_circuit(builder, input_pod_self_statements);
builder.connect_hashes(expected_sts_hash, sts_hash);
//
@ -1730,7 +1725,7 @@ fn verify_main_pod_circuit(
)?;
// 2. Calculate the Pod Id from the public statements
let sts_hash = calculate_statements_hash_circuit(params, builder, pub_statements);
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
// 5. Verify input statements
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
@ -1774,12 +1769,12 @@ impl MainPodVerifyTarget {
input_pods_self_statements: (0..params.max_input_pods)
.map(|_| {
(0..params.max_input_pods_public_statements)
.map(|_| builder.add_virtual_statement(params, false))
.map(|_| builder.add_virtual_statement(false))
.collect_vec()
})
.collect(),
input_statements: (0..params.max_statements)
.map(|_| builder.add_virtual_statement(params, false))
.map(|_| builder.add_virtual_statement(false))
.collect(),
operations: (0..params.max_statements)
.map(|_| builder.add_virtual_operation(params))
@ -1805,7 +1800,7 @@ impl MainPodVerifyTarget {
})
.collect(),
custom_predicate_batches: (0..params.max_custom_predicate_batches)
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
.map(|_| builder.add_virtual_custom_predicate_batch(true))
.collect(),
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
@ -1849,16 +1844,16 @@ fn set_targets_input_pods_self_statements(
statements_target.len(),
params.max_input_pods_public_statements
);
assert!(statements.len() <= params.num_public_statements_hash);
assert!(statements.len() <= Params::num_public_statements_hash());
for (i, statement) in statements.iter().enumerate() {
statements_target[i].set_targets(pw, params, &statement.clone().into())?;
statements_target[i].set_targets(pw, &statement.clone().into())?;
}
// Padding
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(params, &mut none_st);
pad_statement(&mut none_st);
for statement_target in statements_target.iter().skip(statements.len()) {
statement_target.set_targets(pw, params, &none_st)?;
statement_target.set_targets(pw, &none_st)?;
}
Ok(())
}
@ -1903,7 +1898,7 @@ impl InnerCircuit for MainPodVerifyTarget {
}
// Padding
if input_pods_len != self.params.max_input_pods {
let empty_pod = EmptyPod::new_boxed(&self.params, input.vds_set.clone());
let empty_pod = EmptyPod::new_boxed(input.vds_set.clone());
let empty_pod_statements = empty_pod.pub_statements();
let empty_mt_proof = MerkleClaimAndProof {
root: input.vds_set.root(),
@ -1924,7 +1919,7 @@ impl InnerCircuit for MainPodVerifyTarget {
assert_eq!(input.statements.len(), self.params.max_statements);
for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() {
self.input_statements[i].set_targets(pw, &self.params, st)?;
self.input_statements[i].set_targets(pw, st)?;
self.operations[i].set_targets(pw, &self.params, op)?;
}
@ -1979,7 +1974,7 @@ impl InnerCircuit for MainPodVerifyTarget {
assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches);
for (i, cpb) in input.custom_predicate_batches.iter().enumerate() {
self.custom_predicate_batches[i].set_targets(pw, &self.params, cpb)?;
self.custom_predicate_batches[i].set_targets(pw, cpb)?;
}
// Padding
let pad_cpb = CustomPredicateBatch::new(
@ -1988,7 +1983,7 @@ impl InnerCircuit for MainPodVerifyTarget {
vec![CustomPredicate::empty()],
);
for i in input.custom_predicate_batches.len()..self.params.max_custom_predicate_batches {
self.custom_predicate_batches[i].set_targets(pw, &self.params, &pad_cpb)?;
self.custom_predicate_batches[i].set_targets(pw, &pad_cpb)?;
}
assert!(
@ -2048,7 +2043,7 @@ mod tests {
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE,
},
};
@ -2108,10 +2103,10 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_target = builder.add_virtual_statement(&params, false);
let st_target = builder.add_virtual_statement(false);
let op_target = builder.add_virtual_operation(&params);
let prev_statements_target: Vec<_> = (0..prev_statements.len())
.map(|_| builder.add_virtual_statement(&params, false))
.map(|_| builder.add_virtual_statement(false))
.collect();
let merkle_proofs_target: Vec<_> = aux
@ -2166,10 +2161,10 @@ mod tests {
)?;
let mut pw = PartialWitness::<F>::new();
st_target.set_targets(&mut pw, &params, &st)?;
st_target.set_targets(&mut pw, &st)?;
op_target.set_targets(&mut pw, &params, &op)?;
for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) {
prev_st_target.set_targets(&mut pw, &params, prev_st)?;
prev_st_target.set_targets(&mut pw, prev_st)?;
}
for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) {
signed_by_target.set_targets(&mut pw, signed_by)?
@ -3065,11 +3060,11 @@ mod tests {
let mut pw = PartialWitness::<F>::new();
st_tmpl_arg_target.set_targets(&mut pw, params, &st_tmpl_arg)?;
st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?;
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, arg)?;
}
expected_st_arg_target.set_targets(&mut pw, params, &expected_st_arg)?;
expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?;
// generate & verify proof
let data = builder.build::<C>();
@ -3122,7 +3117,7 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_tmpl_target = builder.add_virtual_statement_tmpl(params, false);
let st_tmpl_target = builder.add_virtual_statement_tmpl(false);
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
.collect();
@ -3133,16 +3128,16 @@ mod tests {
&args_target,
);
// TODO: Instead of connect, assign witness to result
let expected_st_target = builder.add_virtual_statement(params, false);
let expected_st_target = builder.add_virtual_statement(false);
builder.connect_flattenable(&expected_st_target, &st_target);
let mut pw = PartialWitness::<F>::new();
st_tmpl_target.set_targets(&mut pw, params, &st_tmpl)?;
st_tmpl_target.set_targets(&mut pw, &st_tmpl)?;
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, arg)?;
}
expected_st_target.set_targets(&mut pw, params, &expected_st.into())?;
expected_st_target.set_targets(&mut pw, &expected_st.into())?;
// generate & verify proof
let data = builder.build::<C>();
@ -3179,7 +3174,7 @@ mod tests {
StatementTmplArg::Literal(Value::from("value")),
],
};
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(&params);
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash();
let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)];
let expected_st = Statement::not_equal(
AnchoredKey::new(dict, Key::from("key")),
@ -3193,16 +3188,24 @@ mod tests {
fn helper_custom_operation_verify_gadget(
params: &Params,
custom_predicate: CustomPredicateRef,
op_args: Vec<Statement>,
args: Vec<Value>,
mut op_args: Vec<Statement>,
mut args: Vec<Value>,
expected_st: Option<Statement>,
) -> Result<()> {
// Pad
for _ in op_args.len()..params.max_operation_args {
op_args.push(Statement::None);
}
for _ in args.len()..params.max_custom_predicate_wildcards {
args.push(Value::from(EMPTY_VALUE));
}
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
let op_args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_statement(params, false))
let custom_predicate_target = builder.add_virtual_custom_predicate_entry();
let op_args_target: Vec<_> = (0..op_args.len())
.map(|_| builder.add_virtual_statement(false))
.collect();
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
@ -3218,20 +3221,20 @@ mod tests {
let mut pw = PartialWitness::<F>::new();
// Input
custom_predicate_target.set_targets(&mut pw, params, &custom_predicate)?;
custom_predicate_target.set_targets(&mut pw, &custom_predicate)?;
for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) {
op_arg_target.set_targets(&mut pw, params, &op_arg.into())?;
op_arg_target.set_targets(&mut pw, &op_arg.into())?;
}
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?;
}
// Expected Output
if let Some(expected_st) = expected_st {
st_target.set_targets(&mut pw, params, &expected_st.into())?;
st_target.set_targets(&mut pw, &expected_st.into())?;
}
let expected_op_type = OperationType::Custom(custom_predicate);
op_type_target.set_targets(&mut pw, params, &expected_op_type)?;
op_type_target.set_targets(&mut pw, &expected_op_type)?;
// generate & verify proof
let data = builder.build::<C>();
@ -3242,15 +3245,7 @@ mod tests {
// TODO: Add negative tests
#[test]
fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> {
// We set the parameters to the exact sizes we have in the test so that we don't have to
// pad.
let params = Params {
max_custom_predicate_arity: 2,
max_custom_predicate_wildcards: 2,
max_operation_args: 2,
max_statement_args: 2,
..Default::default()
};
let params = Params::default();
use NativePredicate as NP;
use StatementTmplBuilder as STB;
@ -3340,15 +3335,7 @@ mod tests {
#[test]
fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> {
// We set the parameters to the exact sizes we have in the test so that we don't have to
// pad.
let params = Params {
max_custom_predicate_arity: 2,
max_custom_predicate_wildcards: 2,
max_operation_args: 2,
max_statement_args: 2,
..Default::default()
};
let params = Params::default();
use NativePredicate as NP;
use StatementTmplBuilder as STB;
@ -3500,10 +3487,9 @@ mod tests {
let mut builder = CircuitBuilder::new(config);
let statements_target = (0..params.max_public_statements)
.map(|_| builder.add_virtual_statement(params, false))
.map(|_| builder.add_virtual_statement(false))
.collect_vec();
let sts_hash_target =
calculate_statements_hash_circuit(params, &mut builder, &statements_target);
let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target);
let mut pw = PartialWitness::<F>::new();
@ -3512,15 +3498,15 @@ mod tests {
.iter()
.map(|st| {
let mut st = mainpod::Statement::from(st.clone());
pad_statement(params, &mut st);
pad_statement(&mut st);
st
})
.collect_vec();
for (st_target, st) in statements_target.iter().zip(statements.iter()) {
st_target.set_targets(&mut pw, params, st)?;
st_target.set_targets(&mut pw, st)?;
}
// Expected Output
let expected_sts_hash = calculate_statements_hash(&statements, params);
let expected_sts_hash = calculate_statements_hash(&statements);
pw.set_hash_target(
sts_hash_target,
HashOut {
@ -3536,10 +3522,10 @@ mod tests {
#[test]
fn test_calculate_sts_hash() -> frontend::Result<()> {
assert_eq!(Params::num_public_statements_hash(), 16);
// Case with no public public statements
let params = Params {
max_public_statements: 0,
num_public_statements_hash: 8,
..Default::default()
};
@ -3547,30 +3533,20 @@ mod tests {
// Case with number of statements for the sts_hash equal to number of public statements
let params = Params {
max_public_statements: 2,
num_public_statements_hash: 2,
max_public_statements: Params::num_public_statements_hash(),
..Default::default()
};
let dict = Hash([F(1), F(2), F(3), F(4)]);
let statements = [
Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)),
Statement::equal(
AnchoredKey::from((dict, "bar")),
AnchoredKey::from((dict, "baz")),
),
]
.into_iter()
.chain(iter::repeat(Statement::None))
.take(params.max_public_statements)
.collect_vec();
let statements = (0..Params::num_public_statements_hash())
.map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64)))
.collect_vec();
helper_calculate_statements_hash(&params, &statements).unwrap();
// Case with more statements for the sts_hash than the number of public statements
// Case with more statements for the sts_hash than the number of public statements
let params = Params {
max_public_statements: 4,
num_public_statements_hash: 6,
..Default::default()
};

View file

@ -67,11 +67,9 @@ fn verify_empty_pod_circuit(
builder: &mut CircuitBuilder,
empty_pod: &EmptyPodVerifyTarget,
) {
let empty_statement = StatementTarget::from_flattened(
params,
&builder.constants(&empty_statement().to_fields(params)),
);
let sts_hash = calculate_statements_hash_circuit(params, builder, &[empty_statement]);
let empty_statement =
StatementTarget::from_flattened(params, &builder.constants(&empty_statement().to_fields()));
let sts_hash = calculate_statements_hash_circuit(builder, &[empty_statement]);
builder.register_public_inputs(&sts_hash.elements);
builder.register_public_inputs(&empty_pod.vds_root.elements);
}
@ -126,7 +124,7 @@ fn build() -> Result<(EmptyPodVerifyTarget, CircuitData)> {
}
impl EmptyPod {
fn new(params: &Params, vd_set: VDSet) -> Result<EmptyPod> {
fn new(vd_set: VDSet) -> Result<EmptyPod> {
let (empty_pod_verify_target, data) = &*cache_get_standard_empty_pod_circuit_data();
let mut pw = PartialWitness::<F>::new();
@ -139,7 +137,7 @@ impl EmptyPod {
};
let common_hash = hash_common_data(&data.common).expect("hash ok");
Ok(EmptyPod {
params: params.clone(),
params: Params::default(),
verifier_only: VerifierOnlyCircuitDataSerializer(data.verifier_only.clone()),
common_hash,
sts_hash,
@ -147,15 +145,10 @@ impl EmptyPod {
proof: proof.proof,
})
}
pub fn new_boxed(params: &Params, vd_set: VDSet) -> Box<dyn Pod> {
let default_params = Params::default();
assert_eq!(default_params.id_params(), params.id_params());
let empty_pod = cache::get(
"empty_pod",
&(default_params, vd_set),
|(params, vd_set)| Self::new(params, vd_set.clone()).expect("prove EmptyPod"),
)
pub fn new_boxed(vd_set: VDSet) -> Box<dyn Pod> {
let empty_pod = cache::get("empty_pod", &vd_set, |vd_set| {
Self::new(vd_set.clone()).expect("prove EmptyPod")
})
.expect("cache ok");
Box::new(empty_pod.clone())
}
@ -178,13 +171,13 @@ impl Pod for EmptyPod {
.into_iter()
.map(mainpod::Statement::from)
.collect_vec();
let sts_hash = calculate_statements_hash(&statements, &self.params);
let sts_hash = calculate_statements_hash(&statements);
if sts_hash != self.sts_hash {
return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash));
}
let public_inputs = sts_hash
.to_fields(&self.params)
.to_fields()
.iter()
.chain(self.vd_set.root().0.iter())
.cloned()
@ -258,9 +251,7 @@ pub mod tests {
#[test]
fn test_empty_pod() {
let params = Params::default();
let empty_pod = EmptyPod::new_boxed(&params, VDSet::new(&[]));
let empty_pod = EmptyPod::new_boxed(VDSet::new(&[]));
empty_pod.verify().unwrap();
}
}

View file

@ -50,21 +50,20 @@ use crate::{
/// circuits with a small `max_public_statements` only pay for `max_public_statements` by starting
/// the poseidon state with a precomputed constant corresponding to the front-padding part: `id =
/// hash(serialize(reverse(statements || none-statements)))`
pub fn calculate_statements_hash(statements: &[Statement], params: &Params) -> middleware::Hash {
assert!(statements.len() <= params.num_public_statements_hash);
assert!(params.max_public_statements <= params.num_public_statements_hash);
pub fn calculate_statements_hash(statements: &[Statement]) -> middleware::Hash {
assert!(statements.len() <= Params::num_public_statements_hash());
let mut none_st: Statement = middleware::Statement::None.into();
pad_statement(params, &mut none_st);
pad_statement(&mut none_st);
let statements_back_padded = statements
.iter()
.chain(iter::repeat(&none_st))
.take(params.num_public_statements_hash)
.take(Params::num_public_statements_hash())
.collect_vec();
let field_elems = statements_back_padded
.iter()
.rev()
.flat_map(|statement| statement.to_fields(params))
.flat_map(|statement| statement.to_fields())
.collect::<Vec<_>>();
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
}
@ -115,7 +114,7 @@ pub(crate) fn extract_custom_predicate_verifications(
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
.expect("find the custom predicate from the extracted unique list");
let custom_predicate_table_index =
batch_index * params.max_custom_batch_size + cpr.index;
batch_index * Params::max_custom_batch_size() + cpr.index;
aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len());
table.push(CustomPredicateVerification {
custom_predicate_table_index,
@ -326,8 +325,8 @@ fn fill_pad<T: Clone>(v: &mut Vec<T>, pad_value: T, len: usize) {
}
}
pub fn pad_statement(params: &Params, s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, params.max_statement_args)
pub fn pad_statement(s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args())
}
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) {
@ -353,7 +352,7 @@ pub(crate) fn layout_statements(
// We mocking or we don't need padding so we skip creating an EmptyPod
MockEmptyPod::new_boxed(params, inputs.vd_set.clone())
} else {
EmptyPod::new_boxed(params, inputs.vd_set.clone())
EmptyPod::new_boxed(inputs.vd_set.clone())
};
let empty_pod = empty_pod_box.as_ref();
assert!(inputs.pods.len() <= params.max_input_pods);
@ -367,7 +366,7 @@ pub(crate) fn layout_statements(
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
pad_statement(params, &mut st);
pad_statement(&mut st);
statements.push(st);
}
}
@ -386,7 +385,7 @@ pub(crate) fn layout_statements(
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
pad_statement(params, &mut st);
pad_statement(&mut st);
statements.push(st);
}
@ -399,7 +398,7 @@ pub(crate) fn layout_statements(
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
pad_statement(params, &mut st);
pad_statement(&mut st);
statements.push(st);
}
@ -475,7 +474,7 @@ impl MainPodProver for Prover {
// We don't need padding so we skip creating an EmptyPod
MockEmptyPod::new_boxed(params, inputs.vd_set.clone())
} else {
EmptyPod::new_boxed(params, inputs.vd_set.clone())
EmptyPod::new_boxed(inputs.vd_set.clone())
};
let inputs = MainPodInputs {
pods: &inputs
@ -491,10 +490,7 @@ impl MainPodProver for Prover {
let input_pods_pub_self_statements = inputs
.pods
.iter()
.map(|pod| {
assert_eq!(params.id_params(), pod.params().id_params());
pod.pub_self_statements()
})
.map(|pod| pod.pub_self_statements())
.collect_vec();
// Aux values for backend::Operation
@ -527,7 +523,7 @@ impl MainPodProver for Prover {
let operations = process_public_statements_operations(params, &statements, operations)?;
// get the id out of the public statements
let sts_hash = calculate_statements_hash(&public_statements, params);
let sts_hash = calculate_statements_hash(&public_statements);
let common_hash: String = cache_get_rec_main_pod_common_hash(params).clone();
let proofs = inputs
@ -718,7 +714,7 @@ impl Pod for MainPod {
)));
}
// 2. get the id out of the public statements
let sts_hash = calculate_statements_hash(&self.public_statements, &self.params);
let sts_hash = calculate_statements_hash(&self.public_statements);
if sts_hash != self.sts_hash {
return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash));
}
@ -738,7 +734,7 @@ impl Pod for MainPod {
let rec_main_pod_verifier_circuit_data =
&*cache_get_rec_main_pod_verifier_circuit_data(&self.params);
let public_inputs = sts_hash
.to_fields(&self.params)
.to_fields()
.iter()
.chain(self.vd_set.root().0.iter())
.cloned()
@ -998,14 +994,10 @@ pub mod tests {
max_input_pods_public_statements: 2,
max_statements: 5,
max_public_statements: 2,
num_public_statements_hash: 4,
max_statement_args: 4,
max_operation_args: 4,
max_operation_args: 5,
max_custom_predicate_batches: 2,
max_custom_predicate_verifications: 2,
max_custom_predicate_arity: 2,
max_custom_predicate_wildcards: 3,
max_custom_batch_size: 2,
max_merkle_proofs_containers: 2,
max_merkle_tree_state_transition_proofs_containers: 2,
max_public_key_of: 2,
@ -1067,10 +1059,7 @@ pub mod tests {
max_input_pods: 0,
max_statements: 9,
max_public_statements: 4,
max_statement_args: 4,
max_operation_args: 4,
max_custom_predicate_arity: 3,
max_custom_batch_size: 3,
max_operation_args: 5,
max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 3,

View file

@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::error::{Error, Result},
middleware::{self, NativePredicate, Params, Predicate, StatementArg, ToFields, Value},
middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS},
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -30,14 +30,14 @@ impl Statement {
}
impl ToFields for Statement {
fn to_fields(&self, params: &Params) -> Vec<middleware::F> {
let mut fields = self.0.hash(params).to_fields(params);
fn to_fields(&self) -> Vec<middleware::F> {
let mut fields = self.0.hash().to_fields();
fields.extend(
self.1
.iter()
.chain(iter::repeat(&StatementArg::None))
.take(params.max_statement_args)
.flat_map(|arg| arg.to_fields(params)),
.take(BASE_PARAMS.max_statement_args)
.flat_map(|arg| arg.to_fields()),
);
fields
}

View file

@ -30,7 +30,7 @@ fn empty_statement() -> Statement {
impl MockEmptyPod {
pub fn new_boxed(params: &Params, vd_set: VDSet) -> Box<dyn Pod> {
let statements = [mainpod::Statement::from(empty_statement())];
let sts_hash = calculate_statements_hash(&statements, params);
let sts_hash = calculate_statements_hash(&statements);
Box::new(Self {
params: params.clone(),
sts_hash,
@ -49,7 +49,7 @@ impl Pod for MockEmptyPod {
.into_iter()
.map(mainpod::Statement::from)
.collect_vec();
let sts_hash = calculate_statements_hash(&statements, &self.params);
let sts_hash = calculate_statements_hash(&statements);
if sts_hash != self.sts_hash {
return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash));
}

View file

@ -167,7 +167,7 @@ impl MockMainPod {
let operations = process_public_statements_operations(params, &statements, operations)?;
// get the id out of the public statements
let sts_hash = calculate_statements_hash(&public_statements, params);
let sts_hash = calculate_statements_hash(&public_statements);
let pad_pod = MockEmptyPod::new_boxed(params, inputs.vd_set.clone());
let input_pods: Vec<Box<dyn Pod>> = inputs