Feat/fst order pred part1 & part2 (#454)
Implement support for first order predicates in the backend. Now a statement template can have a predicate hash or a wildcard. ## predicate <-> predicate hash constraints To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization. This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard). Then in normalization we recalculate the predicate hash if it was a Batch Self. This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard. How this is achieved: - Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard) - After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf. If it was a predicate but not batch self, the old value was used which was constrained via the constructor. See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit` ## Wildcard predicate resolution It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
parent
1724e7b146
commit
9c9a2c454c
11 changed files with 569 additions and 240 deletions
|
|
@ -25,8 +25,8 @@ use crate::{
|
|||
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
|
||||
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
|
||||
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
|
||||
PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget,
|
||||
StatementTmplTarget, ValueTarget,
|
||||
PredicateHashOrWildcardTarget, PredicateTarget, StatementArgTarget,
|
||||
StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget,
|
||||
},
|
||||
hash::{hash_from_state_circuit, precompute_hash_state},
|
||||
mux_table::{MuxTableTarget, TableEntryTarget},
|
||||
|
|
@ -341,12 +341,7 @@ fn build_operation_aux_table_circuit(
|
|||
.chain(signed_by.pk.u.components)
|
||||
.collect(),
|
||||
);
|
||||
let entry: MsgPubKeyTarget = HashPairTarget(
|
||||
HashOutTarget {
|
||||
elements: signed_by.msg.elements,
|
||||
},
|
||||
pk_hash,
|
||||
);
|
||||
let entry: MsgPubKeyTarget = HashPairTarget(HashOutTarget::from(signed_by.msg), pk_hash);
|
||||
|
||||
table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry);
|
||||
measure_gates_end!(builder, measure);
|
||||
|
|
@ -1381,6 +1376,26 @@ fn make_statement_arg_from_template_circuit(
|
|||
StatementArgTarget::new(first, second)
|
||||
}
|
||||
|
||||
fn make_predicate_from_template_circuit(
|
||||
params: &Params,
|
||||
builder: &mut CircuitBuilder,
|
||||
pred_hash_or_wc: &PredicateHashOrWildcardTarget,
|
||||
args: &[ValueTarget],
|
||||
) -> HashOutTarget {
|
||||
let zero = builder.zero();
|
||||
let is_pred = pred_hash_or_wc.is_pred(builder);
|
||||
// If the index is not used, use a 0 instead to still pass the range constraints from
|
||||
// vec_ref
|
||||
let index = builder.select(is_pred, zero, pred_hash_or_wc.wc_index());
|
||||
let resolved_pred_hash = HashOutTarget::from(builder.vec_ref_small(params, args, index));
|
||||
builder.select_flattenable(
|
||||
params,
|
||||
is_pred,
|
||||
&pred_hash_or_wc.pred_hash(),
|
||||
&resolved_pred_hash,
|
||||
)
|
||||
}
|
||||
|
||||
fn make_statement_from_template_circuit(
|
||||
params: &Params,
|
||||
builder: &mut CircuitBuilder,
|
||||
|
|
@ -1388,7 +1403,7 @@ fn make_statement_from_template_circuit(
|
|||
args: &[ValueTarget],
|
||||
) -> StatementTarget {
|
||||
let measure = measure_gates_begin!(builder, "StArgFromTmpl");
|
||||
let args = st_tmpl
|
||||
let st_args = st_tmpl
|
||||
.args
|
||||
.iter()
|
||||
.map(|st_tmpl_arg| {
|
||||
|
|
@ -1396,7 +1411,11 @@ fn make_statement_from_template_circuit(
|
|||
})
|
||||
.collect();
|
||||
measure_gates_end!(builder, measure);
|
||||
StatementTarget::new(*st_tmpl.pred_hash(), args)
|
||||
let measure = measure_gates_begin!(builder, "PredFromTmpl");
|
||||
let pred_hash =
|
||||
make_predicate_from_template_circuit(params, builder, st_tmpl.pred_hash_or_wc(), args);
|
||||
measure_gates_end!(builder, measure);
|
||||
StatementTarget::new(pred_hash, st_args)
|
||||
}
|
||||
|
||||
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
|
||||
|
|
@ -1527,13 +1546,29 @@ fn normalize_st_tmpl_circuit(
|
|||
st_tmpl: &StatementTmplTarget,
|
||||
id: HashOutTarget,
|
||||
) -> StatementTmplTarget {
|
||||
let pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
|
||||
// If the custom predicate is self, we normalize it and then hash it.
|
||||
let old_pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
|
||||
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
|
||||
let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self);
|
||||
let pred_index = pred.elements[1];
|
||||
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
||||
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred);
|
||||
StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone())
|
||||
let is_batch_self = builder.is_equal(old_pred.elements[0], prefix_batch_self);
|
||||
|
||||
let pred_index = old_pred.elements[1];
|
||||
let normalized_custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
||||
let normalized_custom_pred_hash = normalized_custom_pred.hash(builder);
|
||||
|
||||
// If the template is using a predicate and it is batch self we use the freshly computed
|
||||
// normalized predicate hash, otherwise we keep the original data.
|
||||
let old_data = st_tmpl.pred_hash_or_wc().data();
|
||||
let is_pred = st_tmpl.pred_hash_or_wc().is_pred(builder);
|
||||
let is_pred_batch_self = builder.and(is_pred, is_batch_self);
|
||||
let data = builder.select_flattenable(
|
||||
params,
|
||||
is_pred_batch_self,
|
||||
&ValueTarget::from(normalized_custom_pred_hash),
|
||||
&old_data,
|
||||
);
|
||||
let pred_hash_or_wc =
|
||||
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
|
||||
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
|
||||
}
|
||||
|
||||
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||
|
|
@ -1773,7 +1808,7 @@ impl MainPodVerifyTarget {
|
|||
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
|
||||
.collect(),
|
||||
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
|
||||
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false))
|
||||
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
|
@ -2012,8 +2047,8 @@ mod tests {
|
|||
dict,
|
||||
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
||||
middleware::{
|
||||
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, RawValue, StatementArg,
|
||||
StatementTmpl, StatementTmplArg, Wildcard,
|
||||
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
|
||||
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -3124,7 +3159,7 @@ mod tests {
|
|||
let dict = Hash([F(6), F(7), F(8), F(9)]);
|
||||
|
||||
let st_tmpl = StatementTmpl {
|
||||
pred: Predicate::Native(NativePredicate::Equal),
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)),
|
||||
args: vec![
|
||||
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
|
||||
StatementTmplArg::Literal(Value::from("value")),
|
||||
|
|
@ -3137,6 +3172,21 @@ mod tests {
|
|||
);
|
||||
helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?;
|
||||
|
||||
let st_tmpl = StatementTmpl {
|
||||
pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)),
|
||||
args: vec![
|
||||
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
|
||||
StatementTmplArg::Literal(Value::from("value")),
|
||||
],
|
||||
};
|
||||
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(¶ms);
|
||||
let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)];
|
||||
let expected_st = Statement::not_equal(
|
||||
AnchoredKey::new(dict, Key::from("key")),
|
||||
Value::from("value"),
|
||||
);
|
||||
helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -3150,7 +3200,7 @@ mod tests {
|
|||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::new(config);
|
||||
|
||||
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false);
|
||||
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
|
||||
let op_args_target: Vec<_> = (0..args.len())
|
||||
.map(|_| builder.add_virtual_statement(params, false))
|
||||
.collect();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue