Feat/fst order pred part1 & part2 (#454)
Implement support for first order predicates in the backend. Now a statement template can have a predicate hash or a wildcard. ## predicate <-> predicate hash constraints To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization. This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard). Then in normalization we recalculate the predicate hash if it was a Batch Self. This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard. How this is achieved: - Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard) - After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf. If it was a predicate but not batch self, the old value was used which was constrained via the constructor. See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit` ## Wildcard predicate resolution It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
parent
1724e7b146
commit
9c9a2c454c
11 changed files with 569 additions and 240 deletions
|
|
@ -1,5 +1,6 @@
|
|||
use std::{fmt, iter, sync::Arc};
|
||||
|
||||
use itertools::Itertools;
|
||||
use plonky2::field::types::Field;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -70,36 +71,28 @@ impl ToFields for StatementTmplArg {
|
|||
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
||||
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
||||
match self {
|
||||
StatementTmplArg::None => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::None))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect();
|
||||
fields
|
||||
}
|
||||
StatementTmplArg::Literal(v) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
|
||||
.chain(v.raw().to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect();
|
||||
fields
|
||||
}
|
||||
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect_vec(),
|
||||
StatementTmplArg::Literal(v) => iter::once(F::from(StatementTmplArgPrefix::Literal))
|
||||
.chain(v.raw().to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect_vec(),
|
||||
StatementTmplArg::AnchoredKey(wc1, kw2) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
||||
iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
||||
.chain(wc1.to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
|
||||
.chain(kw2.to_fields(params))
|
||||
.collect();
|
||||
fields
|
||||
.collect_vec()
|
||||
}
|
||||
StatementTmplArg::Wildcard(wc) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
||||
iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
||||
.chain(wc.to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect();
|
||||
fields
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -121,16 +114,79 @@ impl fmt::Display for StatementTmplArg {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum PredicateOrWildcard {
|
||||
Predicate(Predicate),
|
||||
Wildcard(Wildcard),
|
||||
}
|
||||
|
||||
impl PredicateOrWildcard {
|
||||
pub fn as_pred(&self) -> Option<&Predicate> {
|
||||
match self {
|
||||
Self::Predicate(pred) => Some(pred),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub fn as_wc(&self) -> Option<&Wildcard> {
|
||||
match self {
|
||||
Self::Wildcard(wc) => Some(wc),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PredicateOrWildcard {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::Predicate(pred) => pred.fmt(f),
|
||||
Self::Wildcard(wc) => {
|
||||
write!(f, "?")?;
|
||||
wc.fmt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum PredicateOrWildcardPrefix {
|
||||
Predicate = 0,
|
||||
Wildcard = 1,
|
||||
}
|
||||
|
||||
impl From<PredicateOrWildcardPrefix> for F {
|
||||
fn from(prefix: PredicateOrWildcardPrefix) -> Self {
|
||||
Self::from_canonical_usize(prefix as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFields for PredicateOrWildcard {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// Encoding:
|
||||
// Predicate(pred) => (0, [hash(pred) ])
|
||||
// Wildcard(wc) => (1, wc_index, 0...)
|
||||
match self {
|
||||
Self::Predicate(pred) => iter::once(F::from(PredicateOrWildcardPrefix::Predicate))
|
||||
.chain(pred.hash(params).to_fields(params))
|
||||
.collect_vec(),
|
||||
Self::Wildcard(wc) => iter::once(F::from(PredicateOrWildcardPrefix::Wildcard))
|
||||
.chain(wc.to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::pred_hash_or_wc_size())
|
||||
.collect_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Statement Template for a Custom Predicate
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct StatementTmpl {
|
||||
pub pred: Predicate,
|
||||
pub pred_or_wc: PredicateOrWildcard,
|
||||
pub args: Vec<StatementTmplArg>,
|
||||
}
|
||||
|
||||
impl StatementTmpl {
|
||||
pub fn pred(&self) -> &Predicate {
|
||||
&self.pred
|
||||
pub fn pred_or_wc(&self) -> &PredicateOrWildcard {
|
||||
&self.pred_or_wc
|
||||
}
|
||||
pub fn args(&self) -> &[StatementTmplArg] {
|
||||
&self.args
|
||||
|
|
@ -139,7 +195,7 @@ impl StatementTmpl {
|
|||
|
||||
impl fmt::Display for StatementTmpl {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.pred.fmt(f)?;
|
||||
self.pred_or_wc.fmt(f)?;
|
||||
write!(f, "(")?;
|
||||
for (i, arg) in self.args.iter().enumerate() {
|
||||
if i != 0 {
|
||||
|
|
@ -154,7 +210,7 @@ impl fmt::Display for StatementTmpl {
|
|||
impl ToFields for StatementTmpl {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// serialize as:
|
||||
// predicate (6 field elements)
|
||||
// predicate (4 field elements)
|
||||
// then the StatementTmplArgs
|
||||
|
||||
// TODO think if this check should go into the StatementTmpl creation,
|
||||
|
|
@ -168,15 +224,13 @@ impl ToFields for StatementTmpl {
|
|||
);
|
||||
}
|
||||
|
||||
let mut fields: Vec<F> = self
|
||||
.pred
|
||||
.hash(params)
|
||||
self.pred_or_wc
|
||||
.to_fields(params)
|
||||
.into_iter()
|
||||
.chain(self.args.iter().flat_map(|sta| sta.to_fields(params)))
|
||||
.collect();
|
||||
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
|
||||
fields
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(params.statement_tmpl_size())
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -203,7 +257,9 @@ impl CustomPredicate {
|
|||
name: "empty".to_string(),
|
||||
conjunction: false,
|
||||
statements: vec![StatementTmpl {
|
||||
pred: Predicate::Native(NativePredicate::None),
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(
|
||||
NativePredicate::None,
|
||||
)),
|
||||
args: vec![],
|
||||
}],
|
||||
args_len: 0,
|
||||
|
|
@ -276,11 +332,11 @@ impl CustomPredicate {
|
|||
}
|
||||
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
|
||||
StatementTmpl {
|
||||
pred: Predicate::Native(if self.conjunction {
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(if self.conjunction {
|
||||
NativePredicate::None
|
||||
} else {
|
||||
NativePredicate::False
|
||||
}),
|
||||
})),
|
||||
args: vec![],
|
||||
}
|
||||
}
|
||||
|
|
@ -318,7 +374,7 @@ impl ToFields for CustomPredicate {
|
|||
}
|
||||
|
||||
let pad_st = self.pad_statement_tmpl();
|
||||
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
||||
iter::once(F::from_bool(self.conjunction))
|
||||
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
||||
.chain(
|
||||
self.statements
|
||||
|
|
@ -327,8 +383,7 @@ impl ToFields for CustomPredicate {
|
|||
.take(params.max_custom_predicate_arity)
|
||||
.flat_map(|st| st.to_fields(params)),
|
||||
)
|
||||
.collect();
|
||||
fields
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -350,7 +405,7 @@ impl fmt::Display for CustomPredicate {
|
|||
writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?;
|
||||
for st in &self.statements {
|
||||
write!(f, " ")?;
|
||||
st.pred.fmt(f)?;
|
||||
st.pred_or_wc.fmt(f)?;
|
||||
write!(f, "(")?;
|
||||
for (i, arg) in st.args.iter().enumerate() {
|
||||
if i != 0 {
|
||||
|
|
@ -382,14 +437,12 @@ impl ToFields for CustomPredicateBatch {
|
|||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// all the custom predicates in order
|
||||
let pad_pred = CustomPredicate::empty();
|
||||
let fields: Vec<F> = self
|
||||
.predicates
|
||||
self.predicates
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_pred))
|
||||
.take(params.max_custom_batch_size)
|
||||
.flat_map(|p| p.to_fields(params))
|
||||
.collect();
|
||||
fields
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -418,7 +471,6 @@ impl CustomPredicateBatch {
|
|||
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
||||
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
||||
let input = self.to_fields(params);
|
||||
|
||||
hash_fields(&input)
|
||||
}
|
||||
|
||||
|
|
@ -470,7 +522,10 @@ mod tests {
|
|||
};
|
||||
|
||||
fn st(p: Predicate, args: Vec<StatementTmplArg>) -> StatementTmpl {
|
||||
StatementTmpl { pred: p, args }
|
||||
StatementTmpl {
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(p),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
fn key(name: &str) -> Key {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue