Rweber/zkp (#202)

Add less-than-equal comparison
This commit is contained in:
rickwebiii
2023-01-17 15:18:04 -08:00
committed by GitHub
parent 2f9acb39a7
commit 216fd1beda
7 changed files with 199 additions and 41 deletions

View File

@@ -5,41 +5,35 @@ use crate::{invoke_gadget, with_zkp_ctx, zkp::ZkpContextOps, ZkpError, ZkpResult
/**
* Expands a field element into N-bit unsigned binary.
*/
pub struct ToUInt<const N: usize>;
trait GetBit {
fn get_bit(&self, i: usize) -> u8;
pub struct ToUInt {
n: usize,
}
impl GetBit for BigInt {
fn get_bit(&self, i: usize) -> u8 {
const LIMB_SIZE: usize = std::mem::size_of::<u64>();
let limb = i / LIMB_SIZE;
let bit = i % LIMB_SIZE;
((self.limbs()[limb].0 & (0x1 << bit)) >> bit) as u8
impl ToUInt {
pub fn new(n: usize) -> Self {
Self { n }
}
}
impl<const N: usize> Gadget for ToUInt<N> {
impl Gadget for ToUInt {
fn compute_inputs(&self, gadget_inputs: &[BigInt]) -> ZkpResult<Vec<BigInt>> {
let val = gadget_inputs[0];
if N == 0 {
if self.n == 0 {
return Err(ZkpError::gadget_error("Cannot create 0-bit uint."));
}
if *val > BigInt::ONE.shl_vartime(N) {
if *val > BigInt::ONE.shl_vartime(self.n) {
return Err(ZkpError::gadget_error(&format!(
"Value too large for {N} bit unsigned int."
"Value too large for {} bit unsigned int.",
self.n
)));
}
let mut bits = vec![];
for i in 0..N {
bits.push(BigInt::from(val.get_bit(i)));
for i in 0..self.n {
bits.push(BigInt::from(val.bit_vartime(i)));
}
Ok(bits)
@@ -55,7 +49,7 @@ impl<const N: usize> Gadget for ToUInt<N> {
let mut muls = vec![];
let hidden_inputs = with_zkp_ctx(|ctx| {
for i in 0..N {
for i in 0..self.n {
let constant = BigInt::from(*BigInt::ONE << i);
let constant = ctx.add_constant(&constant);
@@ -95,7 +89,7 @@ impl<const N: usize> Gadget for ToUInt<N> {
}
fn hidden_input_count(&self) -> usize {
N
self.n
}
}

View File

@@ -98,7 +98,26 @@ where
/**
* Asserts that lhs equals rhs.
*/
fn constraint_eq(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
fn constrain_eq(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self>;
}
/**
* A trait for comparing 2 values.
*/
pub trait ConstrainCmpVarVar
where
Self: Sized + ZkpType,
{
/**
* Asserts that lhs is less than or equal rhs.
*
* # Remarks
* `bits` is the maximum number of bits required to represent
* `rhs - lhs` as an unsigned value.
* This value must be less than the number of bits needed to
* represent the field modulus.
*/
fn constrain_le_bounded(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>, bits: usize);
}
/**

View File

@@ -18,6 +18,8 @@ use crate::types::zkp::{
use crate as sunscreen;
use super::{ConstrainCmpVarVar, SubVar};
// Shouldn't need Clone + Copy, but there appears to be a bug in the Rust
// compiler that prevents ProgramNode from being Copy if we don't.
// https://github.com/rust-lang/rust/issues/104264
@@ -145,6 +147,16 @@ impl<F: BackendField> AddVar for NativeField<F> {
}
}
impl<F: BackendField> SubVar for NativeField<F> {
fn sub(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
let o = ctx.add_subtraction(lhs.ids[0], rhs.ids[0]);
ProgramNode::new(&[o])
})
}
}
impl<F: BackendField> MulVar for NativeField<F> {
fn mul(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
@@ -166,7 +178,7 @@ impl<F: BackendField> NegVar for NativeField<F> {
}
impl<F: BackendField> ConstrainEqVarVar for NativeField<F> {
fn constraint_eq(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
fn constrain_eq(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>) -> ProgramNode<Self> {
with_zkp_ctx(|ctx| {
let sub = ctx.add_subtraction(lhs.ids[0], rhs.ids[0]);
@@ -177,6 +189,14 @@ impl<F: BackendField> ConstrainEqVarVar for NativeField<F> {
}
}
impl<F: BackendField> ConstrainCmpVarVar for NativeField<F> {
fn constrain_le_bounded(lhs: ProgramNode<Self>, rhs: ProgramNode<Self>, bits: usize) {
let diff = rhs - lhs;
invoke_gadget(ToUInt::new(bits), &[diff.ids[0]]);
}
}
impl<F: BackendField> IntoProgramNode for NativeField<F> {
type Output = NativeField<F>;
@@ -198,7 +218,7 @@ pub trait ToBinary<F: BackendField> {
impl<F: BackendField> ToBinary<F> for ProgramNode<NativeField<F>> {
fn to_unsigned<const N: usize>(&self) -> [ProgramNode<NativeField<F>>; N] {
let bits = invoke_gadget(ToUInt::<N>, self.ids);
let bits = invoke_gadget(ToUInt::new(N), self.ids);
let mut vals = [*self; N];
@@ -215,7 +235,11 @@ mod tests {
use std::ops::{Add, Mul, Neg, Sub};
use curve25519_dalek::scalar::Scalar;
use sunscreen_zkp_backend::ZkpInto;
use sunscreen_compiler_macros::zkp_program;
use sunscreen_runtime::{Runtime, ZkpProgramInput};
use sunscreen_zkp_backend::{bulletproofs::BulletproofsBackend, ZkpBackend, ZkpInto};
use crate::{types::zkp::ConstrainCmp, Compiler};
use super::*;
@@ -277,7 +301,7 @@ mod tests {
}
impl ZkpInto<BigInt> for TestField {
fn into(self) -> BigInt {
fn zkp_into(self) -> BigInt {
unreachable!()
}
}
@@ -314,4 +338,53 @@ mod tests {
assert_eq!(x.val, BigInt::ONE);
}
#[test]
fn can_compare_le_bounded() {
#[zkp_program(backend = "bulletproofs")]
fn le<F: BackendField>(x: NativeField<F>, y: NativeField<F>) {
x.constrain_le_bounded(y, 16);
}
let app = Compiler::new()
.zkp_backend::<BulletproofsBackend>()
.zkp_program(le)
.compile()
.unwrap();
let runtime = Runtime::new_zkp(&BulletproofsBackend::new()).unwrap();
let program = app.get_zkp_program(le).unwrap();
let test_case = |x: i64, y: i64, expect_pass: bool| {
type BpField = NativeField<<BulletproofsBackend as ZkpBackend>::Field>;
let result = runtime.prove(
program,
vec![],
vec![],
vec![BpField::from(x), BpField::from(y)],
);
let proof = if expect_pass {
result.unwrap()
} else {
assert!(result.is_err());
return;
};
runtime
.verify(program, &proof, vec![], Vec::<ZkpProgramInput>::new())
.unwrap();
};
test_case(5, 6, true);
test_case(5, 5, true);
test_case(5, 1024, true);
test_case(-3, -2, true);
test_case(-2, -2, true);
test_case(-1, 3, true);
test_case(-1, -2, false);
test_case(6, 5, false);
}
}

View File

@@ -11,7 +11,7 @@ use crate::{
INDEX_ARENA,
};
use super::ConstrainEqVarVar;
use super::{ConstrainCmpVarVar, ConstrainEqVarVar};
#[derive(Clone, Copy)]
/**
@@ -206,6 +206,36 @@ where
* Constrains this native field to equal the right hand side
*/
fn constrain_eq(self, rhs: T) -> ProgramNode<Self::Output> {
V::constraint_eq(self.into_program_node(), rhs.into_program_node())
V::constrain_eq(self.into_program_node(), rhs.into_program_node())
}
}
/**
* Comparison constraints (e.g. lt, le, gt, ge).
*/
pub trait ConstrainCmp<Rhs> {
/**
* Constrain that this value is less than or equal than the RHS.
*
* # Remarks
* The number of bits is the maximum number of bits required to
* represent `rhs - lhs` as an unsigned integer. This allows you
* to dramatically reduce the number of constrains to perform a
* comparison.
*
* The maximum value for bits is f - 1 where f is the size of
* the backend field.
*/
fn constrain_le_bounded(self, rhs: Rhs, bits: usize);
}
impl<T, U, V> ConstrainCmp<T> for U
where
T: Sized + IntoProgramNode<Output = V>,
U: IntoProgramNode<Output = V> + Sized,
V: ZkpType + Sized + ConstrainCmpVarVar,
{
fn constrain_le_bounded(self, rhs: T, bits: usize) {
V::constrain_le_bounded(self.into_program_node(), rhs.into_program_node(), bits);
}
}

View File

@@ -502,13 +502,13 @@ fn scalar_to_uint<const N: usize>(x: &Scalar) -> UInt<N> {
}
impl crate::ZkpFrom<Scalar> for BigInt {
fn from(val: Scalar) -> BigInt {
fn zkp_from(val: Scalar) -> BigInt {
BigInt(scalar_to_uint(&val))
}
}
impl crate::ZkpFrom<&Scalar> for BigInt {
fn from(val: &Scalar) -> BigInt {
fn zkp_from(val: &Scalar) -> BigInt {
BigInt(scalar_to_uint(val))
}
}
@@ -527,7 +527,7 @@ mod tests {
let scalar = Scalar::try_from(a).unwrap();
assert_eq!(a, <BigInt as crate::ZkpFrom<Scalar>>::from(scalar));
assert_eq!(a, <BigInt as crate::ZkpFrom<Scalar>>::zkp_from(scalar));
}
#[test]
@@ -573,7 +573,7 @@ mod tests {
assert_eq!(
BigInt(l_min_1),
<BigInt as crate::ZkpFrom<Scalar>>::from(scalar)
<BigInt as crate::ZkpFrom<Scalar>>::zkp_from(scalar)
);
}

View File

@@ -329,7 +329,7 @@ where
let parents = query.get_unordered_operands(id)?;
for parent in parents {
if node_outputs[&parent].clone().into() != x {
if node_outputs[&parent].clone().zkp_into() != x {
return Err(Error::UnsatifiableConstraint(id));
}
}
@@ -344,7 +344,7 @@ where
let args = arg_indices
.iter()
.map(|x| node_outputs[x].clone().into())
.map(|x| node_outputs[x].clone().zkp_into())
.collect::<Vec<BigInt>>();
let hidden_inputs = g.compute_inputs(&args)?;
@@ -486,11 +486,11 @@ where
Operation::ConstantInput(x) => {
let val = constant_inputs[x].clone();
NodeInfo::new(ExecOperation::Constant(val.into()))
NodeInfo::new(ExecOperation::Constant(val.zkp_into()))
}
Operation::HiddenInput(_) => match node_outputs.as_ref() {
Some(node_outputs) => NodeInfo::new(ExecOperation::HiddenInput(Some(
node_outputs[&id].clone().into(),
node_outputs[&id].clone().zkp_into(),
))),
None => NodeInfo::new(ExecOperation::HiddenInput(None)),
},
@@ -558,7 +558,7 @@ where
let mut transforms = GraphTransforms::new();
if let Operation::PublicInput(x) = query.get_node(id).unwrap().operation {
let as_bigint: BigInt = public_inputs[x].clone().into();
let as_bigint: BigInt = public_inputs[x].clone().zkp_into();
let constraint = transforms.push(Transform::AddNode(NodeInfo {
operation: Operation::Constraint(as_bigint),

View File

@@ -20,7 +20,7 @@ use std::{
};
pub use crypto_bigint::UInt;
use crypto_bigint::{subtle::ConditionallySelectable, U512};
use crypto_bigint::{subtle::ConditionallySelectable, Limb, U512};
pub use error::*;
pub use exec::ExecutableZkpProgram;
pub use jit::{jit_prover, jit_verifier, CompiledZkpProgram, Operation};
@@ -213,6 +213,35 @@ impl BigInt {
Self(U512::from_be_hex(hex_str))
}
/**
* Returns `ceil(log_2(&self))`.
*
* # Remarks
* Runs in variable time with respect to `self`
*/
pub fn vartime_log2(&self) -> u32 {
let mut log2 = 0;
if *self == BigInt::ZERO {
panic!("Cannot compute log2(0).");
}
let bitlen = self.limbs().len() * std::mem::size_of::<Limb>() * 8;
for i in 0..bitlen {
let i = bitlen - 1 - i;
let bit_val = self.bit_vartime(i);
if bit_val == 1 && log2 == 0 {
log2 = i as u32;
} else if bit_val == 1 {
log2 += 1;
}
}
log2
}
/**
* The value 0.
*/
@@ -309,7 +338,7 @@ pub trait ZkpFrom<T> {
/**
* See [`std::convert::From::from`].
*/
fn from(val: T) -> Self;
fn zkp_from(val: T) -> Self;
}
/**
@@ -320,14 +349,27 @@ pub trait ZkpInto<T> {
/**
* See [`std::convert::Into::into`].
*/
fn into(self) -> T;
fn zkp_into(self) -> T;
}
impl<T, U> ZkpInto<T> for U
where
T: ZkpFrom<U>,
{
fn into(self) -> T {
T::from(self)
fn zkp_into(self) -> T {
T::zkp_from(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn log2_works() {
assert_eq!(BigInt::from(4u16).vartime_log2(), 2);
assert_eq!(BigInt::from(5u16).vartime_log2(), 3);
assert_eq!(BigInt::from(6u16).vartime_log2(), 3);
assert_eq!(BigInt::from(8u16).vartime_log2(), 3);
}
}