mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: introduce auto rounders
This commit is contained in:
@@ -15,7 +15,17 @@ from .compilation import (
|
||||
Server,
|
||||
)
|
||||
from .compilation.decorators import circuit, compiler
|
||||
from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros
|
||||
from .extensions import (
|
||||
AutoRounder,
|
||||
LookupTable,
|
||||
array,
|
||||
one,
|
||||
ones,
|
||||
round_bit_pattern,
|
||||
univariate,
|
||||
zero,
|
||||
zeros,
|
||||
)
|
||||
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
from .tracing.typing import (
|
||||
|
||||
@@ -4,7 +4,7 @@ Provide additional features that are not present in numpy.
|
||||
|
||||
from .array import array
|
||||
from .ones import one, ones
|
||||
from .round_bit_pattern import round_bit_pattern
|
||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||
from .table import LookupTable
|
||||
from .univariate import univariate
|
||||
from .zeros import zero, zeros
|
||||
|
||||
@@ -2,22 +2,135 @@
|
||||
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
local = threading.local()
|
||||
|
||||
# pylint: disable=protected-access
|
||||
local._is_adjusting = False
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class Adjusting(BaseException):
|
||||
"""
|
||||
Adjusting class, to be used as early stop signal during adjustment.
|
||||
"""
|
||||
|
||||
rounder: "AutoRounder"
|
||||
input_min: int
|
||||
input_max: int
|
||||
|
||||
def __init__(self, rounder: "AutoRounder", input_min: int, input_max: int):
|
||||
super().__init__()
|
||||
self.rounder = rounder
|
||||
self.input_min = input_min
|
||||
self.input_max = input_max
|
||||
|
||||
|
||||
class AutoRounder:
|
||||
"""
|
||||
AutoRounder class, to optimize for number of msbs to keep druing round bit pattern operation.
|
||||
"""
|
||||
|
||||
target_msbs: int
|
||||
|
||||
is_adjusted: bool
|
||||
input_min: int
|
||||
input_max: int
|
||||
input_bit_width: int
|
||||
lsbs_to_remove: int
|
||||
|
||||
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
|
||||
# pylint: disable=protected-access
|
||||
if local._is_adjusting:
|
||||
raise RuntimeError(
|
||||
"AutoRounders cannot be constructed during adjustment, "
|
||||
"please construct AutoRounders outside the function and reference it"
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
self.target_msbs = target_msbs
|
||||
|
||||
self.is_adjusted = False
|
||||
self.input_min = 0
|
||||
self.input_max = 0
|
||||
self.input_bit_width = 0
|
||||
self.lsbs_to_remove = 0
|
||||
|
||||
@staticmethod
|
||||
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
"""
|
||||
Adjust AutoRounders in a function using an inputset.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,too-many-branches
|
||||
|
||||
try: # extract underlying function for decorators
|
||||
function = function.function # type: ignore
|
||||
assert callable(function)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if local._is_adjusting:
|
||||
raise RuntimeError("AutoRounders cannot be adjusted recursively")
|
||||
|
||||
try:
|
||||
local._is_adjusting = True
|
||||
while True:
|
||||
rounder = None
|
||||
|
||||
for sample in inputset:
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
try:
|
||||
function(*sample)
|
||||
except Adjusting as adjuster:
|
||||
rounder = adjuster.rounder
|
||||
|
||||
rounder.input_min = min(rounder.input_min, adjuster.input_min)
|
||||
rounder.input_max = max(rounder.input_max, adjuster.input_max)
|
||||
|
||||
input_value = Value.of([rounder.input_min, rounder.input_max])
|
||||
assert isinstance(input_value.dtype, Integer)
|
||||
rounder.input_bit_width = input_value.dtype.bit_width
|
||||
|
||||
if rounder.input_bit_width - rounder.lsbs_to_remove > rounder.target_msbs:
|
||||
rounder.lsbs_to_remove = rounder.input_bit_width - rounder.target_msbs
|
||||
else:
|
||||
return
|
||||
|
||||
if rounder is None:
|
||||
raise ValueError("AutoRounders cannot be adjusted with an empty inputset")
|
||||
|
||||
rounder.is_adjusted = True
|
||||
|
||||
finally:
|
||||
local._is_adjusting = False
|
||||
|
||||
# pylint: enable=protected-access,too-many-branches
|
||||
|
||||
|
||||
def round_bit_pattern(
|
||||
x: Union[int, List, np.ndarray, Tracer],
|
||||
lsbs_to_remove: int,
|
||||
) -> Union[int, List, np.ndarray, Tracer]:
|
||||
x: Union[int, np.integer, List, np.ndarray, Tracer],
|
||||
lsbs_to_remove: Union[int, AutoRounder],
|
||||
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
|
||||
"""
|
||||
Round the bit pattern of an integer.
|
||||
|
||||
If `lsbs_to_remove` is an `AutoRounder`:
|
||||
corresponding integer value will be determined by adjustment process.
|
||||
|
||||
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
@@ -55,19 +168,44 @@ def round_bit_pattern(
|
||||
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
|
||||
Args:
|
||||
x (Union[int, np.ndarray, Tracer]):
|
||||
x (Union[int, np.integer, np.ndarray, Tracer]):
|
||||
input to round
|
||||
|
||||
lsbs_to_remove (int):
|
||||
number of the least significant numbers to remove
|
||||
lsbs_to_remove (Union[int, AutoRounder]):
|
||||
number of the least significant bits to remove
|
||||
or an auto rounder object which will be used to determine the integer value
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tracer]:
|
||||
Union[int, np.integer, np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
rounded value(s) otherwise
|
||||
"""
|
||||
|
||||
def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]:
|
||||
# pylint: disable=protected-access,too-many-branches
|
||||
|
||||
if isinstance(lsbs_to_remove, AutoRounder):
|
||||
if local._is_adjusting:
|
||||
if not lsbs_to_remove.is_adjusted:
|
||||
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x)))
|
||||
|
||||
elif not lsbs_to_remove.is_adjusted:
|
||||
raise RuntimeError(
|
||||
"AutoRounders cannot be used before adjustment, "
|
||||
"please call AutoRounder.adjust with the function that will be compiled "
|
||||
"and provide the exact inputset that will be used for compilation"
|
||||
)
|
||||
|
||||
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
|
||||
|
||||
assert isinstance(lsbs_to_remove, int)
|
||||
|
||||
def evaluator(
|
||||
x: Union[int, np.integer, np.ndarray],
|
||||
lsbs_to_remove: int,
|
||||
) -> Union[int, np.integer, np.ndarray]:
|
||||
if lsbs_to_remove == 0:
|
||||
return x
|
||||
|
||||
unit = 1 << lsbs_to_remove
|
||||
half = 1 << lsbs_to_remove - 1
|
||||
rounded = (x + half) // unit
|
||||
@@ -94,7 +232,9 @@ def round_bit_pattern(
|
||||
raise TypeError(
|
||||
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
|
||||
)
|
||||
elif not isinstance(x, int):
|
||||
elif not isinstance(x, (int, np.integer)):
|
||||
raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}")
|
||||
|
||||
return evaluator(x, lsbs_to_remove)
|
||||
|
||||
# pylint: enable=protected-access,too-many-branches
|
||||
|
||||
@@ -6,11 +6,14 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
from concrete.numpy.representation.utils import format_constant
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample,lsbs_to_remove,expected_output",
|
||||
[
|
||||
(0b_0000_0011, 0, 0b_0000_0011),
|
||||
(0b_0000_0100, 0, 0b_0000_0100),
|
||||
(0b_0000_0000, 3, 0b_0000_0000),
|
||||
(0b_0000_0001, 3, 0b_0000_0000),
|
||||
(0b_0000_0010, 3, 0b_0000_0000),
|
||||
@@ -33,7 +36,7 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
|
||||
"""
|
||||
Test round bit pattern in evaluation context.
|
||||
"""
|
||||
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
|
||||
assert cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -53,13 +56,18 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
|
||||
def test_bad_plain_round_bit_pattern(
|
||||
sample,
|
||||
lsbs_to_remove,
|
||||
expected_error,
|
||||
expected_message,
|
||||
):
|
||||
"""
|
||||
Test round bit pattern in evaluation context with bad parameters.
|
||||
"""
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
cnp.round_bit_pattern(sample, lsbs_to_remove)
|
||||
cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
@@ -85,7 +93,191 @@ def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
|
||||
x_rounded = cnp.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||
return np.abs(50 * np.sin(x_rounded)).astype(np.int64)
|
||||
|
||||
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
|
||||
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))
|
||||
|
||||
|
||||
def test_auto_rounding(helpers):
|
||||
"""
|
||||
Test round bit pattern with auto rounding.
|
||||
"""
|
||||
|
||||
# with auto adjust rounders configuration
|
||||
# ---------------------------------------
|
||||
|
||||
# y has the max value of 1999, so it's 11 bits
|
||||
# our target msb is 5 bits, which means we need to remove 6 of the least significant bits
|
||||
|
||||
rounder1 = cnp.AutoRounder(target_msbs=5)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function1(x):
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder1)
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
inputset1 = range(1000)
|
||||
function1.trace(inputset1, helpers.configuration(), auto_adjust_rounders=True)
|
||||
|
||||
assert rounder1.lsbs_to_remove == 6
|
||||
|
||||
# manual
|
||||
# ------
|
||||
|
||||
# y has the max value of 1999, so it's 11 bits
|
||||
# our target msb is 3 bits, which means we need to remove 8 of the least significant bits
|
||||
|
||||
rounder2 = cnp.AutoRounder(target_msbs=3)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function2(x):
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder2)
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
inputset2 = range(1000)
|
||||
cnp.AutoRounder.adjust(function2, inputset2)
|
||||
|
||||
assert rounder2.lsbs_to_remove == 8
|
||||
|
||||
# complicated case
|
||||
# ----------------
|
||||
|
||||
# have 2 ** 8 entries during evaluation, it won't matter after compilation
|
||||
entries3 = list(range(2**8))
|
||||
# we have 8-bit inputs for this table, and we only want to use first 5-bits
|
||||
for i in range(0, 2**8, 2**3):
|
||||
# so we set every 8th entry to a 4-bit value
|
||||
entries3[i] = np.random.randint(0, (2**4) - (2**2))
|
||||
# when this tlu is applied to an 8-bit value with 5-bit msb rounding, result will be 4-bits
|
||||
table3 = cnp.LookupTable(entries3)
|
||||
# and this is the rounder for table1, which should have lsbs_to_remove of 3
|
||||
rounder3 = cnp.AutoRounder(target_msbs=5)
|
||||
|
||||
# have 2 ** 8 entries during evaluation, it won't matter after compilation
|
||||
entries4 = list(range(2**8))
|
||||
# we have 4-bit inputs for this table, and we only want to use first 2-bits
|
||||
for i in range(0, 2**4, 2**2):
|
||||
# so we set every 4th entry to an 8-bit value
|
||||
entries4[i] = np.random.randint(0, 2**8)
|
||||
# when this tlu is applied to a 4-bit value with 2-bit msb rounding, result will be 8-bits
|
||||
table4 = cnp.LookupTable(entries4)
|
||||
# and this is the rounder for table2, which should have lsbs_to_remove of 2
|
||||
rounder4 = cnp.AutoRounder(target_msbs=2)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function3(x):
|
||||
a = cnp.round_bit_pattern(x, lsbs_to_remove=rounder3)
|
||||
b = table3[a]
|
||||
c = cnp.round_bit_pattern(b, lsbs_to_remove=rounder4)
|
||||
d = table4[c]
|
||||
return d
|
||||
|
||||
inputset3 = range((2**8) - (2**3))
|
||||
circuit3 = function3.compile(
|
||||
inputset3,
|
||||
helpers.configuration(),
|
||||
auto_adjust_rounders=True,
|
||||
virtual=True,
|
||||
)
|
||||
|
||||
assert rounder3.lsbs_to_remove == 3
|
||||
assert rounder4.lsbs_to_remove == 2
|
||||
|
||||
table3_formatted_string = format_constant(table3.table, 25)
|
||||
table4_formatted_string = format_constant(table4.table, 25)
|
||||
|
||||
helpers.check_str(
|
||||
f"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint8>
|
||||
%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar<uint8>
|
||||
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
|
||||
%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar<uint4>
|
||||
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
|
||||
return %4
|
||||
|
||||
""",
|
||||
str(circuit3),
|
||||
)
|
||||
|
||||
|
||||
def test_auto_rounding_without_adjustment():
|
||||
"""
|
||||
Test round bit pattern with auto rounding but without adjustment.
|
||||
"""
|
||||
|
||||
rounder = cnp.AutoRounder(target_msbs=5)
|
||||
|
||||
def function(x):
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
function(100)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"AutoRounders cannot be used before adjustment, "
|
||||
"please call AutoRounder.adjust with the function that will be compiled "
|
||||
"and provide the exact inputset that will be used for compilation"
|
||||
)
|
||||
|
||||
|
||||
def test_auto_rounding_with_empty_inputset():
|
||||
"""
|
||||
Test round bit pattern with auto rounding but with empty inputset.
|
||||
"""
|
||||
|
||||
rounder = cnp.AutoRounder(target_msbs=5)
|
||||
|
||||
def function(x):
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
cnp.AutoRounder.adjust(function, [])
|
||||
|
||||
assert str(excinfo.value) == "AutoRounders cannot be adjusted with an empty inputset"
|
||||
|
||||
|
||||
def test_auto_rounding_recursive_adjustment():
|
||||
"""
|
||||
Test round bit pattern with auto rounding but with recursive adjustment.
|
||||
"""
|
||||
|
||||
rounder = cnp.AutoRounder(target_msbs=5)
|
||||
|
||||
def function(x):
|
||||
cnp.AutoRounder.adjust(function, range(10))
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
cnp.AutoRounder.adjust(function, range(10))
|
||||
|
||||
assert str(excinfo.value) == "AutoRounders cannot be adjusted recursively"
|
||||
|
||||
|
||||
def test_auto_rounding_construct_in_function():
|
||||
"""
|
||||
Test round bit pattern with auto rounding but rounder is constructed within the function.
|
||||
"""
|
||||
|
||||
def function(x):
|
||||
y = x + 1000
|
||||
z = cnp.round_bit_pattern(y, lsbs_to_remove=cnp.AutoRounder(target_msbs=5))
|
||||
return np.sqrt(z).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
cnp.AutoRounder.adjust(function, range(10))
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"AutoRounders cannot be constructed during adjustment, "
|
||||
"please construct AutoRounders outside the function and reference it"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user