Add cond on invert zero and divide zero

This commit is contained in:
Heng Zhang
2021-11-06 01:24:02 +08:00
parent e4fff8db8f
commit 28ecab6a01
3 changed files with 166 additions and 26 deletions

View File

@@ -44,8 +44,8 @@ trait IntegerInstructions<N: FieldExt> {
fn sub(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn mul(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn square(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error>;
fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error>;
fn reduce(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error>;
fn assert_equal(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<(), Error>;
fn assert_strict_equal(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<(), Error>;
@@ -80,11 +80,11 @@ impl<W: FieldExt, N: FieldExt> IntegerInstructions<N> for IntegerChip<W, N> {
self._square(region, a, offset)
}
fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error> {
fn div(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, b: &AssignedInteger<N>, offset: &mut usize) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error> {
self._div(region, a, b, offset)
}
fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<AssignedInteger<N>, Error> {
fn invert(&self, region: &mut Region<'_, N>, a: &AssignedInteger<N>, offset: &mut usize) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error> {
self._invert(region, a, offset)
}
@@ -207,9 +207,9 @@ impl<W: FieldExt, N: FieldExt> IntegerChip<W, N> {
#[cfg(test)]
mod tests {
use super::{IntegerChip, IntegerConfig, IntegerInstructions};
use crate::circuit::main_gate::{MainGate, MainGateConfig};
use crate::circuit::AssignedValue;
use crate::circuit::main_gate::{MainGate, MainGateConfig, MainGateInstructions};
use crate::circuit::range::{RangeChip, RangeInstructions};
use crate::rns::{Integer, Limb, Rns};
use halo2::arithmetic::FieldExt;
@@ -703,6 +703,7 @@ mod tests {
struct TestCircuitInvert<W: FieldExt, N: FieldExt> {
integer_a: Option<Integer<N>>,
integer_b: Option<Integer<N>>,
cond: Option<N>,
rns: Rns<W, N>,
}
@@ -734,10 +735,12 @@ mod tests {
let offset = &mut 0;
let integer_a_0 = &integer_chip.assign_integer(&mut region, self.integer_a.clone(), offset)?.clone();
let integer_b_0 = &integer_chip.assign_integer(&mut region, self.integer_b.clone(), offset)?.clone();
let cond_0 = integer_chip.main_gate().assign_bit(&mut region, self.cond.clone(), offset)?.clone();
let integer_a_1 = &integer_a_0.clone();
let integer_b_1 = &integer_chip.invert(&mut region, integer_a_0, offset)?;
let (integer_b_1, cond_1) = &integer_chip.invert(&mut region, integer_a_0, offset)?;
integer_chip.assert_strict_equal(&mut region, integer_a_0, integer_a_1, offset)?;
integer_chip.assert_strict_equal(&mut region, integer_b_0, integer_b_1, offset)?;
integer_chip.main_gate().assert_equal(&mut region, cond_0, cond_1.clone(), offset)?;
Ok(())
},
@@ -778,6 +781,7 @@ mod tests {
let circuit = TestCircuitInvert::<Wrong, Native> {
integer_a: Some(integer_a),
integer_b: integer_b,
cond: Some(Native::zero()),
rns: rns.clone(),
};
@@ -789,11 +793,44 @@ mod tests {
assert_eq!(prover.verify(), Ok(()));
}
#[test]
fn test_zero_invert_circuit() {
use halo2::pasta::Fp as Wrong;
use halo2::pasta::Fq as Native;
let bit_len_limb = 64;
let rns = Rns::<Wrong, Native>::construct(bit_len_limb);
#[cfg(not(feature = "no_lookup"))]
let K: u32 = (rns.bit_len_lookup + 1) as u32;
#[cfg(feature = "no_lookup")]
let K: u32 = 8;
let integer_a = rns.new_from_big(0u32.into());
let integer_b = rns.new_from_big(1u32.into());
let circuit = TestCircuitInvert::<Wrong, Native> {
integer_a: Some(integer_a),
integer_b: Some(integer_b),
cond: Some(Native::one()),
rns: rns.clone(),
};
let prover = match MockProver::run(K, &circuit, vec![]) {
Ok(prover) => prover,
Err(e) => panic!("{:#?}", e),
};
assert_eq!(prover.verify(), Ok(()));
}
#[derive(Default, Clone, Debug)]
struct TestCircuitDivision<W: FieldExt, N: FieldExt> {
integer_a: Option<Integer<N>>,
integer_b: Option<Integer<N>>,
integer_c: Option<Integer<N>>,
cond: Option<N>,
rns: Rns<W, N>,
}
@@ -826,12 +863,14 @@ mod tests {
let integer_a_0 = &integer_chip.assign_integer(&mut region, self.integer_a.clone(), offset)?.clone();
let integer_b_0 = &integer_chip.assign_integer(&mut region, self.integer_b.clone(), offset)?.clone();
let integer_c_0 = &integer_chip.assign_integer(&mut region, self.integer_c.clone(), offset)?.clone();
let cond_0 = integer_chip.main_gate().assign_bit(&mut region, self.cond.clone(), offset)?.clone();
let integer_a_1 = &integer_a_0.clone();
let integer_b_1 = &integer_b_0.clone();
let integer_c_1 = &integer_chip.div(&mut region, integer_a_0, integer_b_0, offset)?;
let (integer_c_1, cond_1) = &integer_chip.div(&mut region, integer_a_0, integer_b_0, offset)?;
integer_chip.assert_strict_equal(&mut region, integer_a_0, integer_a_1, offset)?;
integer_chip.assert_strict_equal(&mut region, integer_b_0, integer_b_1, offset)?;
integer_chip.assert_equal(&mut region, integer_c_0, integer_c_1, offset)?;
integer_chip.main_gate().assert_equal(&mut region, cond_0, cond_1.clone(), offset)?;
Ok(())
},
@@ -874,6 +913,40 @@ mod tests {
integer_a: Some(integer_a.clone()),
integer_b: Some(integer_b),
integer_c: integer_c,
cond: Some(Native::zero()),
rns: rns.clone(),
};
let prover = match MockProver::run(K, &circuit, vec![]) {
Ok(prover) => prover,
Err(e) => panic!("{:#?}", e),
};
assert_eq!(prover.verify(), Ok(()));
}
#[test]
fn test_zero_division_circuit() {
use halo2::pasta::Fp as Wrong;
use halo2::pasta::Fq as Native;
let bit_len_limb = 64;
let rns = Rns::<Wrong, Native>::construct(bit_len_limb);
#[cfg(not(feature = "no_lookup"))]
let K: u32 = (rns.bit_len_lookup + 1) as u32;
#[cfg(feature = "no_lookup")]
let K: u32 = 8;
let integer_a = rns.rand_prenormalized();
let integer_b = rns.new_from_big(0u32.into());
let integer_c = integer_a.clone();
let circuit = TestCircuitDivision::<Wrong, Native> {
integer_a: Some(integer_a),
integer_b: Some(integer_b),
integer_c: Some(integer_c),
cond: Some(Native::one()),
rns: rns.clone(),
};

View File

@@ -1,6 +1,7 @@
use super::IntegerChip;
use super::IntegerInstructions;
use crate::circuit::{AssignedInteger};
use super::AssignedCondition;
use crate::circuit::AssignedInteger;
use halo2::arithmetic::FieldExt;
use halo2::circuit::Region;
use halo2::plonk::Error;
@@ -12,10 +13,10 @@ impl<W: FieldExt, N: FieldExt> IntegerChip<W, N> {
a: &AssignedInteger<N>,
b: &AssignedInteger<N>,
offset: &mut usize,
) -> Result<AssignedInteger<N>, Error> {
let b_inv = self.invert(region, b, offset)?;
) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error> {
let (b_inv, cond) = self.invert(region, b, offset)?;
let a_mul_b_inv = self.mul(region, a, &b_inv, offset)?;
Ok(a_mul_b_inv)
Ok((a_mul_b_inv, cond))
}
}

View File

@@ -1,34 +1,100 @@
use super::IntegerChip;
use super::IntegerInstructions;
use crate::{NUMBER_OF_LIMBS};
use crate::circuit::{AssignedInteger};
use super::{AssignedCondition, IntegerChip, IntegerInstructions, MainGateInstructions};
use crate::NUMBER_OF_LIMBS;
use crate::circuit::{Assigned, AssignedInteger};
use crate::circuit::main_gate::{CombinationOption, Term};
use halo2::arithmetic::FieldExt;
use halo2::circuit::Region;
use halo2::plonk::Error;
impl<W: FieldExt, N: FieldExt> IntegerChip<W, N> {
fn inert_inv_range_tune(&self) -> usize {
self.rns.bit_len_prenormalized - (self.rns.bit_len_limb * (NUMBER_OF_LIMBS - 1)) + 1
}
pub(crate) fn _invert(
&self,
region: &mut Region<'_, N>,
a: &AssignedInteger<N>,
offset: &mut usize,
) -> Result<AssignedInteger<N>, Error> {
let integer_inv = a.integer().and_then(|integer_a| {
self.rns.invert(&integer_a)
});
) -> Result<(AssignedInteger<N>, AssignedCondition<N>), Error> {
let main_gate = self.main_gate();
let (zero, one) = (N::zero(), N::one());
let integer_one = self.rns.new_from_big(1u32.into());
// TODO: Shall we just use SynthesisError here?
// Passing None to mul is undefined behavior for invert,
// throwing an error or panic would be a better choice.
let integer_a = a.integer().ok_or(Error::SynthesisError)?;
let integer_inv = self.rns.invert(&integer_a).or(Some(integer_one));
// TODO: For range constraints, we have these options:
// 1. extend mul to support prenormalized value.
// 2. call normalize here.
// 3. add wrong field range check on inv.
let most_significant_limb_bit_len = self.rns.bit_len_prenormalized - (self.rns.bit_len_limb * (NUMBER_OF_LIMBS - 1)) + 1;
let inv = self.range_assign_integer(region, integer_inv.into(), most_significant_limb_bit_len, offset)?;
let one = self.assign_integer(region, Some(integer_one), offset)?;
let a_mul_inv = self._mul(region, &a, &inv, offset)?;
let inv = self.range_assign_integer(region, integer_inv.into(), self.inert_inv_range_tune(), offset)?;
let a_mul_inv = self.mul(region, &a, &inv, offset)?;
self.assert_equal(region, &a_mul_inv, &one, offset)?;
// We believe the mul result is strictly less than wrong modulus, so we add strict constraints here.
// The limbs[1..NUMBER_OF_LIMBS] of a_mul_inv should be 0.
for i in 1..NUMBER_OF_LIMBS {
main_gate.assert_zero(region, a_mul_inv.limbs[i].clone(), offset)?;
}
Ok(inv)
// The limbs[0] of a_mul_inv should be 0 or 1, i.e. limbs[0] * limbs[0] - limbs[0] = 0.
main_gate.combine(
region,
Term::Assigned(&a_mul_inv.limbs[0], zero),
Term::Assigned(&a_mul_inv.limbs[0], zero),
Term::Assigned(&a_mul_inv.limbs[0], -one),
Term::Zero,
zero,
offset,
CombinationOption::SingleLinerMul,
)?;
// If a_mul_inv is 0 (i.e. not 1), then inv must be 1 (i.e. [1, 0, 0, 0]).
// Here we short x.limbs[i] as x[i].
// 1. (a_mul_inv[0] - 1) * inv[1..NUMBER_OF_LIMBS] = 0
// 2. (a_mul_inv[0] - 1) * (inv[0] - 1) = 0
for i in 1..NUMBER_OF_LIMBS {
main_gate.combine(
region,
Term::Assigned(&a_mul_inv.limbs[0], zero),
Term::Assigned(&inv.limbs[i], -one),
Term::Zero,
Term::Zero,
zero,
offset,
CombinationOption::SingleLinerMul,
)?;
}
println!("{}", offset);
main_gate.combine(
region,
Term::Assigned(&a_mul_inv.limbs[0], -one),
Term::Assigned(&inv.limbs[0], -one),
Term::Zero,
Term::Zero,
one,
offset,
CombinationOption::SingleLinerMul,
)?;
// Align with main_gain.invert(), cond = 1 - a_mul_inv
let cond = a_mul_inv.limbs[0].value().map(|a_mul_inv| { one - a_mul_inv });
let (_, cond_cell, _, _) = main_gate.combine(
region,
Term::Assigned(&a_mul_inv.limbs[0], one),
Term::Unassigned(cond, one),
Term::Zero,
Term::Zero,
-one,
offset,
CombinationOption::SingleLinerMul,
)?;
Ok((inv, AssignedCondition::new(cond_cell, cond)))
}
}