diff --git a/src/circuit/integer.rs b/src/circuit/integer.rs index 37b6eca..3a7e954 100644 --- a/src/circuit/integer.rs +++ b/src/circuit/integer.rs @@ -44,8 +44,8 @@ trait IntegerInstructions { fn sub(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result, Error>; fn mul(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result, Error>; fn square(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result, Error>; - fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result, Error>; - fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result, Error>; + fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result<(AssignedInteger, AssignedCondition), Error>; + fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result<(AssignedInteger, AssignedCondition), Error>; fn reduce(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result, Error>; fn assert_equal(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result<(), Error>; fn assert_strict_equal(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result<(), Error>; @@ -80,11 +80,11 @@ impl IntegerInstructions for IntegerChip { self._square(region, a, offset) } - fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result, Error> { + fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize) -> Result<(AssignedInteger, AssignedCondition), Error> { self._div(region, a, b, offset) } - fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result, Error> { + fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize) -> Result<(AssignedInteger, AssignedCondition), Error> { self._invert(region, a, offset) } @@ -207,9 +207,9 @@ impl IntegerChip { #[cfg(test)] mod tests { - use super::{IntegerChip, IntegerConfig, IntegerInstructions}; - use crate::circuit::main_gate::{MainGate, MainGateConfig}; + use crate::circuit::AssignedValue; + use crate::circuit::main_gate::{MainGate, MainGateConfig, MainGateInstructions}; use crate::circuit::range::{RangeChip, RangeInstructions}; use crate::rns::{Integer, Limb, Rns}; use halo2::arithmetic::FieldExt; @@ -703,6 +703,7 @@ mod tests { struct TestCircuitInvert { integer_a: Option>, integer_b: Option>, + cond: Option, rns: Rns, } @@ -734,10 +735,12 @@ mod tests { let offset = &mut 0; let integer_a_0 = &integer_chip.assign_integer(&mut region, self.integer_a.clone(), offset)?.clone(); let integer_b_0 = &integer_chip.assign_integer(&mut region, self.integer_b.clone(), offset)?.clone(); + let cond_0 = integer_chip.main_gate().assign_bit(&mut region, self.cond.clone(), offset)?.clone(); let integer_a_1 = &integer_a_0.clone(); - let integer_b_1 = &integer_chip.invert(&mut region, integer_a_0, offset)?; + let (integer_b_1, cond_1) = &integer_chip.invert(&mut region, integer_a_0, offset)?; integer_chip.assert_strict_equal(&mut region, integer_a_0, integer_a_1, offset)?; integer_chip.assert_strict_equal(&mut region, integer_b_0, integer_b_1, offset)?; + integer_chip.main_gate().assert_equal(&mut region, cond_0, cond_1.clone(), offset)?; Ok(()) }, @@ -778,6 +781,7 @@ mod tests { let circuit = TestCircuitInvert:: { integer_a: Some(integer_a), integer_b: integer_b, + cond: Some(Native::zero()), rns: rns.clone(), }; @@ -789,11 +793,44 @@ mod tests { assert_eq!(prover.verify(), Ok(())); } + #[test] + fn test_zero_invert_circuit() { + use halo2::pasta::Fp as Wrong; + use halo2::pasta::Fq as Native; + + let bit_len_limb = 64; + let rns = Rns::::construct(bit_len_limb); + + #[cfg(not(feature = "no_lookup"))] + let K: u32 = (rns.bit_len_lookup + 1) as u32; + #[cfg(feature = "no_lookup")] + let K: u32 = 8; + + let integer_a = rns.new_from_big(0u32.into()); + let integer_b = rns.new_from_big(1u32.into()); + + let circuit = TestCircuitInvert:: { + integer_a: Some(integer_a), + integer_b: Some(integer_b), + cond: Some(Native::one()), + rns: rns.clone(), + }; + + let prover = match MockProver::run(K, &circuit, vec![]) { + Ok(prover) => prover, + Err(e) => panic!("{:#?}", e), + }; + + assert_eq!(prover.verify(), Ok(())); + } + + #[derive(Default, Clone, Debug)] struct TestCircuitDivision { integer_a: Option>, integer_b: Option>, integer_c: Option>, + cond: Option, rns: Rns, } @@ -826,12 +863,14 @@ mod tests { let integer_a_0 = &integer_chip.assign_integer(&mut region, self.integer_a.clone(), offset)?.clone(); let integer_b_0 = &integer_chip.assign_integer(&mut region, self.integer_b.clone(), offset)?.clone(); let integer_c_0 = &integer_chip.assign_integer(&mut region, self.integer_c.clone(), offset)?.clone(); + let cond_0 = integer_chip.main_gate().assign_bit(&mut region, self.cond.clone(), offset)?.clone(); let integer_a_1 = &integer_a_0.clone(); let integer_b_1 = &integer_b_0.clone(); - let integer_c_1 = &integer_chip.div(&mut region, integer_a_0, integer_b_0, offset)?; + let (integer_c_1, cond_1) = &integer_chip.div(&mut region, integer_a_0, integer_b_0, offset)?; integer_chip.assert_strict_equal(&mut region, integer_a_0, integer_a_1, offset)?; integer_chip.assert_strict_equal(&mut region, integer_b_0, integer_b_1, offset)?; integer_chip.assert_equal(&mut region, integer_c_0, integer_c_1, offset)?; + integer_chip.main_gate().assert_equal(&mut region, cond_0, cond_1.clone(), offset)?; Ok(()) }, @@ -874,6 +913,40 @@ mod tests { integer_a: Some(integer_a.clone()), integer_b: Some(integer_b), integer_c: integer_c, + cond: Some(Native::zero()), + rns: rns.clone(), + }; + + let prover = match MockProver::run(K, &circuit, vec![]) { + Ok(prover) => prover, + Err(e) => panic!("{:#?}", e), + }; + + assert_eq!(prover.verify(), Ok(())); + } + + #[test] + fn test_zero_division_circuit() { + use halo2::pasta::Fp as Wrong; + use halo2::pasta::Fq as Native; + + let bit_len_limb = 64; + let rns = Rns::::construct(bit_len_limb); + + #[cfg(not(feature = "no_lookup"))] + let K: u32 = (rns.bit_len_lookup + 1) as u32; + #[cfg(feature = "no_lookup")] + let K: u32 = 8; + + let integer_a = rns.rand_prenormalized(); + let integer_b = rns.new_from_big(0u32.into()); + let integer_c = integer_a.clone(); + + let circuit = TestCircuitDivision:: { + integer_a: Some(integer_a), + integer_b: Some(integer_b), + integer_c: Some(integer_c), + cond: Some(Native::one()), rns: rns.clone(), }; diff --git a/src/circuit/integer/div.rs b/src/circuit/integer/div.rs index 05cf761..4f57e1d 100644 --- a/src/circuit/integer/div.rs +++ b/src/circuit/integer/div.rs @@ -1,6 +1,7 @@ use super::IntegerChip; use super::IntegerInstructions; -use crate::circuit::{AssignedInteger}; +use super::AssignedCondition; +use crate::circuit::AssignedInteger; use halo2::arithmetic::FieldExt; use halo2::circuit::Region; use halo2::plonk::Error; @@ -12,10 +13,10 @@ impl IntegerChip { a: &AssignedInteger, b: &AssignedInteger, offset: &mut usize, - ) -> Result, Error> { - let b_inv = self.invert(region, b, offset)?; + ) -> Result<(AssignedInteger, AssignedCondition), Error> { + let (b_inv, cond) = self.invert(region, b, offset)?; let a_mul_b_inv = self.mul(region, a, &b_inv, offset)?; - Ok(a_mul_b_inv) + Ok((a_mul_b_inv, cond)) } } diff --git a/src/circuit/integer/invert.rs b/src/circuit/integer/invert.rs index c8fdada..5997067 100644 --- a/src/circuit/integer/invert.rs +++ b/src/circuit/integer/invert.rs @@ -1,34 +1,100 @@ -use super::IntegerChip; -use super::IntegerInstructions; -use crate::{NUMBER_OF_LIMBS}; -use crate::circuit::{AssignedInteger}; +use super::{AssignedCondition, IntegerChip, IntegerInstructions, MainGateInstructions}; +use crate::NUMBER_OF_LIMBS; +use crate::circuit::{Assigned, AssignedInteger}; +use crate::circuit::main_gate::{CombinationOption, Term}; use halo2::arithmetic::FieldExt; use halo2::circuit::Region; use halo2::plonk::Error; impl IntegerChip { + fn inert_inv_range_tune(&self) -> usize { + self.rns.bit_len_prenormalized - (self.rns.bit_len_limb * (NUMBER_OF_LIMBS - 1)) + 1 + } + pub(crate) fn _invert( &self, region: &mut Region<'_, N>, a: &AssignedInteger, offset: &mut usize, - ) -> Result, Error> { - let integer_inv = a.integer().and_then(|integer_a| { - self.rns.invert(&integer_a) - }); + ) -> Result<(AssignedInteger, AssignedCondition), Error> { + let main_gate = self.main_gate(); + + let (zero, one) = (N::zero(), N::one()); let integer_one = self.rns.new_from_big(1u32.into()); + // TODO: Shall we just use SynthesisError here? + // Passing None to mul is undefined behavior for invert, + // throwing an error or panic would be a better choice. + let integer_a = a.integer().ok_or(Error::SynthesisError)?; + let integer_inv = self.rns.invert(&integer_a).or(Some(integer_one)); // TODO: For range constraints, we have these options: // 1. extend mul to support prenormalized value. // 2. call normalize here. // 3. add wrong field range check on inv. - let most_significant_limb_bit_len = self.rns.bit_len_prenormalized - (self.rns.bit_len_limb * (NUMBER_OF_LIMBS - 1)) + 1; - let inv = self.range_assign_integer(region, integer_inv.into(), most_significant_limb_bit_len, offset)?; - let one = self.assign_integer(region, Some(integer_one), offset)?; - let a_mul_inv = self._mul(region, &a, &inv, offset)?; + let inv = self.range_assign_integer(region, integer_inv.into(), self.inert_inv_range_tune(), offset)?; + let a_mul_inv = self.mul(region, &a, &inv, offset)?; - self.assert_equal(region, &a_mul_inv, &one, offset)?; + // We believe the mul result is strictly less than wrong modulus, so we add strict constraints here. + // The limbs[1..NUMBER_OF_LIMBS] of a_mul_inv should be 0. + for i in 1..NUMBER_OF_LIMBS { + main_gate.assert_zero(region, a_mul_inv.limbs[i].clone(), offset)?; + } - Ok(inv) + // The limbs[0] of a_mul_inv should be 0 or 1, i.e. limbs[0] * limbs[0] - limbs[0] = 0. + main_gate.combine( + region, + Term::Assigned(&a_mul_inv.limbs[0], zero), + Term::Assigned(&a_mul_inv.limbs[0], zero), + Term::Assigned(&a_mul_inv.limbs[0], -one), + Term::Zero, + zero, + offset, + CombinationOption::SingleLinerMul, + )?; + + // If a_mul_inv is 0 (i.e. not 1), then inv must be 1 (i.e. [1, 0, 0, 0]). + // Here we short x.limbs[i] as x[i]. + // 1. (a_mul_inv[0] - 1) * inv[1..NUMBER_OF_LIMBS] = 0 + // 2. (a_mul_inv[0] - 1) * (inv[0] - 1) = 0 + for i in 1..NUMBER_OF_LIMBS { + main_gate.combine( + region, + Term::Assigned(&a_mul_inv.limbs[0], zero), + Term::Assigned(&inv.limbs[i], -one), + Term::Zero, + Term::Zero, + zero, + offset, + CombinationOption::SingleLinerMul, + )?; + } + + println!("{}", offset); + + main_gate.combine( + region, + Term::Assigned(&a_mul_inv.limbs[0], -one), + Term::Assigned(&inv.limbs[0], -one), + Term::Zero, + Term::Zero, + one, + offset, + CombinationOption::SingleLinerMul, + )?; + + // Align with main_gain.invert(), cond = 1 - a_mul_inv + let cond = a_mul_inv.limbs[0].value().map(|a_mul_inv| { one - a_mul_inv }); + let (_, cond_cell, _, _) = main_gate.combine( + region, + Term::Assigned(&a_mul_inv.limbs[0], one), + Term::Unassigned(cond, one), + Term::Zero, + Term::Zero, + -one, + offset, + CombinationOption::SingleLinerMul, + )?; + + Ok((inv, AssignedCondition::new(cond_cell, cond))) } }