diff --git a/sunscreen_compiler/src/types/fractional.rs b/sunscreen_compiler/src/types/fractional.rs index e27859b33..52e00f6e8 100644 --- a/sunscreen_compiler/src/types/fractional.rs +++ b/sunscreen_compiler/src/types/fractional.rs @@ -2,88 +2,130 @@ use seal::Plaintext as SealPlaintext; use crate::types::{GraphAdd, GraphMul}; use crate::{ + crate_version, types::{BfvType, CircuitNode, FheType, Type, Version}, - with_ctx, Params, crate_version + with_ctx, Params, }; use sunscreen_runtime::{ - InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName, TypeNameInstance + InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName, + TypeNameInstance, }; #[derive(Debug, Clone, Copy, PartialEq)] /** - * A quasi fixed-point representation capable of representing values with + * A quasi fixed-point representation capable of storing values with * both integer and fractional components. - * + * * # Remarks * This type is capable of addition, subtraction, and multiplication with no * more overhead than the [`Signed`](crate::types::Signed) type. * That is, addition and multiplication each take exactly one operation. - * + * * ## Representation - * See [SEAL v2.1 documentation](https://eprint.iacr.org/2017/224.pdf) for - * details. - * - * Recall that in BFV, the plaintext consists of 2 polynomials, each with - * `poly_degree` terms. `poly_degree` is a BFV scheme parameter that - * suncreen assigns for you depending on your circuit's noise requirements. - * This type packs the value into the first polynomial and doesn't use the - * second. - * + * Recall that in BFV, the plaintext consists of a polynomial with + * `poly_degree` terms. `poly_degree` is a BFV scheme parameter that (by + * default) suncreen assigns for you depending on your circuit's noise + * requirements. + * * This type represents values with both an integer and fractional component. - * The generic argument `INT_BITS` defines how many bits are reserved for the - * integer portion and the remaining `poly_degree - INT_BITS` bits store the - * fraction. - * - * Internally, this has a fairly funky representation. - * - * The integer bits map to the low order plaintext polynomial coefficients - * with the following relation: - * + * Semantically, you can think of this as a fixed-point value, but the + * implementation is somewhat different. The generic argument `INT_BITS` + * defines how many bits are reserved for the integer portion and the + * remaining `poly_degree - INT_BITS` bits store the fraction. + * + * Internally, this has a fairly funky representation that differs from + * traditional fixed-point. These variations allow the type to function + * properly under addition and multiplication in the absence of carries + * without needing to shift the decimal location after multiply operations. + * + * Each binary digit of the number maps to a single coefficient in the + * polynomial. The integer digits map to the low order plaintext polynomial + * coefficients with the following relation: + * * ```text * int(x) = sum_{i=0..INT_BITS}(c_i * 2^i) * ``` - * + * * where `c_i` is the coefficient for the `x^i` term of the polynomial. - * + * * Then, the fractional parts follow: - * + * * ```text * frac(x) = sum_{i=INT_BITS..n}(-c_i * 2^(n-i)) * ``` - * + * * where `n` is the `poly_degree`. - * + * * Note that the sign of the polynomial coefficient for fractional terms are - * inverted. - * + * inverted. The entire value is simply `int(x) + frac(x)`. + * + * Negative values encode every digit as negative, where a negative + * coefficient is any value above `plain_modulus / 2` up to + * `plain_modulus - 1`. The former is the most negative value, while the + * latter is the value -1. This is analogous to how 2's complement + * defines values above 0x80..00 to be negative with 0x80..00 being INT_MIN + * and 0xFF..FF being -1. In fact, if `plain_modulus = 2^N`, each digit + * behaves exactly like an `N`-bit 2's complement value. + * + * See [SEAL v2.1 documentation](https://eprint.iacr.org/2017/224.pdf) for + * full details. + * * ## Limitations * When encrypting a Fractional type, encoding will fail if: * * The underlying [`f64`] is infinite. * * The underlying [`f64`] is NaN * * The integer portion of the underlying [`f64`] exceeds the precision for * `INT_BITS` - * + * * Subnormals flush to 0, while normals are represented without precision loss. - * + * * While the numbers are binary, addition and multiplication are carryless. - * That is, carries don't propagate but instead increase the digit (i.e. - * polynomial coefficiens) beyond radix 2. However, they're still subject to + * That is, carries don't propagate but instead increase the digit (i.e. + * polynomial coefficients) beyond radix 2. However, they're still subject to * the scheme's `plain_modulus` specified during circuit compilation. * Repeated operations on an encrypted Fractional value will result in garbled * values if *any* digit overflows the `plain_modulus`. - * + * * Additionally numbers can experience more traditional overflow if the integer * portion exceeds `2^INT_BITS`. Finally, repeated multiplications of * numbers with decimal components introduce new decmal digits. If more than - * `2^(n-INT_BITS)` decimals appear, they will overflow into the integer + * `2^(n-INT_BITS)` decimals appear, they will overflow into the integer * portion and garble the number. + * + * To mitigate these issues, you should do some mix of the following: + * * Ensure inputs never result in either of these scenarios. Inputs to a + * circuit need to have small enough digits to avoid digit overflow, values + * are small enough to avoid integer underflow, and have few enough decimal + * places to avoid decimal underflow. + * * Alice can periodically decrypt values, call turn the [`Fractional`] into + * an [`f64`], turn that back into a [`Fractional`], and re-encrypt. This will + * propagate carries and truncate the decimal portion to at most 53 places. + * + * ```rust + * # use sunscreen_compiler::types::Fractional; + * # use sunscreen_compiler::{Ciphertext, PublicKey, Runtime, Result}; + * # use seal::{SecretKey}; + * + * fn normalize( + * runtime: &Runtime, + * ciphertext: &Ciphertext, + * secret_key: &SecretKey, + * public_key: &PublicKey + * ) -> Result { + * let val: Fractional::<64> = runtime.decrypt(&ciphertext, &secret_key)?; + * let val: f64 = val.into(); + * let val = Fractional::<64>::from(val); + * + * Ok(runtime.encrypt(val, &public_key)?) + * } + * ``` */ pub struct Fractional { val: f64, } -impl std::ops::Deref for Fractional { +impl std::ops::Deref for Fractional { type Target = f64; fn deref(&self) -> &Self::Target { @@ -91,27 +133,30 @@ impl std::ops::Deref for Fractional { } } -impl NumCiphertexts for Fractional { +impl NumCiphertexts for Fractional { const NUM_CIPHERTEXTS: usize = 1; } -impl TypeName for Fractional { +impl TypeName for Fractional { fn type_name() -> Type { - Type { name: format!("sunscreen_compiler::types::Fractional<{}>", INT_BITS), version: Version::parse(crate_version!()).expect("Crate version is not a valid semver") } + Type { + name: format!("sunscreen_compiler::types::Fractional<{}>", INT_BITS), + version: Version::parse(crate_version!()).expect("Crate version is not a valid semver"), + } } } -impl TypeNameInstance for Fractional { +impl TypeNameInstance for Fractional { fn type_name_instance(&self) -> Type { Self::type_name() } } -impl FheType for Fractional {} -impl BfvType for Fractional {} +impl FheType for Fractional {} +impl BfvType for Fractional {} -impl Fractional {} +impl Fractional {} -impl GraphAdd for Fractional { +impl GraphAdd for Fractional { type Left = Fractional; type Right = Fractional; @@ -127,7 +172,7 @@ impl GraphAdd for Fractional { } } -impl GraphMul for Fractional { +impl GraphMul for Fractional { type Left = Fractional; type Right = Fractional; @@ -143,17 +188,21 @@ impl GraphMul for Fractional { } } -impl TryIntoPlaintext for Fractional { +impl TryIntoPlaintext for Fractional { fn try_into_plaintext( &self, params: &Params, ) -> std::result::Result { if self.val.is_nan() { - return Err(sunscreen_runtime::Error::FheTypeError("Value is NaN.".to_owned())); + return Err(sunscreen_runtime::Error::FheTypeError( + "Value is NaN.".to_owned(), + )); } if self.val.is_infinite() { - return Err(sunscreen_runtime::Error::FheTypeError("Value is infinite.".to_owned())); + return Err(sunscreen_runtime::Error::FheTypeError( + "Value is infinite.".to_owned(), + )); } let mut seal_plaintext = SealPlaintext::new()?; @@ -163,7 +212,7 @@ impl TryIntoPlaintext for Fractional { // Just flush subnormals, as they're tiny and annoying. if self.val.is_subnormal() || self.val == 0.0 { return Ok(Plaintext { - inner: InnerPlaintext::Seal(vec![seal_plaintext]) + inner: InnerPlaintext::Seal(vec![seal_plaintext]), }); } @@ -175,19 +224,21 @@ impl TryIntoPlaintext for Fractional { // Coerce the f64 into a u64 so we can extract out the // sign, mantissa, and exponent. let as_u64: u64 = unsafe { std::mem::transmute(self.val) }; - + let sign_mask = 0x1 << 63; let mantissa_mask = 0xFFFFFFFFFFFFF; let exp_mask = !mantissa_mask & !sign_mask; - + // Mask of the mantissa and add the implicit 1 let mantissa = as_u64 & mantissa_mask | (mantissa_mask + 1); let exp = as_u64 & exp_mask; let power = (exp >> (f64::MANTISSA_DIGITS - 1)) as i64 - 1023; let sign = (as_u64 & sign_mask) >> 63; - if power > INT_BITS as i64 { - return Err(sunscreen_runtime::Error::FheTypeError("Out of range".to_owned())); + if power + 1 > INT_BITS as i64 { + return Err(sunscreen_runtime::Error::FheTypeError( + "Out of range".to_owned(), + )); } for i in 0..f64::MANTISSA_DIGITS { @@ -201,11 +252,7 @@ impl TryIntoPlaintext for Fractional { }; // For powers less than 0, we invert the sign. - let sign = if bit_power >= 0 { - sign - } else { - !sign & 0x1 - }; + let sign = if bit_power >= 0 { sign } else { !sign & 0x1 }; let coeff = if sign == 0 { bit_value @@ -226,7 +273,7 @@ impl TryIntoPlaintext for Fractional { } } -impl TryFromPlaintext for Fractional { +impl TryFromPlaintext for Fractional { fn try_from_plaintext( plaintext: &Plaintext, params: &Params, @@ -254,11 +301,7 @@ impl TryFromPlaintext for Fractional { let coeff = p[0].get_coefficient(i); // Reverse the sign of negative powers. - let sign = if power >= 0 { - 1f64 - } else { - -1f64 - }; + let sign = if power >= 0 { 1f64 } else { -1f64 }; if coeff < negative_cutoff { val += sign * coeff as f64 * (power as f64).exp2(); @@ -275,13 +318,13 @@ impl TryFromPlaintext for Fractional { } } -impl From for Fractional { +impl From for Fractional { fn from(val: f64) -> Self { Self { val } } } -impl Into for Fractional { +impl Into for Fractional { fn into(self) -> f64 { self.val } @@ -290,7 +333,7 @@ impl Into for Fractional { #[cfg(test)] mod tests { use super::*; - use crate::{SecurityLevel, SchemeType}; + use crate::{SchemeType, SecurityLevel}; #[test] fn can_encode_decode_fractional() { diff --git a/sunscreen_compiler/src/types/mod.rs b/sunscreen_compiler/src/types/mod.rs index e053b374a..5b4157d9f 100644 --- a/sunscreen_compiler/src/types/mod.rs +++ b/sunscreen_compiler/src/types/mod.rs @@ -11,9 +11,9 @@ pub use sunscreen_runtime::{ TypeNameInstance, Version, }; +pub use fractional::Fractional; pub use integer::{Signed, Unsigned}; pub use rational::Rational; -pub use fractional::Fractional; use std::ops::{Add, Div, Mul, Sub}; #[derive(Clone, Copy, Serialize, Deserialize)] diff --git a/sunscreen_compiler/tests/types.rs b/sunscreen_compiler/tests/types.rs index b0c5b1be4..a4ff5643d 100644 --- a/sunscreen_compiler/tests/types.rs +++ b/sunscreen_compiler/tests/types.rs @@ -1,5 +1,6 @@ use sunscreen_compiler::{ - circuit, types::Fractional, types::Rational, types::Signed, Compiler, PlainModulusConstraint, Runtime, + circuit, types::Fractional, types::Rational, types::Signed, Compiler, PlainModulusConstraint, + Runtime, }; #[test] @@ -237,7 +238,7 @@ fn can_sub_rational_numbers() { #[test] fn can_add_fractional_numbers() { #[circuit(scheme = "bfv")] - fn add(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> { + fn add(a: Fractional<64>, b: Fractional<64>) -> Fractional<64> { a + b } @@ -260,7 +261,7 @@ fn can_add_fractional_numbers() { let result = runtime.run(&circuit, vec![a, b], &public).unwrap(); - let c: Fractional::<64> = runtime.decrypt(&result[0], &secret).unwrap(); + let c: Fractional<64> = runtime.decrypt(&result[0], &secret).unwrap(); assert_eq!(c, (-6.28).try_into().unwrap()); } @@ -268,7 +269,7 @@ fn can_add_fractional_numbers() { #[test] fn can_mul_fractional_numbers() { #[circuit(scheme = "bfv")] - fn mul(a: Fractional::<64>, b: Fractional::<64>) -> Fractional::<64> { + fn mul(a: Fractional<64>, b: Fractional<64>) -> Fractional<64> { a * b } @@ -292,11 +293,11 @@ fn can_mul_fractional_numbers() { let result = runtime.run(&circuit, vec![a_c, b_c], &public).unwrap(); - let c: Fractional::<64> = runtime.decrypt(&result[0], &secret).unwrap(); + let c: Fractional<64> = runtime.decrypt(&result[0], &secret).unwrap(); assert_eq!(c, (a * b).try_into().unwrap()); }; - + test_mul(-3.14, -3.14); test_mul(1234., 5678.); test_mul(-1234., 5678.); @@ -308,4 +309,4 @@ fn can_mul_fractional_numbers() { // 4294967296 is 2^32. This should be about the largest multiplication we // can do with 64-bits of precision for the integer. test_mul(4294967295., 4294967296.); -} \ No newline at end of file +}