Use predicate hash in statements instead of the literal predicate
Resolve #448 Previously a predicate was 6 elements. Now it grows to 8 elements; and the hash is 4 elements. Some parts of the circuit require only require equality checks with the predicate: that works with the predicate hash. Other parts require inspecting or working with particular elements in the predicate, those need the preimage of the predicate hash. Both `StatementTarget` and `StatementTmplTarget` have been updated to include the predicate hash and optionally the predicate. When the predicate is included, constraints are automatically generated for `pred_hash = hash(pred)`. We only include the predicate when needed.
This commit is contained in:
parent
2eb1daeb92
commit
0fca00cc93
7 changed files with 319 additions and 159 deletions
|
|
@ -1396,10 +1396,7 @@ fn make_statement_from_template_circuit(
|
|||
})
|
||||
.collect();
|
||||
measure_gates_end!(builder, measure);
|
||||
StatementTarget {
|
||||
predicate: st_tmpl.pred.clone(),
|
||||
args,
|
||||
}
|
||||
StatementTarget::new(*st_tmpl.pred_hash(), args)
|
||||
}
|
||||
|
||||
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
|
||||
|
|
@ -1434,11 +1431,9 @@ fn make_custom_statement_circuit(
|
|||
let v = builder.select_flattenable(params, mask, arg, &arg_none);
|
||||
StatementArgTarget::wildcard_literal(builder, &v)
|
||||
})
|
||||
.collect();
|
||||
let statement = StatementTarget {
|
||||
predicate: st_predicate,
|
||||
args: st_args,
|
||||
};
|
||||
.collect_vec();
|
||||
let statement_with_pred =
|
||||
StatementTarget::new_with_pred(builder, params, st_predicate, &st_args);
|
||||
|
||||
// Check the operation arguments
|
||||
// From each statement template we generate an expected statement using replacing the
|
||||
|
|
@ -1470,7 +1465,7 @@ fn make_custom_statement_circuit(
|
|||
|
||||
builder.assert_one(is_op_args_ok.target);
|
||||
measure_gates_end!(builder, measure);
|
||||
Ok((statement, op_type))
|
||||
Ok((statement_with_pred, op_type))
|
||||
}
|
||||
|
||||
/// Replace the blank verifier_data_hash slots in intro predicates by `vd_hash`
|
||||
|
|
@ -1480,19 +1475,13 @@ fn normalize_statement_circuit(
|
|||
statement: &StatementTarget,
|
||||
vd_hash: &HashOutTarget,
|
||||
) -> StatementTarget {
|
||||
let is_intro = statement.predicate.is_intro(builder);
|
||||
let old_pred = statement.predicate.elements;
|
||||
let old = HashOutTarget::try_from(&old_pred[1..1 + HASH_SIZE]).expect("len = 4");
|
||||
let new = builder
|
||||
.select_flattenable(params, is_intro, vd_hash, &old)
|
||||
.elements;
|
||||
let is_blank_intro = statement.pred_is_blank_intro(builder);
|
||||
let old_pred_hash = statement.pred_hash();
|
||||
let intro_pred_hash = PredicateTarget::new_intro(builder, *vd_hash).hash(builder);
|
||||
let new_pred_hash =
|
||||
builder.select_flattenable(params, is_blank_intro, &intro_pred_hash, old_pred_hash);
|
||||
|
||||
StatementTarget {
|
||||
predicate: PredicateTarget {
|
||||
elements: [old_pred[0], new[0], new[1], new[2], new[3], old_pred[5]],
|
||||
},
|
||||
args: statement.args.clone(),
|
||||
}
|
||||
StatementTarget::new(new_pred_hash, statement.args.clone())
|
||||
}
|
||||
|
||||
/// `params.num_public_statements_hash` is the total number of statements that will be hashed.
|
||||
|
|
@ -1538,15 +1527,13 @@ fn normalize_st_tmpl_circuit(
|
|||
st_tmpl: &StatementTmplTarget,
|
||||
id: HashOutTarget,
|
||||
) -> StatementTmplTarget {
|
||||
let pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
|
||||
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
|
||||
let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self);
|
||||
let pred_index = st_tmpl.pred.elements[1];
|
||||
let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self);
|
||||
let pred_index = pred.elements[1];
|
||||
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
||||
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred);
|
||||
StatementTmplTarget {
|
||||
pred,
|
||||
args: st_tmpl.args.clone(),
|
||||
}
|
||||
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred);
|
||||
StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone())
|
||||
}
|
||||
|
||||
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||
|
|
@ -1567,7 +1554,9 @@ fn build_custom_predicate_table_circuit(
|
|||
let statements = cp
|
||||
.statements
|
||||
.iter()
|
||||
.map(|st_tmpl| normalize_st_tmpl_circuit(params, builder, st_tmpl, id))
|
||||
.map(|st_with_pred_tmpl| {
|
||||
normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, id)
|
||||
})
|
||||
.collect_vec();
|
||||
let cp = CustomPredicateTarget {
|
||||
conjunction: cp.conjunction,
|
||||
|
|
@ -1625,19 +1614,19 @@ fn verify_main_pod_circuit(
|
|||
|
||||
// NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction
|
||||
// pod that declares a statement with no arguments.
|
||||
let is_intro = input_pod_self_statements[0].predicate.is_intro(builder);
|
||||
let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder);
|
||||
|
||||
// Introduction pods can only have Introduction or None statements
|
||||
let mut intro_ok = is_intro;
|
||||
let mut intro_ok = is_blank_intro;
|
||||
for self_st in &input_pod_self_statements[1..] {
|
||||
let st_is_intro = self_st.predicate.is_intro(builder);
|
||||
let st_is_intro = self_st.pred_is_blank_intro(builder);
|
||||
let st_is_none = self_st.has_native_type(builder, params, NativePredicate::None);
|
||||
let st_is_intro_or_none = builder.or(st_is_intro, st_is_none);
|
||||
intro_ok = builder.and(intro_ok, st_is_intro_or_none);
|
||||
}
|
||||
builder.connect(is_intro.target, intro_ok.target);
|
||||
builder.connect(is_blank_intro.target, intro_ok.target);
|
||||
|
||||
let is_main = builder.not(is_intro);
|
||||
let is_main = builder.not(is_blank_intro);
|
||||
for self_st in input_pod_self_statements {
|
||||
let normalized_st = normalize_statement_circuit(
|
||||
params,
|
||||
|
|
@ -1750,12 +1739,12 @@ impl MainPodVerifyTarget {
|
|||
input_pods_self_statements: (0..params.max_input_pods)
|
||||
.map(|_| {
|
||||
(0..params.max_input_pods_public_statements)
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.map(|_| builder.add_virtual_statement(params, false))
|
||||
.collect_vec()
|
||||
})
|
||||
.collect(),
|
||||
input_statements: (0..params.max_statements)
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.map(|_| builder.add_virtual_statement(params, false))
|
||||
.collect(),
|
||||
operations: (0..params.max_statements)
|
||||
.map(|_| builder.add_virtual_operation(params))
|
||||
|
|
@ -1781,10 +1770,10 @@ impl MainPodVerifyTarget {
|
|||
})
|
||||
.collect(),
|
||||
custom_predicate_batches: (0..params.max_custom_predicate_batches)
|
||||
.map(|_| builder.add_virtual_custom_predicate_batch(params))
|
||||
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
|
||||
.collect(),
|
||||
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
|
||||
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
|
||||
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
|
@ -2084,10 +2073,10 @@ mod tests {
|
|||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::new(config);
|
||||
|
||||
let st_target = builder.add_virtual_statement(¶ms);
|
||||
let st_target = builder.add_virtual_statement(¶ms, false);
|
||||
let op_target = builder.add_virtual_operation(¶ms);
|
||||
let prev_statements_target: Vec<_> = (0..prev_statements.len())
|
||||
.map(|_| builder.add_virtual_statement(¶ms))
|
||||
.map(|_| builder.add_virtual_statement(¶ms, false))
|
||||
.collect();
|
||||
|
||||
let merkle_proofs_target: Vec<_> = aux
|
||||
|
|
@ -3098,7 +3087,7 @@ mod tests {
|
|||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::new(config);
|
||||
|
||||
let st_tmpl_target = builder.add_virtual_statement_tmpl(params);
|
||||
let st_tmpl_target = builder.add_virtual_statement_tmpl(params, false);
|
||||
let args_target: Vec<_> = (0..args.len())
|
||||
.map(|_| builder.add_virtual_value())
|
||||
.collect();
|
||||
|
|
@ -3109,7 +3098,7 @@ mod tests {
|
|||
&args_target,
|
||||
);
|
||||
// TODO: Instead of connect, assign witness to result
|
||||
let expected_st_target = builder.add_virtual_statement(params);
|
||||
let expected_st_target = builder.add_virtual_statement(params, false);
|
||||
builder.connect_flattenable(&expected_st_target, &st_target);
|
||||
|
||||
let mut pw = PartialWitness::<F>::new();
|
||||
|
|
@ -3161,9 +3150,9 @@ mod tests {
|
|||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::new(config);
|
||||
|
||||
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
|
||||
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false);
|
||||
let op_args_target: Vec<_> = (0..args.len())
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.map(|_| builder.add_virtual_statement(params, false))
|
||||
.collect();
|
||||
let args_target: Vec<_> = (0..args.len())
|
||||
.map(|_| builder.add_virtual_value())
|
||||
|
|
@ -3455,7 +3444,7 @@ mod tests {
|
|||
let mut builder = CircuitBuilder::new(config);
|
||||
|
||||
let statements_target = (0..params.max_public_statements)
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.map(|_| builder.add_virtual_statement(params, false))
|
||||
.collect_vec();
|
||||
let sts_hash_target =
|
||||
calculate_statements_hash_circuit(params, &mut builder, &statements_target);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue