halo2_gadgets: Enable more inversions to be batched during synthesis

This commit is contained in:
Jack Grigg
2022-05-12 22:21:29 +00:00
parent 5f1fb166d1
commit 515f97769f
11 changed files with 89 additions and 107 deletions

View File

@@ -7,11 +7,11 @@ use crate::{
};
use arrayvec::ArrayVec;
use ff::{Field, PrimeField};
use ff::PrimeField;
use group::prime::PrimeCurveAffine;
use halo2_proofs::{
circuit::{AssignedCell, Chip, Layouter},
plonk::{Advice, Column, ConstraintSystem, Error, Fixed},
plonk::{Advice, Assigned, Column, ConstraintSystem, Error, Fixed},
};
use pasta_curves::{arithmetic::CurveAffine, pallas};
@@ -35,9 +35,13 @@ pub(crate) use mul::incomplete::DoubleAndAdd;
#[derive(Clone, Debug)]
pub struct EccPoint {
/// x-coordinate
x: AssignedCell<pallas::Base, pallas::Base>,
///
/// Stored as an `Assigned<F>` to enable batching inversions.
x: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
/// y-coordinate
y: AssignedCell<pallas::Base, pallas::Base>,
///
/// Stored as an `Assigned<F>` to enable batching inversions.
y: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
}
impl EccPoint {
@@ -46,8 +50,8 @@ impl EccPoint {
/// This is an internal API that we only use where we know we have a valid curve point
/// (specifically inside Sinsemilla).
pub(crate) fn from_coordinates_unchecked(
x: AssignedCell<pallas::Base, pallas::Base>,
y: AssignedCell<pallas::Base, pallas::Base>,
x: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
y: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
) -> Self {
EccPoint { x, y }
}
@@ -59,7 +63,7 @@ impl EccPoint {
if x.is_zero_vartime() && y.is_zero_vartime() {
Some(pallas::Affine::identity())
} else {
Some(pallas::Affine::from_xy(*x, *y).unwrap())
Some(pallas::Affine::from_xy(x.evaluate(), y.evaluate()).unwrap())
}
}
_ => None,
@@ -68,12 +72,12 @@ impl EccPoint {
/// The cell containing the affine short-Weierstrass x-coordinate,
/// or 0 for the zero point.
pub fn x(&self) -> AssignedCell<pallas::Base, pallas::Base> {
self.x.clone()
self.x.clone().evaluate()
}
/// The cell containing the affine short-Weierstrass y-coordinate,
/// or 0 for the zero point.
pub fn y(&self) -> AssignedCell<pallas::Base, pallas::Base> {
self.y.clone()
self.y.clone().evaluate()
}
#[cfg(test)]
@@ -87,9 +91,13 @@ impl EccPoint {
#[derive(Clone, Debug)]
pub struct NonIdentityEccPoint {
/// x-coordinate
x: AssignedCell<pallas::Base, pallas::Base>,
///
/// Stored as an `Assigned<F>` to enable batching inversions.
x: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
/// y-coordinate
y: AssignedCell<pallas::Base, pallas::Base>,
///
/// Stored as an `Assigned<F>` to enable batching inversions.
y: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
}
impl NonIdentityEccPoint {
@@ -98,8 +106,8 @@ impl NonIdentityEccPoint {
/// This is an internal API that we only use where we know we have a valid non-identity
/// curve point (specifically inside Sinsemilla).
pub(crate) fn from_coordinates_unchecked(
x: AssignedCell<pallas::Base, pallas::Base>,
y: AssignedCell<pallas::Base, pallas::Base>,
x: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
y: AssignedCell<Assigned<pallas::Base>, pallas::Base>,
) -> Self {
NonIdentityEccPoint { x, y }
}
@@ -109,18 +117,18 @@ impl NonIdentityEccPoint {
match (self.x.value(), self.y.value()) {
(Some(x), Some(y)) => {
assert!(!x.is_zero_vartime() && !y.is_zero_vartime());
Some(pallas::Affine::from_xy(*x, *y).unwrap())
Some(pallas::Affine::from_xy(x.evaluate(), y.evaluate()).unwrap())
}
_ => None,
}
}
/// The cell containing the affine short-Weierstrass x-coordinate.
pub fn x(&self) -> AssignedCell<pallas::Base, pallas::Base> {
self.x.clone()
self.x.clone().evaluate()
}
/// The cell containing the affine short-Weierstrass y-coordinate.
pub fn y(&self) -> AssignedCell<pallas::Base, pallas::Base> {
self.y.clone()
self.y.clone().evaluate()
}
}

View File

@@ -1,8 +1,7 @@
use super::EccPoint;
use ff::{BatchInvert, Field};
use halo2_proofs::{
circuit::Region,
plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Expression, Selector},
plonk::{Advice, Assigned, Column, ConstraintSystem, Constraints, Error, Expression, Selector},
poly::Rotation,
};
use pasta_curves::{arithmetic::FieldExt, pallas};
@@ -227,58 +226,31 @@ impl Config {
let (x_p, y_p) = (p.x.value(), p.y.value());
let (x_q, y_q) = (q.x.value(), q.y.value());
// [alpha, beta, gamma, delta]
// = [inv0(x_q - x_p), inv0(x_p), inv0(x_q), inv0(y_q + y_p)]
// where inv0(x) = 0 if x = 0, 1/x otherwise.
//
let (alpha, beta, gamma, delta) = {
let inverses = x_p
.zip(x_q)
.zip(y_p)
.zip(y_q)
.map(|(((x_p, x_q), y_p), y_q)| {
let alpha = x_q - x_p;
let beta = x_p;
let gamma = x_q;
let delta = y_q + y_p;
let mut inverses = [alpha, *beta, *gamma, delta];
inverses.batch_invert();
inverses
});
if let Some([alpha, beta, gamma, delta]) = inverses {
(Some(alpha), Some(beta), Some(gamma), Some(delta))
} else {
(None, None, None, None)
}
};
// Assign α = inv0(x_q - x_p)
let alpha = x_p.zip(x_q).map(|(x_p, x_q)| (x_q - x_p).invert());
region.assign_advice(|| "α", self.alpha, offset, || alpha.ok_or(Error::Synthesis))?;
// Assign β = inv0(x_p)
let beta = x_p.map(|x_p| x_p.invert());
region.assign_advice(|| "β", self.beta, offset, || beta.ok_or(Error::Synthesis))?;
// Assign γ = inv0(x_q)
let gamma = x_q.map(|x_q| x_q.invert());
region.assign_advice(|| "γ", self.gamma, offset, || gamma.ok_or(Error::Synthesis))?;
// Assign δ = inv0(y_q + y_p) if x_q = x_p, 0 otherwise
region.assign_advice(
|| "δ",
self.delta,
offset,
|| {
let x_p = x_p.ok_or(Error::Synthesis)?;
let x_q = x_q.ok_or(Error::Synthesis)?;
let delta = x_p
.zip(x_q)
.zip(y_p)
.zip(y_q)
.map(|(((x_p, x_q), y_p), y_q)| {
if x_q == x_p {
delta.ok_or(Error::Synthesis)
(y_q + y_p).invert()
} else {
Ok(pallas::Base::zero())
Assigned::Zero
}
},
)?;
});
region.assign_advice(|| "δ", self.delta, offset, || delta.ok_or(Error::Synthesis))?;
#[allow(clippy::collapsible_else_if)]
// Assign lambda
@@ -296,13 +268,13 @@ impl Config {
} else {
if !y_p.is_zero_vartime() {
// 3(x_p)^2
let three_x_p_sq = pallas::Base::from(3) * x_p.square();
let three_x_p_sq = x_p.square() * pallas::Base::from(3);
// 1 / 2(y_p)
let inv_two_y_p = y_p.invert().unwrap() * pallas::Base::TWO_INV;
let inv_two_y_p = y_p.invert() * pallas::Base::TWO_INV;
// λ = 3(x_p)^2 / 2(y_p)
three_x_p_sq * inv_two_y_p
} else {
pallas::Base::zero()
Assigned::Zero
}
}
});
@@ -329,7 +301,7 @@ impl Config {
(*x_p, *y_p)
} else if (x_q == x_p) && (*y_q == -y_p) {
// P + (-P) maps to (0,0)
(pallas::Base::zero(), pallas::Base::zero())
(Assigned::Zero, Assigned::Zero)
} else {
// x_r = λ^2 - x_p - x_q
let x_r = lambda.square() - x_p - x_q;

View File

@@ -1,14 +1,12 @@
use std::collections::HashSet;
use super::NonIdentityEccPoint;
use ff::Field;
use group::Curve;
use halo2_proofs::{
circuit::Region,
plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Selector},
poly::Rotation,
};
use pasta_curves::{arithmetic::CurveAffine, pallas};
use pasta_curves::pallas;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Config {
@@ -121,20 +119,24 @@ impl Config {
q.y.copy_advice(|| "y_q", region, self.y_qr, offset)?;
// Compute the sum `P + Q = R`
let r = {
let p = p.point();
let q = q.point();
let r = p
.zip(q)
.map(|(p, q)| (p + q).to_affine().coordinates().unwrap());
let r_x = r.map(|r| *r.x());
let r_y = r.map(|r| *r.y());
(r_x, r_y)
};
let r = x_p
.zip(y_p)
.zip(x_q)
.zip(y_q)
.map(|(((x_p, y_p), x_q), y_q)| {
{
// λ = (y_q - y_p)/(x_q - x_p)
let lambda = (y_q - y_p) * (x_q - x_p).invert();
// x_r = λ^2 - x_p - x_q
let x_r = lambda.square() - x_p - x_q;
// y_r = λ(x_p - x_r) - y_p
let y_r = lambda * (x_p - x_r) - y_p;
(x_r, y_r)
}
});
// Assign the sum to `x_qr`, `y_qr` columns in the next row
let x_r = r.0;
let x_r = r.map(|r| r.0);
let x_r_var = region.assign_advice(
|| "x_r",
self.x_qr,
@@ -142,7 +144,7 @@ impl Config {
|| x_r.ok_or(Error::Synthesis),
)?;
let y_r = r.1;
let y_r = r.map(|r| r.1);
let y_r_var = region.assign_advice(
|| "y_r",
self.y_qr,

View File

@@ -12,7 +12,7 @@ use ff::PrimeField;
use halo2_proofs::{
arithmetic::FieldExt,
circuit::{AssignedCell, Layouter, Region},
plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Selector},
plonk::{Advice, Assigned, Column, ConstraintSystem, Constraints, Error, Selector},
poly::Rotation,
};
use uint::construct_uint;
@@ -360,7 +360,7 @@ impl Config {
if !lsb {
base.x.value().cloned()
} else {
Some(pallas::Base::zero())
Some(Assigned::Zero)
}
} else {
None
@@ -370,7 +370,7 @@ impl Config {
if !lsb {
base.y.value().map(|y_p| -y_p)
} else {
Some(pallas::Base::zero())
Some(Assigned::Zero)
}
} else {
None
@@ -404,9 +404,9 @@ impl Config {
#[derive(Clone, Debug)]
// `x`-coordinate of the accumulator.
struct X<F: FieldExt>(AssignedCell<F, F>);
struct X<F: FieldExt>(AssignedCell<Assigned<F>, F>);
impl<F: FieldExt> Deref for X<F> {
type Target = AssignedCell<F, F>;
type Target = AssignedCell<Assigned<F>, F>;
fn deref(&self) -> &Self::Target {
&self.0
@@ -415,9 +415,9 @@ impl<F: FieldExt> Deref for X<F> {
#[derive(Clone, Debug)]
// `y`-coordinate of the accumulator.
struct Y<F: FieldExt>(AssignedCell<F, F>);
struct Y<F: FieldExt>(AssignedCell<Assigned<F>, F>);
impl<F: FieldExt> Deref for Y<F> {
type Target = AssignedCell<F, F>;
type Target = AssignedCell<Assigned<F>, F>;
fn deref(&self) -> &Self::Target {
&self.0

View File

@@ -1,7 +1,6 @@
use super::super::NonIdentityEccPoint;
use super::{X, Y, Z};
use crate::utilities::bool_check;
use ff::Field;
use halo2_proofs::{
circuit::Region,
plonk::{
@@ -335,7 +334,7 @@ impl<const NUM_BITS: usize> Config<NUM_BITS> {
.zip(y_p)
.zip(x_a.value())
.zip(x_p)
.map(|(((y_a, y_p), x_a), x_p)| (y_a - y_p) * (x_a - x_p).invert().unwrap());
.map(|(((y_a, y_p), x_a), x_p)| (y_a - y_p) * (x_a - x_p).invert());
region.assign_advice(
|| "lambda1",
self.double_and_add.lambda_1,
@@ -356,7 +355,7 @@ impl<const NUM_BITS: usize> Config<NUM_BITS> {
.zip(x_a.value())
.zip(x_r)
.map(|(((lambda1, y_a), x_a), x_r)| {
pallas::Base::from(2) * y_a * (x_a - x_r).invert().unwrap() - lambda1
y_a * pallas::Base::from(2) * (x_a - x_r).invert() - lambda1
});
region.assign_advice(
|| "lambda2",

View File

@@ -6,11 +6,10 @@ use crate::{
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::{
circuit::Layouter,
plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Expression, Selector},
plonk::{Advice, Assigned, Column, ConstraintSystem, Constraints, Error, Expression, Selector},
poly::Rotation,
};
use ff::Field;
use pasta_curves::{arithmetic::FieldExt, pallas};
use std::iter;
@@ -149,13 +148,7 @@ impl Config {
// Witness η = inv0(z_130), where inv0(x) = 0 if x = 0, 1/x otherwise
{
let eta = zs[130].value().map(|z_130| {
if z_130.is_zero_vartime() {
pallas::Base::zero()
} else {
z_130.invert().unwrap()
}
});
let eta = zs[130].value().map(|z_130| Assigned::from(z_130).invert());
region.assign_advice(
|| "η = inv0(z_130)",
self.advices[0],

View File

@@ -278,7 +278,7 @@ impl<FixedPoints: super::FixedPoints<pallas::Affine>> Config<FixedPoints> {
let x = mul_b.map(|mul_b| {
let x = *mul_b.x();
assert!(x != pallas::Base::zero());
x
x.into()
});
let x = region.assign_advice(
|| format!("mul_b_x, window {}", w),
@@ -290,7 +290,7 @@ impl<FixedPoints: super::FixedPoints<pallas::Affine>> Config<FixedPoints> {
let y = mul_b.map(|mul_b| {
let y = *mul_b.y();
assert!(y != pallas::Base::zero());
y
y.into()
});
let y = region.assign_advice(
|| format!("mul_b_y, window {}", w),

View File

@@ -177,7 +177,7 @@ impl<Fixed: FixedPoints<pallas::Affine>> Config<Fixed> {
// Conditionally negate `y`-coordinate
let y_val = if let Some(sign) = sign.value() {
if sign == &-pallas::Base::one() {
magnitude_mul.y.value().cloned().map(|y: pallas::Base| -y)
magnitude_mul.y.value().cloned().map(|y| -y)
} else {
magnitude_mul.y.value().cloned()
}

View File

@@ -5,15 +5,16 @@ use group::prime::PrimeCurveAffine;
use halo2_proofs::{
circuit::{AssignedCell, Region},
plonk::{
Advice, Column, ConstraintSystem, Constraints, Error, Expression, Selector, VirtualCells,
Advice, Assigned, Column, ConstraintSystem, Constraints, Error, Expression, Selector,
VirtualCells,
},
poly::Rotation,
};
use pasta_curves::{arithmetic::CurveAffine, pallas};
type Coordinates = (
AssignedCell<pallas::Base, pallas::Base>,
AssignedCell<pallas::Base, pallas::Base>,
AssignedCell<Assigned<pallas::Base>, pallas::Base>,
AssignedCell<Assigned<pallas::Base>, pallas::Base>,
);
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@@ -86,7 +87,7 @@ impl Config {
fn assign_xy(
&self,
value: Option<(pallas::Base, pallas::Base)>,
value: Option<(Assigned<pallas::Base>, Assigned<pallas::Base>)>,
offset: usize,
region: &mut Region<'_, pallas::Base>,
) -> Result<Coordinates, Error> {
@@ -116,10 +117,10 @@ impl Config {
let value = value.map(|value| {
// Map the identity to (0, 0).
if value == pallas::Affine::identity() {
(pallas::Base::zero(), pallas::Base::zero())
(Assigned::Zero, Assigned::Zero)
} else {
let value = value.coordinates().unwrap();
(*value.x(), *value.y())
(value.x().into(), value.y().into())
}
});
@@ -146,7 +147,7 @@ impl Config {
let value = value.map(|value| {
let value = value.coordinates().unwrap();
(*value.x(), *value.y())
(value.x().into(), value.y().into())
});
self.assign_xy(value, offset, region)

View File

@@ -177,7 +177,7 @@ where
}
}
Ok((
NonIdentityEccPoint::from_coordinates_unchecked(x_a.0.evaluate(), y_a.evaluate()),
NonIdentityEccPoint::from_coordinates_unchecked(x_a.0, y_a),
zs_sum,
))
}

View File

@@ -82,6 +82,13 @@ impl<F: Field> Neg for Assigned<F> {
}
}
impl<F: Field> Neg for &Assigned<F> {
type Output = Assigned<F>;
fn neg(self) -> Self::Output {
-*self
}
}
impl<F: Field> Add for Assigned<F> {
type Output = Assigned<F>;
fn add(self, rhs: Assigned<F>) -> Assigned<F> {