diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index e8cda34..2436cb5 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -153,7 +153,7 @@ impl PartialOrd for Hash { impl fmt::Display for Hash { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let v0 = self.0[0].to_canonical_u64(); - for i in 0..4 { + for i in 0..HASH_SIZE { write!(f, "{:02x}", (v0 >> (i * 8)) & 0xff)?; } write!(f, "…") @@ -168,8 +168,8 @@ impl FromHex for Hash { // In little endian let bytes = <[u8; 32]>::from_hex(hex)?; let mut buf: [u8; 8] = [0; 8]; - let mut inner = [F::ZERO; 4]; - for i in 0..4 { + let mut inner = [F::ZERO; HASH_SIZE]; + for i in 0..HASH_SIZE { buf.copy_from_slice(&bytes[8 * i..8 * (i + 1)]); inner[i] = F::from_canonical_u64(u64::from_le_bytes(buf)); } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 5b34571..1c133cd 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -5,6 +5,7 @@ use std::{fmt, hash as h, iter::zip}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; +use crate::backends::plonky2::basetypes::HASH_SIZE; use crate::util::hashmap_insert_no_dupe; use super::{ @@ -51,12 +52,13 @@ impl ToFields for HashOrWildcard { match self { HashOrWildcard::Hash(h) => h.to_fields(_params), HashOrWildcard::Wildcard(w) => { - let usizes: Vec = vec![0, 0, 0, *w]; + let mut usizes: Vec = vec![0; HASH_SIZE - 1]; + usizes.push(*w); let fields: Vec = usizes .iter() .map(|x| F::from_canonical_u64(*x as u64)) .collect(); - (fields, 4) + (fields, HASH_SIZE) } } } @@ -99,8 +101,7 @@ impl ToFields for StatementTmplArg { // Key(hash_or_wildcard1, hash_or_wildcard2) // => (2, [hash_or_wildcard1], [hash_or_wildcard2]) // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements - let hash_size = 4; - let statement_tmpl_arg_size = 2 * hash_size + 1; + let statement_tmpl_arg_size = 2 * HASH_SIZE + 1; match self { StatementTmplArg::None => { let fields: Vec = std::iter::repeat_with(|| F::from_canonical_u64(0)) @@ -111,7 +112,7 @@ impl ToFields for StatementTmplArg { StatementTmplArg::Literal(v) => { let fields: Vec = std::iter::once(F::from_canonical_u64(1)) .chain(v.to_fields(_params).0) - .chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(hash_size)) + .chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .collect(); (fields, statement_tmpl_arg_size) }