Fix custom predicate circuits and add tests for them (#235)
* add tests, fix custom predicates * wip * wip * fix custom predicates * modularize code * fix typos * remove scratch file * update * Update src/backends/plonky2/circuits/mainpod.rs Co-authored-by: Ahmad Afuni <root@ahmadafuni.com> --------- Co-authored-by: Ahmad Afuni <root@ahmadafuni.com>
This commit is contained in:
parent
f5a1aa7523
commit
def0730462
16 changed files with 629 additions and 153 deletions
|
|
@ -14,10 +14,10 @@ use crate::{
|
|||
circuits::{
|
||||
common::{
|
||||
CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget,
|
||||
CustomPredicateVerifyEntryTarget, CustomPredicateVerifyQueryTarget, Flattenable,
|
||||
MerkleClaimTarget, OperationTarget, OperationTypeTarget, PredicateTarget,
|
||||
StatementArgTarget, StatementTarget, StatementTmplArgTarget, StatementTmplTarget,
|
||||
ValueTarget,
|
||||
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
|
||||
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget, OperationTarget,
|
||||
OperationTypeTarget, PredicateTarget, StatementArgTarget, StatementTarget,
|
||||
StatementTmplArgTarget, StatementTmplTarget, ValueTarget,
|
||||
},
|
||||
signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget},
|
||||
},
|
||||
|
|
@ -30,8 +30,8 @@ use crate::{
|
|||
},
|
||||
middleware::{
|
||||
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||
NativePredicate, Params, PodType, Statement, StatementArg, ToFields, Value, WildcardValue,
|
||||
F, KEY_TYPE, SELF, VALUE_SIZE,
|
||||
NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields,
|
||||
Value, WildcardValue, F, KEY_TYPE, SELF, VALUE_SIZE,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -188,8 +188,7 @@ impl OperationVerifyGadget {
|
|||
.concat();
|
||||
|
||||
let ok = builder.any(op_checks);
|
||||
|
||||
builder.connect(ok.target, _true.target);
|
||||
builder.assert_one(ok.target);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -813,6 +812,7 @@ impl CustomOperationVerifyGadget {
|
|||
// expected_sts.len() == self.params.max_custom_predicate_arity
|
||||
// op_args.len() == self.params.max_operation_args;
|
||||
assert!(self.params.max_custom_predicate_arity <= self.params.max_operation_args);
|
||||
|
||||
let sts_eq: Vec<_> = expected_sts
|
||||
.iter()
|
||||
.zip(op_args.iter())
|
||||
|
|
@ -837,6 +837,119 @@ struct MainPodVerifyGadget {
|
|||
}
|
||||
|
||||
impl MainPodVerifyGadget {
|
||||
// Replace predicates of batch-self with the corresponding global custom predicate batch_id and
|
||||
// index
|
||||
fn normalize_st_tmpl(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
st_tmpl: &StatementTmplTarget,
|
||||
id: HashOutTarget,
|
||||
) -> StatementTmplTarget {
|
||||
let params = &self.params;
|
||||
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
|
||||
let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self);
|
||||
let pred_index = st_tmpl.pred.elements[1];
|
||||
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
||||
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred);
|
||||
StatementTmplTarget {
|
||||
pred,
|
||||
args: st_tmpl.args.clone(),
|
||||
}
|
||||
}
|
||||
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||
/// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we
|
||||
/// calculate the id of each batch.
|
||||
fn build_custom_predicate_table(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> Result<(Vec<HashOutTarget>, Vec<CustomPredicateBatchTarget>)> {
|
||||
let params = &self.params;
|
||||
let mut custom_predicate_table =
|
||||
Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size);
|
||||
let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches);
|
||||
for _ in 0..params.max_custom_predicate_batches {
|
||||
let cpb = builder.add_virtual_custom_predicate_batch(params);
|
||||
let id = cpb.id(builder); // constrain the id
|
||||
for (index, cp) in cpb.predicates.iter().enumerate() {
|
||||
let statements = cp
|
||||
.statements
|
||||
.iter()
|
||||
.map(|st_tmpl| self.normalize_st_tmpl(builder, st_tmpl, id))
|
||||
.collect_vec();
|
||||
let cp = CustomPredicateTarget {
|
||||
conjunction: cp.conjunction,
|
||||
statements,
|
||||
args_len: cp.args_len,
|
||||
};
|
||||
let entry = CustomPredicateEntryTarget {
|
||||
id, // output
|
||||
index: builder.constant(F::from_canonical_usize(index)), // constant
|
||||
predicate: cp.clone(), // input
|
||||
};
|
||||
|
||||
let in_query_hash = entry.hash(builder);
|
||||
custom_predicate_table.push(in_query_hash);
|
||||
}
|
||||
custom_predicate_batches.push(cpb); // We keep this for witness assignment
|
||||
}
|
||||
Ok((custom_predicate_table, custom_predicate_batches))
|
||||
}
|
||||
|
||||
/// Build table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args]
|
||||
/// with queryable part as hash([st, op, op_args]). While building the table we verify each
|
||||
/// custom predicate against the operation and statement.
|
||||
fn build_custom_predicate_verification_table(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
custom_predicate_table: &[HashOutTarget],
|
||||
) -> Result<(Vec<HashOutTarget>, Vec<CustomPredicateVerifyEntryTarget>)> {
|
||||
let params = &self.params;
|
||||
let mut custom_predicate_verifications =
|
||||
Vec::with_capacity(params.max_custom_predicate_verifications);
|
||||
let mut custom_predicate_verification_table =
|
||||
Vec::with_capacity(params.max_custom_predicate_verifications);
|
||||
for _ in 0..params.max_custom_predicate_verifications {
|
||||
let custom_predicate_table_index = builder.add_virtual_target();
|
||||
let custom_predicate = builder.add_virtual_custom_predicate_entry(params);
|
||||
let args = (0..params.max_custom_predicate_wildcards)
|
||||
.map(|_| builder.add_virtual_value())
|
||||
.collect_vec();
|
||||
let op_args = (0..params.max_operation_args)
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.collect_vec();
|
||||
|
||||
// Verify the custom predicate operation
|
||||
let (statement, op_type) = CustomOperationVerifyGadget {
|
||||
params: params.clone(),
|
||||
}
|
||||
.eval(builder, &custom_predicate, &op_args, &args)?;
|
||||
|
||||
// Check that the batch id is correct by querying the custom predicate batches table
|
||||
let table_query_hash =
|
||||
builder.vec_ref(params, custom_predicate_table, custom_predicate_table_index);
|
||||
let out_query_hash = custom_predicate.hash(builder);
|
||||
builder.connect_array(table_query_hash.elements, out_query_hash.elements);
|
||||
|
||||
let entry = CustomPredicateVerifyEntryTarget {
|
||||
custom_predicate_table_index, // input
|
||||
custom_predicate, // input
|
||||
args, // input
|
||||
query: CustomPredicateVerifyQueryTarget {
|
||||
statement, // output
|
||||
op_type, // output
|
||||
op_args, // input
|
||||
},
|
||||
};
|
||||
let in_query_hash = entry.query.hash(builder);
|
||||
custom_predicate_verification_table.push(in_query_hash);
|
||||
custom_predicate_verifications.push(entry); // We keep this for witness assignment
|
||||
}
|
||||
Ok((
|
||||
custom_predicate_verification_table,
|
||||
custom_predicate_verifications,
|
||||
))
|
||||
}
|
||||
|
||||
fn eval(&self, builder: &mut CircuitBuilder<F, D>) -> Result<MainPodVerifyTarget> {
|
||||
let params = &self.params;
|
||||
// 1. Verify all input signed pods
|
||||
|
|
@ -851,12 +964,17 @@ impl MainPodVerifyGadget {
|
|||
|
||||
// Build the statement array
|
||||
let mut statements = Vec::new();
|
||||
// Statement at index 0 is always None to be used for padding operation arguments in custom
|
||||
// predicate statements
|
||||
let st_none =
|
||||
StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]);
|
||||
statements.push(st_none);
|
||||
for signed_pod in &signed_pods {
|
||||
statements.extend_from_slice(signed_pod.pub_statements(builder, false).as_slice());
|
||||
}
|
||||
debug_assert_eq!(
|
||||
statements.len(),
|
||||
self.params.max_input_signed_pods * self.params.max_signed_pod_values
|
||||
1 + self.params.max_input_signed_pods * self.params.max_signed_pod_values
|
||||
);
|
||||
// TODO: Fill with input main pods
|
||||
for _main_pod in 0..self.params.max_input_main_pods {
|
||||
|
|
@ -895,73 +1013,13 @@ impl MainPodVerifyGadget {
|
|||
.map(|pf| pf.into())
|
||||
.collect();
|
||||
|
||||
// Table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||
// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we
|
||||
// calculate the id of each batch.
|
||||
let mut custom_predicate_table =
|
||||
Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size);
|
||||
let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches);
|
||||
for _ in 0..params.max_custom_predicate_batches {
|
||||
let cpb = builder.add_virtual_custom_predicate_batch(&self.params);
|
||||
let id = cpb.id(builder); // constrain the id
|
||||
for (index, cp) in cpb.predicates.iter().enumerate() {
|
||||
let entry = CustomPredicateEntryTarget {
|
||||
id, // output
|
||||
index: builder.constant(F::from_canonical_usize(index)), // constant
|
||||
predicate: cp.clone(), // input
|
||||
};
|
||||
let in_query_hash = entry.hash(builder);
|
||||
custom_predicate_table.push(in_query_hash);
|
||||
}
|
||||
custom_predicate_batches.push(cpb); // We keep this for witness assignment
|
||||
}
|
||||
// Table of custom predicate batches with batch_id calculation
|
||||
let (custom_predicate_table, custom_predicate_batches) =
|
||||
self.build_custom_predicate_table(builder)?;
|
||||
|
||||
// Table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args]
|
||||
// with queryable part as hash([st, op, op_args]). While building the table we verify each
|
||||
// custom predicate against the operation and statement.
|
||||
let mut custom_predicate_verifications =
|
||||
Vec::with_capacity(params.max_custom_predicate_verifications);
|
||||
let mut custom_predicate_verification_table =
|
||||
Vec::with_capacity(params.max_custom_predicate_verifications);
|
||||
for _ in 0..params.max_custom_predicate_verifications {
|
||||
let custom_predicate_table_index = builder.add_virtual_target();
|
||||
let custom_predicate = builder.add_virtual_custom_predicate_entry(&self.params);
|
||||
let args = (0..params.max_custom_predicate_wildcards)
|
||||
.map(|_| builder.add_virtual_value())
|
||||
.collect_vec();
|
||||
let op_args = (0..params.max_operation_args)
|
||||
.map(|_| builder.add_virtual_statement(&self.params))
|
||||
.collect_vec();
|
||||
|
||||
// Verify the custom predicate operation
|
||||
let (statement, op_type) = CustomOperationVerifyGadget {
|
||||
params: params.clone(),
|
||||
}
|
||||
.eval(builder, &custom_predicate, &op_args, &args)?;
|
||||
|
||||
// Check that the batch id is correct by querying the custom predicate batches table
|
||||
let table_query_hash = builder.vec_ref(
|
||||
&self.params,
|
||||
&custom_predicate_table,
|
||||
custom_predicate_table_index,
|
||||
);
|
||||
let out_query_hash = custom_predicate.hash(builder);
|
||||
builder.connect_array(table_query_hash.elements, out_query_hash.elements);
|
||||
|
||||
let entry = CustomPredicateVerifyEntryTarget {
|
||||
custom_predicate_table_index, // input
|
||||
custom_predicate, // input
|
||||
args, // input
|
||||
query: CustomPredicateVerifyQueryTarget {
|
||||
statement, // output
|
||||
op_type, // output
|
||||
op_args, // input
|
||||
},
|
||||
};
|
||||
let in_query_hash = entry.query.hash(builder);
|
||||
custom_predicate_verification_table.push(in_query_hash);
|
||||
custom_predicate_verifications.push(entry); // We keep this for witness assignment
|
||||
}
|
||||
// Table of custom predicate statements verification against operations
|
||||
let (custom_predicate_verification_table, custom_predicate_verifications) =
|
||||
self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?;
|
||||
|
||||
// 2. Calculate the Pod Id from the public statements
|
||||
let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect();
|
||||
|
|
@ -2193,7 +2251,7 @@ mod tests {
|
|||
custom_predicate: CustomPredicateRef,
|
||||
op_args: Vec<Statement>,
|
||||
args: Vec<WildcardValue>,
|
||||
expected_st: Statement,
|
||||
expected_st: Option<Statement>,
|
||||
) -> Result<()> {
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
|
@ -2226,22 +2284,22 @@ mod tests {
|
|||
arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?;
|
||||
}
|
||||
// Expected Output
|
||||
st_target.set_targets(&mut pw, params, &expected_st.into())?;
|
||||
if let Some(expected_st) = expected_st {
|
||||
st_target.set_targets(&mut pw, params, &expected_st.into())?;
|
||||
}
|
||||
|
||||
let expected_op_type = OperationType::Custom(custom_predicate);
|
||||
op_type_target.set_targets(&mut pw, params, &expected_op_type)?;
|
||||
|
||||
// generate & verify proof
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
data.verify(proof.clone()).unwrap();
|
||||
|
||||
Ok(())
|
||||
let proof = data.prove(pw)?;
|
||||
Ok(data.verify(proof.clone())?)
|
||||
}
|
||||
|
||||
// TODO: Add negative tests
|
||||
#[test]
|
||||
fn test_custom_operation_verify_gadget() -> frontend::Result<()> {
|
||||
fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> {
|
||||
// We set the parameters to the exact sizes we have in the test so that we don't have to
|
||||
// pad.
|
||||
let params = Params {
|
||||
|
|
@ -2298,7 +2356,7 @@ mod tests {
|
|||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
expected_st,
|
||||
Some(expected_st),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -2322,7 +2380,7 @@ mod tests {
|
|||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
expected_st,
|
||||
Some(expected_st),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -2349,10 +2407,190 @@ mod tests {
|
|||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
expected_st,
|
||||
Some(expected_st),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> {
|
||||
// We set the parameters to the exact sizes we have in the test so that we don't have to
|
||||
// pad.
|
||||
let params = Params {
|
||||
max_custom_predicate_arity: 2,
|
||||
max_custom_predicate_wildcards: 2,
|
||||
max_operation_args: 2,
|
||||
max_statement_args: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
use NativePredicate as NP;
|
||||
use StatementTmplBuilder as STB;
|
||||
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||
let stb0 = STB::new(NP::ValueOf)
|
||||
.arg(("id", key("score")))
|
||||
.arg(literal(42));
|
||||
let stb1 = STB::new(NP::Equal)
|
||||
.arg(("id", "secret_key"))
|
||||
.arg(("id", key("score")));
|
||||
let _ = builder.predicate_and(
|
||||
"pred_and",
|
||||
&["id"],
|
||||
&["secret_key"],
|
||||
&[stb0.clone(), stb1.clone()],
|
||||
)?;
|
||||
let _ = builder.predicate_or("pred_or", &["id"], &["secret_key"], &[stb0, stb1])?;
|
||||
let batch = builder.finish();
|
||||
|
||||
let pod_id = PodId(hash_str("pod_id"));
|
||||
|
||||
// AND (0) Sanity check with correct values
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
|
||||
let op_args = vec![
|
||||
Statement::ValueOf(
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
Value::from(42),
|
||||
),
|
||||
Statement::Equal(
|
||||
AnchoredKey::new(pod_id, Key::from("foo")),
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
),
|
||||
];
|
||||
let args = vec![
|
||||
WildcardValue::PodId(pod_id),
|
||||
WildcardValue::Key(Key::from("foo")),
|
||||
];
|
||||
let expected_st = Statement::Custom(
|
||||
custom_predicate.clone(),
|
||||
vec![args[0].clone(), WildcardValue::None],
|
||||
);
|
||||
|
||||
helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
Some(expected_st),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// AND (1) Different pod_id for same wildcard
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
|
||||
let op_args = vec![
|
||||
Statement::ValueOf(
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
Value::from(42),
|
||||
),
|
||||
Statement::Equal(
|
||||
AnchoredKey::new(PodId(hash_str("BAD")), Key::from("foo")),
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
),
|
||||
];
|
||||
let args = vec![
|
||||
WildcardValue::PodId(pod_id),
|
||||
WildcardValue::Key(Key::from("foo")),
|
||||
];
|
||||
|
||||
assert!(helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
None,
|
||||
)
|
||||
.is_err());
|
||||
|
||||
// AND (2) key doesn't match template
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
|
||||
let op_args = vec![
|
||||
Statement::ValueOf(AnchoredKey::new(pod_id, Key::from("BAD")), Value::from(42)),
|
||||
Statement::Equal(
|
||||
AnchoredKey::new(pod_id, Key::from("foo")),
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
),
|
||||
];
|
||||
let args = vec![
|
||||
WildcardValue::PodId(pod_id),
|
||||
WildcardValue::Key(Key::from("foo")),
|
||||
];
|
||||
|
||||
assert!(helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
None,
|
||||
)
|
||||
.is_err());
|
||||
|
||||
// AND (3) literal doesn't match template
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
|
||||
let op_args = vec![
|
||||
Statement::ValueOf(
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
Value::from(0xbad),
|
||||
),
|
||||
Statement::Equal(
|
||||
AnchoredKey::new(pod_id, Key::from("foo")),
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
),
|
||||
];
|
||||
let args = vec![
|
||||
WildcardValue::PodId(pod_id),
|
||||
WildcardValue::Key(Key::from("foo")),
|
||||
];
|
||||
|
||||
assert!(helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
None,
|
||||
)
|
||||
.is_err());
|
||||
|
||||
// AND (4) predicate doesn't match template
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
|
||||
let op_args = vec![
|
||||
Statement::ValueOf(
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
Value::from(42),
|
||||
),
|
||||
Statement::NotEqual(
|
||||
AnchoredKey::new(pod_id, Key::from("foo")),
|
||||
AnchoredKey::new(pod_id, Key::from("score")),
|
||||
),
|
||||
];
|
||||
let args = vec![
|
||||
WildcardValue::PodId(pod_id),
|
||||
WildcardValue::Key(Key::from("foo")),
|
||||
];
|
||||
|
||||
assert!(helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
None,
|
||||
)
|
||||
.is_err());
|
||||
|
||||
// OR (1) Two Nones
|
||||
let custom_predicate = CustomPredicateRef::new(batch.clone(), 1);
|
||||
let op_args = vec![Statement::None, Statement::None];
|
||||
let args = vec![WildcardValue::PodId(pod_id), WildcardValue::None];
|
||||
|
||||
assert!(helper_custom_operation_verify_gadget(
|
||||
¶ms,
|
||||
custom_predicate,
|
||||
op_args,
|
||||
args,
|
||||
None
|
||||
)
|
||||
.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue