diff --git a/executor/src/witgen/affine_expression.rs b/executor/src/witgen/affine_expression.rs index 6069c27f9..d368efb93 100644 --- a/executor/src/witgen/affine_expression.rs +++ b/executor/src/witgen/affine_expression.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use itertools::Itertools; +use num_traits::Zero; use number::{BigInt, FieldElement}; use super::range_constraints::RangeConstraintSet; @@ -34,8 +35,8 @@ where { pub fn from_variable_id(var_id: K) -> AffineExpression { Self { - coefficients: BTreeMap::from([(var_id, 1.into())]), - offset: 0.into(), + coefficients: BTreeMap::from([(var_id, T::one())]), + offset: T::zero(), } } @@ -81,9 +82,9 @@ where // c * a + o = 0 <=> a = -o/c Ok(EvalValue::complete(vec![( i, - Constraint::Assignment(if *c == 1.into() { + Constraint::Assignment(if c.is_one() { -self.offset - } else if *c == (-1).into() { + } else if *c == -T::one() { self.offset } else { -self.offset / *c @@ -94,7 +95,7 @@ where IncompleteCause::MultipleLinearSolutions, )), (None, None) => { - if self.offset == 0.into() { + if self.offset.is_zero() { Ok(EvalValue::complete(vec![])) } else { Err(()) @@ -231,13 +232,13 @@ where } // Check if they are mutually exclusive and compute assignments. - let mut covered_bits: ::Integer = 0u32.into(); + let mut covered_bits: ::Integer = 0.into(); let mut assignments = EvalValue::complete(vec![]); let mut offset = (-self.offset).to_integer(); for (i, coeff, constraint) in parts { let constraint = constraint.clone().unwrap(); let mask = constraint.mask(); - if *mask & covered_bits != 0u32.into() { + if *mask & covered_bits != 0.into() { return Ok(EvalValue::incomplete( IncompleteCause::OverlappingBitConstraints, )); @@ -255,7 +256,7 @@ where offset &= !*mask; } - if offset != 0u32.into() { + if !offset.is_zero() { // We were not able to cover all of the offset, so this equation cannot be solved. Err(ConflictingRangeConstraints) } else { @@ -346,15 +347,15 @@ where "{}", self.nonzero_coefficients() .map(|(i, c)| { - if *c == 1.into() { + if c.is_one() { i.to_string() - } else if *c == (-1).into() { + } else if *c == -T::one() { format!("-{i}") } else { format!("{c} * {i}") } }) - .chain((self.offset != 0.into()).then_some(self.offset.to_string())) + .chain((!self.offset.is_zero()).then(|| self.offset.to_string())) .join(" + ") ) } diff --git a/executor/src/witgen/range_constraints.rs b/executor/src/witgen/range_constraints.rs index 14fd7fe78..08860172d 100644 --- a/executor/src/witgen/range_constraints.rs +++ b/executor/src/witgen/range_constraints.rs @@ -1,6 +1,8 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; +use num_traits::Zero; + use ast::analyzed::{Expression, Identity, IdentityKind, PolynomialReference}; use ast::parsed::BinaryOperator; use number::{log2_exact, BigInt, FieldElement}; @@ -29,7 +31,7 @@ impl RangeConstraint { pub fn from_max_bit(max_bit: u64) -> Self { assert!(max_bit < 1024); RangeConstraint { - mask: T::Integer::from(((1 << (max_bit + 1)) - 1) as u32), + mask: T::Integer::from((1u64 << (max_bit + 1)) - 1), } } @@ -39,7 +41,7 @@ impl RangeConstraint { /// The range constraint of the sum of two expressions. pub fn try_combine_sum(&self, other: &RangeConstraint) -> Option> { - if self.mask & other.mask == 0u32.into() { + if self.mask & other.mask == 0.into() { Some(RangeConstraint { mask: self.mask | other.mask, }) @@ -164,16 +166,16 @@ pub fn determine_global_constraints<'a, T: FieldElement>( /// TODO do this on the symbolic definition instead of the values. fn process_fixed_column(fixed: &[T]) -> Option<(RangeConstraint, bool)> { if let Some(bit) = smallest_period_candidate(fixed) { - let mask = T::Integer::from(((1 << bit) - 1) as u32); + let mask = T::Integer::from((1u64 << bit) - 1); if fixed .iter() .enumerate() - .all(|(i, v)| v.to_integer() == T::Integer::from(i as u32) & mask) + .all(|(i, v)| v.to_integer() == T::Integer::from(i as u64) & mask) { return Some((RangeConstraint::from_mask(mask), true)); } } - let mut mask = T::Integer::from(0u32); + let mut mask = T::Integer::zero(); for v in fixed.iter() { mask |= v.to_integer(); } diff --git a/number/src/macros.rs b/number/src/macros.rs index e3489bb78..7e47d6b36 100644 --- a/number/src/macros.rs +++ b/number/src/macros.rs @@ -6,7 +6,7 @@ macro_rules! powdr_field { }; use ark_ff::{BigInteger, Field, PrimeField}; use num_bigint::BigUint; - use num_traits::Num; + use num_traits::{Num, One, Zero}; use std::fmt; use std::ops::*; use std::str::FromStr; @@ -160,6 +160,29 @@ macro_rules! powdr_field { } } + impl AddAssign for BigIntImpl { + fn add_assign(&mut self, other: Self) { + self.value.add_with_carry(&other.value); + } + } + + impl Add for BigIntImpl { + type Output = Self; + fn add(mut self, other: Self) -> Self { + self.add_assign(other); + self + } + } + + impl Zero for BigIntImpl { + fn zero() -> Self { + BigIntImpl::new(<$ark_type as PrimeField>::BigInt::zero()) + } + fn is_zero(&self) -> bool { + self.value.is_zero() + } + } + impl TryFrom for BigIntImpl { type Error = (); @@ -171,9 +194,19 @@ macro_rules! powdr_field { } impl BigInt for BigIntImpl { + const NUM_BITS: usize = <$ark_type as PrimeField>::BigInt::NUM_LIMBS * 64; fn to_arbitrary_integer(self) -> BigUint { self.value.into() } + fn num_bits(&self) -> u32 { + self.value.num_bits() + } + fn one() -> Self { + BigIntImpl::new(<$ark_type as PrimeField>::BigInt::one()) + } + fn is_one(&self) -> bool { + self.value == <$ark_type as PrimeField>::BigInt::one() + } } impl From for $name { @@ -228,6 +261,7 @@ macro_rules! powdr_field { impl FieldElement for $name { type Integer = BigIntImpl; + const BITS: u32 = <$ark_type>::MODULUS_BIT_SIZE; fn from_str(s: &str) -> Self { Self { @@ -352,6 +386,24 @@ macro_rules! powdr_field { } } + impl Zero for $name { + fn zero() -> Self { + <$ark_type as Zero>::zero().into() + } + fn is_zero(&self) -> bool { + self.value.is_zero() + } + } + + impl One for $name { + fn one() -> Self { + <$ark_type as One>::one().into() + } + fn is_one(&self) -> bool { + self.value.is_one() + } + } + impl fmt::Display for $name { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let value = self.to_integer().value; diff --git a/number/src/serialize.rs b/number/src/serialize.rs index 71a95a2ca..b46c83a64 100644 --- a/number/src/serialize.rs +++ b/number/src/serialize.rs @@ -1,6 +1,6 @@ use std::io::{Read, Write}; -use crate::{BigInt, DegreeType, FieldElement}; +use crate::{DegreeType, FieldElement}; fn ceil_div(num: usize, div: usize) -> usize { (num + div - 1) / div @@ -11,7 +11,7 @@ pub fn write_polys_file( degree: DegreeType, polys: &Vec<(&str, Vec)>, ) { - let width = ceil_div(T::modulus().to_arbitrary_integer().bits() as usize, 64) * 8; + let width = ceil_div(T::BITS as usize, 64) * 8; for i in 0..degree as usize { for (_name, constant) in polys { @@ -26,7 +26,7 @@ pub fn read_polys_file<'a, T: FieldElement>( file: &mut impl Read, columns: &[&'a str], ) -> (Vec<(&'a str, Vec)>, DegreeType) { - let width = ceil_div(T::modulus().to_arbitrary_integer().bits() as usize, 64) * 8; + let width = ceil_div(T::BITS as usize, 64) * 8; let bytes_to_read = width * columns.len(); diff --git a/number/src/traits.rs b/number/src/traits.rs index d0570d188..00efcf3f0 100644 --- a/number/src/traits.rs +++ b/number/src/traits.rs @@ -1,5 +1,7 @@ use std::{fmt, hash::Hash, ops::*}; +use num_traits::{One, Zero}; + use crate::{AbstractNumberType, DegreeType}; /// A fixed-width integer type @@ -9,11 +11,15 @@ pub trait BigInt: + Sync + PartialEq + Eq - + From + + PartialOrd + + Ord + + From + BitAnd + BitOr + BitOrAssign + BitAndAssign + + AddAssign + + Add + fmt::Display + fmt::Debug + Copy @@ -21,10 +27,23 @@ pub trait BigInt: + Shl + Shr + BitXor + + Zero + fmt::LowerHex + TryFrom { + /// Number of bits of this base type. Not to be confused with the number of bits + /// of the field elements! + const NUM_BITS: usize; fn to_arbitrary_integer(self) -> AbstractNumberType; + /// Number of bits required to encode this particular number. + fn num_bits(&self) -> u32; + + /// Returns the constant one. + /// We are not implementing num_traits::One because it also requires multiplication. + fn one() -> Self; + + /// Checks if the number is one. + fn is_one(&self) -> bool; } /// A field element @@ -47,6 +66,8 @@ pub trait FieldElement: + Mul + Div + Neg + + Zero + + One + fmt::Display + fmt::Debug + From @@ -60,6 +81,8 @@ pub trait FieldElement: { /// The underlying fixed-width integer type type Integer: BigInt; + /// Number of bits required to represent elements of this field. + const BITS: u32; fn to_degree(&self) -> DegreeType;