Use predicate hash in statements instead of the literal predicate

Resolve #448 

Previously a predicate was 6 elements.  Now it grows to 8 elements; and the hash is 4 elements.

Some parts of the circuit require only require equality checks with the predicate: that works with the predicate hash.  Other parts require inspecting or working with particular elements in the predicate, those need the preimage of the predicate hash.
Both `StatementTarget` and `StatementTmplTarget` have been updated to include the predicate hash and optionally the predicate.  When the predicate is included, constraints are automatically generated for `pred_hash = hash(pred)`.  We only include the predicate when needed.
This commit is contained in:
Eduard S. 2026-01-19 11:02:11 +01:00 committed by GitHub
parent 2eb1daeb92
commit 0fca00cc93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 319 additions and 159 deletions

View file

@ -1396,10 +1396,7 @@ fn make_statement_from_template_circuit(
})
.collect();
measure_gates_end!(builder, measure);
StatementTarget {
predicate: st_tmpl.pred.clone(),
args,
}
StatementTarget::new(*st_tmpl.pred_hash(), args)
}
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
@ -1434,11 +1431,9 @@ fn make_custom_statement_circuit(
let v = builder.select_flattenable(params, mask, arg, &arg_none);
StatementArgTarget::wildcard_literal(builder, &v)
})
.collect();
let statement = StatementTarget {
predicate: st_predicate,
args: st_args,
};
.collect_vec();
let statement_with_pred =
StatementTarget::new_with_pred(builder, params, st_predicate, &st_args);
// Check the operation arguments
// From each statement template we generate an expected statement using replacing the
@ -1470,7 +1465,7 @@ fn make_custom_statement_circuit(
builder.assert_one(is_op_args_ok.target);
measure_gates_end!(builder, measure);
Ok((statement, op_type))
Ok((statement_with_pred, op_type))
}
/// Replace the blank verifier_data_hash slots in intro predicates by `vd_hash`
@ -1480,19 +1475,13 @@ fn normalize_statement_circuit(
statement: &StatementTarget,
vd_hash: &HashOutTarget,
) -> StatementTarget {
let is_intro = statement.predicate.is_intro(builder);
let old_pred = statement.predicate.elements;
let old = HashOutTarget::try_from(&old_pred[1..1 + HASH_SIZE]).expect("len = 4");
let new = builder
.select_flattenable(params, is_intro, vd_hash, &old)
.elements;
let is_blank_intro = statement.pred_is_blank_intro(builder);
let old_pred_hash = statement.pred_hash();
let intro_pred_hash = PredicateTarget::new_intro(builder, *vd_hash).hash(builder);
let new_pred_hash =
builder.select_flattenable(params, is_blank_intro, &intro_pred_hash, old_pred_hash);
StatementTarget {
predicate: PredicateTarget {
elements: [old_pred[0], new[0], new[1], new[2], new[3], old_pred[5]],
},
args: statement.args.clone(),
}
StatementTarget::new(new_pred_hash, statement.args.clone())
}
/// `params.num_public_statements_hash` is the total number of statements that will be hashed.
@ -1538,15 +1527,13 @@ fn normalize_st_tmpl_circuit(
st_tmpl: &StatementTmplTarget,
id: HashOutTarget,
) -> StatementTmplTarget {
let pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self);
let pred_index = st_tmpl.pred.elements[1];
let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self);
let pred_index = pred.elements[1];
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred);
StatementTmplTarget {
pred,
args: st_tmpl.args.clone(),
}
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred);
StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone())
}
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
@ -1567,7 +1554,9 @@ fn build_custom_predicate_table_circuit(
let statements = cp
.statements
.iter()
.map(|st_tmpl| normalize_st_tmpl_circuit(params, builder, st_tmpl, id))
.map(|st_with_pred_tmpl| {
normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, id)
})
.collect_vec();
let cp = CustomPredicateTarget {
conjunction: cp.conjunction,
@ -1625,19 +1614,19 @@ fn verify_main_pod_circuit(
// NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction
// pod that declares a statement with no arguments.
let is_intro = input_pod_self_statements[0].predicate.is_intro(builder);
let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder);
// Introduction pods can only have Introduction or None statements
let mut intro_ok = is_intro;
let mut intro_ok = is_blank_intro;
for self_st in &input_pod_self_statements[1..] {
let st_is_intro = self_st.predicate.is_intro(builder);
let st_is_intro = self_st.pred_is_blank_intro(builder);
let st_is_none = self_st.has_native_type(builder, params, NativePredicate::None);
let st_is_intro_or_none = builder.or(st_is_intro, st_is_none);
intro_ok = builder.and(intro_ok, st_is_intro_or_none);
}
builder.connect(is_intro.target, intro_ok.target);
builder.connect(is_blank_intro.target, intro_ok.target);
let is_main = builder.not(is_intro);
let is_main = builder.not(is_blank_intro);
for self_st in input_pod_self_statements {
let normalized_st = normalize_statement_circuit(
params,
@ -1750,12 +1739,12 @@ impl MainPodVerifyTarget {
input_pods_self_statements: (0..params.max_input_pods)
.map(|_| {
(0..params.max_input_pods_public_statements)
.map(|_| builder.add_virtual_statement(params))
.map(|_| builder.add_virtual_statement(params, false))
.collect_vec()
})
.collect(),
input_statements: (0..params.max_statements)
.map(|_| builder.add_virtual_statement(params))
.map(|_| builder.add_virtual_statement(params, false))
.collect(),
operations: (0..params.max_statements)
.map(|_| builder.add_virtual_operation(params))
@ -1781,10 +1770,10 @@ impl MainPodVerifyTarget {
})
.collect(),
custom_predicate_batches: (0..params.max_custom_predicate_batches)
.map(|_| builder.add_virtual_custom_predicate_batch(params))
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
.collect(),
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false))
.collect(),
}
}
@ -2084,10 +2073,10 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_target = builder.add_virtual_statement(&params);
let st_target = builder.add_virtual_statement(&params, false);
let op_target = builder.add_virtual_operation(&params);
let prev_statements_target: Vec<_> = (0..prev_statements.len())
.map(|_| builder.add_virtual_statement(&params))
.map(|_| builder.add_virtual_statement(&params, false))
.collect();
let merkle_proofs_target: Vec<_> = aux
@ -3098,7 +3087,7 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_tmpl_target = builder.add_virtual_statement_tmpl(params);
let st_tmpl_target = builder.add_virtual_statement_tmpl(params, false);
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
.collect();
@ -3109,7 +3098,7 @@ mod tests {
&args_target,
);
// TODO: Instead of connect, assign witness to result
let expected_st_target = builder.add_virtual_statement(params);
let expected_st_target = builder.add_virtual_statement(params, false);
builder.connect_flattenable(&expected_st_target, &st_target);
let mut pw = PartialWitness::<F>::new();
@ -3161,9 +3150,9 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false);
let op_args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_statement(params))
.map(|_| builder.add_virtual_statement(params, false))
.collect();
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
@ -3455,7 +3444,7 @@ mod tests {
let mut builder = CircuitBuilder::new(config);
let statements_target = (0..params.max_public_statements)
.map(|_| builder.add_virtual_statement(params))
.map(|_| builder.add_virtual_statement(params, false))
.collect_vec();
let sts_hash_target =
calculate_statements_hash_circuit(params, &mut builder, &statements_target);