Feat/fst order pred part1 & part2 (#454)
Implement support for first order predicates in the backend. Now a statement template can have a predicate hash or a wildcard. ## predicate <-> predicate hash constraints To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization. This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard). Then in normalization we recalculate the predicate hash if it was a Batch Self. This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard. How this is achieved: - Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard) - After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf. If it was a predicate but not batch self, the old value was used which was constrained via the constructor. See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit` ## Wildcard predicate resolution It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
parent
1724e7b146
commit
9c9a2c454c
11 changed files with 569 additions and 240 deletions
|
|
@ -32,9 +32,10 @@ use crate::{
|
|||
},
|
||||
middleware::{
|
||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
|
||||
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
|
||||
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||
NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard,
|
||||
PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl,
|
||||
StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE,
|
||||
STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -46,6 +47,22 @@ pub struct ValueTarget {
|
|||
pub elements: [Target; VALUE_SIZE],
|
||||
}
|
||||
|
||||
impl From<ValueTarget> for HashOutTarget {
|
||||
fn from(v: ValueTarget) -> HashOutTarget {
|
||||
HashOutTarget {
|
||||
elements: v.elements,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashOutTarget> for ValueTarget {
|
||||
fn from(h: HashOutTarget) -> ValueTarget {
|
||||
ValueTarget {
|
||||
elements: h.elements,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueTarget {
|
||||
pub fn zero(builder: &mut CircuitBuilder) -> Self {
|
||||
Self {
|
||||
|
|
@ -524,18 +541,112 @@ impl StatementTmplArgTarget {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PredicateHashOrWildcardTarget {
|
||||
/// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
|
||||
pub elements: [Target; Params::pred_hash_or_wc_size()],
|
||||
}
|
||||
|
||||
impl PredicateHashOrWildcardTarget {
|
||||
pub fn new(prefix: Target, data: ValueTarget) -> Self {
|
||||
let v = data.elements;
|
||||
Self {
|
||||
elements: [prefix, v[0], v[1], v[2], v[3]],
|
||||
}
|
||||
}
|
||||
pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self {
|
||||
Self::new(
|
||||
builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)),
|
||||
ValueTarget::from(pred_hash),
|
||||
)
|
||||
}
|
||||
pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget {
|
||||
let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate));
|
||||
builder.is_equal(self.elements[0], prefix_pred)
|
||||
}
|
||||
pub fn data(&self) -> ValueTarget {
|
||||
ValueTarget {
|
||||
elements: self.elements[1..].try_into().expect("4 elements"),
|
||||
}
|
||||
}
|
||||
pub fn pred_hash(&self) -> HashOutTarget {
|
||||
HashOutTarget::from(self.data())
|
||||
}
|
||||
pub fn wc_index(&self) -> Target {
|
||||
self.elements[1]
|
||||
}
|
||||
pub fn set_targets_raw(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
prefix: PredicateOrWildcardPrefix,
|
||||
data: RawValue,
|
||||
) -> Result<()> {
|
||||
pw.set_target(self.elements[0], F::from(prefix))?;
|
||||
pw.set_target_arr(&self.elements[1..], &data.0)?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
pred: &PredicateOrWildcard,
|
||||
) -> Result<()> {
|
||||
match pred {
|
||||
PredicateOrWildcard::Predicate(pred) => {
|
||||
self.set_targets_raw(
|
||||
pw,
|
||||
PredicateOrWildcardPrefix::Predicate,
|
||||
RawValue::from(pred.hash(params)),
|
||||
)?;
|
||||
}
|
||||
PredicateOrWildcard::Wildcard(wc) => {
|
||||
self.set_targets_raw(
|
||||
pw,
|
||||
PredicateOrWildcardPrefix::Wildcard,
|
||||
RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Flattenable for PredicateHashOrWildcardTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.elements.to_vec()
|
||||
}
|
||||
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||
Self {
|
||||
elements: vs.try_into().expect("5 elements"),
|
||||
}
|
||||
}
|
||||
fn size(_params: &Params) -> usize {
|
||||
Params::pred_hash_or_wc_size()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct StatementTmplTarget {
|
||||
/// The preimage of the predicate_hash. This predicate is needed only to build the custom
|
||||
/// predicate table because it needs to normalize statement templates with predicates that
|
||||
/// refer to self into content-addressed predicates (using the batch id and index). The
|
||||
/// predicate type is inspected to do this normalization. After the table is built we only use
|
||||
/// the predicate hash for equality checks.
|
||||
pred: Option<PredicateTarget>,
|
||||
pred_hash: HashOutTarget,
|
||||
/// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
|
||||
/// and the template uses a predicate and not a wildcard.
|
||||
pred_hash_or_wc: PredicateHashOrWildcardTarget,
|
||||
pub args: Vec<StatementTmplArgTarget>,
|
||||
}
|
||||
|
||||
impl StatementTmplTarget {
|
||||
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementTmplArgTarget>) -> Self {
|
||||
pub fn new(
|
||||
pred_hash_or_wc: PredicateHashOrWildcardTarget,
|
||||
args: Vec<StatementTmplArgTarget>,
|
||||
) -> Self {
|
||||
Self {
|
||||
pred: None,
|
||||
pred_hash,
|
||||
pred_hash_or_wc,
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
|
@ -546,9 +657,22 @@ impl StatementTmplTarget {
|
|||
st_tmpl: &StatementTmpl,
|
||||
) -> Result<()> {
|
||||
if let Some(pred) = &self.pred {
|
||||
pred.set_targets(pw, params, &st_tmpl.pred)?;
|
||||
match &st_tmpl.pred_or_wc {
|
||||
PredicateOrWildcard::Predicate(p) => {
|
||||
// We store a predicate (not a wildcard) and we have it available. In this
|
||||
// case the hash will be calculated by constraints later on and we should not
|
||||
// rely on the original data.
|
||||
pred.set_targets(pw, params, p)?
|
||||
}
|
||||
PredicateOrWildcard::Wildcard(_wc) => {
|
||||
// Fill in with a recognizable constant for better debugging; this value is
|
||||
// not supposed to be used.
|
||||
pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])?
|
||||
}
|
||||
}
|
||||
}
|
||||
pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?;
|
||||
self.pred_hash_or_wc
|
||||
.set_targets(pw, params, &st_tmpl.pred_or_wc)?;
|
||||
let arg_pad = StatementTmplArg::None;
|
||||
for (i, arg) in st_tmpl
|
||||
.args
|
||||
|
|
@ -564,8 +688,8 @@ impl StatementTmplTarget {
|
|||
pub fn pred(&self) -> Option<&PredicateTarget> {
|
||||
self.pred.as_ref()
|
||||
}
|
||||
pub fn pred_hash(&self) -> &HashOutTarget {
|
||||
&self.pred_hash
|
||||
pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget {
|
||||
&self.pred_hash_or_wc
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -603,6 +727,8 @@ impl CustomPredicateTarget {
|
|||
}
|
||||
}
|
||||
|
||||
/// This type is used to build the custom predicate table, which exposes the custom predicates with
|
||||
/// normalized statement templates indexed by batch_id and custom_predicate_index.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CustomPredicateBatchTarget {
|
||||
pub predicates: Vec<CustomPredicateTarget>,
|
||||
|
|
@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget {
|
|||
.clone()
|
||||
.into_iter()
|
||||
.map(|st_tmpl| {
|
||||
let pred = match st_tmpl.pred {
|
||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||
batch: batch.clone(),
|
||||
index: i,
|
||||
}),
|
||||
p => p,
|
||||
let pred_or_wc = match st_tmpl.pred_or_wc {
|
||||
PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => {
|
||||
PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef {
|
||||
batch: batch.clone(),
|
||||
index: i,
|
||||
}))
|
||||
}
|
||||
x => x.clone(),
|
||||
};
|
||||
StatementTmpl {
|
||||
pred,
|
||||
pred_or_wc,
|
||||
args: st_tmpl.args,
|
||||
}
|
||||
})
|
||||
|
|
@ -724,7 +852,7 @@ pub struct CustomPredicateVerifyEntryTarget {
|
|||
}
|
||||
|
||||
impl CustomPredicateVerifyEntryTarget {
|
||||
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self {
|
||||
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
|
||||
let custom_predicate_table_len =
|
||||
params.max_custom_predicate_batches * params.max_custom_batch_size;
|
||||
CustomPredicateVerifyEntryTarget {
|
||||
|
|
@ -732,7 +860,7 @@ impl CustomPredicateVerifyEntryTarget {
|
|||
custom_predicate_table_len,
|
||||
builder,
|
||||
),
|
||||
custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred),
|
||||
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
|
||||
args: (0..params.max_custom_predicate_wildcards)
|
||||
.map(|_| builder.add_virtual_value())
|
||||
.collect(),
|
||||
|
|
@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget {
|
|||
|
||||
impl Flattenable for StatementTmplTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.pred_hash
|
||||
self.pred_hash_or_wc
|
||||
.flatten()
|
||||
.into_iter()
|
||||
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
|
||||
|
|
@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget {
|
|||
|
||||
fn from_flattened(params: &Params, v: &[Target]) -> Self {
|
||||
assert_eq!(v.len(), Self::size(params));
|
||||
let pred_hash_end = HASH_SIZE;
|
||||
let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]);
|
||||
let pred_hash_or_wc_end = Params::pred_hash_or_wc_size();
|
||||
let pred_hash_or_wc =
|
||||
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
|
||||
let sta_size = Params::statement_tmpl_arg_size();
|
||||
let args = (0..params.max_statement_args)
|
||||
.map(|i| {
|
||||
let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)];
|
||||
let sta_v = &v
|
||||
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
|
||||
StatementTmplArgTarget::from_flattened(params, sta_v)
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
pred: None,
|
||||
pred_hash,
|
||||
pred_hash_or_wc,
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
fn size(params: &Params) -> usize {
|
||||
HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params)
|
||||
Params::pred_hash_or_wc_size()
|
||||
+ params.max_statement_args * StatementTmplArgTarget::size(params)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
params: &Params,
|
||||
with_pred: bool,
|
||||
) -> CustomPredicateBatchTarget;
|
||||
fn add_virtual_custom_predicate_entry(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
with_pred: bool,
|
||||
) -> CustomPredicateEntryTarget;
|
||||
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
|
||||
-> CustomPredicateEntryTarget;
|
||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||
fn select_statement_arg(
|
||||
&mut self,
|
||||
|
|
@ -1320,24 +1448,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
/// If `with_pred = true` a predicate is included and its hash constrained.
|
||||
/// If `with_pred = true` a predicate is included.
|
||||
/// If `with_pred = false` only the predicate hash is included.
|
||||
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
|
||||
/// predicate and not a wildcard.
|
||||
fn add_virtual_statement_tmpl(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
with_pred: bool,
|
||||
) -> StatementTmplTarget {
|
||||
let (pred, pred_hash) = if with_pred {
|
||||
let pred_hash_or_wc =
|
||||
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
|
||||
let pred = if with_pred {
|
||||
let pred = self.add_virtual_predicate();
|
||||
let pred_hash = pred.hash(self);
|
||||
(Some(pred), pred_hash)
|
||||
let is_pred = pred_hash_or_wc.is_pred(self);
|
||||
let data = pred_hash_or_wc.data();
|
||||
for i in 0..VALUE_SIZE {
|
||||
self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]);
|
||||
}
|
||||
Some(pred)
|
||||
} else {
|
||||
let pred_hash = self.add_virtual_hash();
|
||||
(None, pred_hash)
|
||||
None
|
||||
};
|
||||
StatementTmplTarget {
|
||||
pred,
|
||||
pred_hash,
|
||||
pred_hash_or_wc,
|
||||
args: (0..params.max_statement_args)
|
||||
.map(|_| self.add_virtual_statement_tmpl_arg())
|
||||
.collect(),
|
||||
|
|
@ -1377,12 +1513,11 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
fn add_virtual_custom_predicate_entry(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
with_pred: bool,
|
||||
) -> CustomPredicateEntryTarget {
|
||||
CustomPredicateEntryTarget {
|
||||
id: self.add_virtual_hash(),
|
||||
index: self.add_virtual_target(),
|
||||
predicate: self.add_virtual_custom_predicate(params, with_pred),
|
||||
predicate: self.add_virtual_custom_predicate(params, false),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue