feat: introduce auto rounders

This commit is contained in:
Umut
2022-11-08 13:30:22 +01:00
parent eb601f5948
commit ccd3f9af6a
4 changed files with 358 additions and 16 deletions

View File

@@ -15,7 +15,17 @@ from .compilation import (
Server,
)
from .compilation.decorators import circuit, compiler
from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros
from .extensions import (
AutoRounder,
LookupTable,
array,
one,
ones,
round_bit_pattern,
univariate,
zero,
zeros,
)
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph
from .tracing.typing import (

View File

@@ -4,7 +4,7 @@ Provide additional features that are not present in numpy.
from .array import array
from .ones import one, ones
from .round_bit_pattern import round_bit_pattern
from .round_bit_pattern import AutoRounder, round_bit_pattern
from .table import LookupTable
from .univariate import univariate
from .zeros import zero, zeros

View File

@@ -2,22 +2,135 @@
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
"""
import threading
from copy import deepcopy
from typing import List, Union
from typing import Any, Callable, Iterable, List, Tuple, Union
import numpy as np
from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
local = threading.local()
# pylint: disable=protected-access
local._is_adjusting = False
# pylint: enable=protected-access
class Adjusting(BaseException):
"""
Adjusting class, to be used as early stop signal during adjustment.
"""
rounder: "AutoRounder"
input_min: int
input_max: int
def __init__(self, rounder: "AutoRounder", input_min: int, input_max: int):
super().__init__()
self.rounder = rounder
self.input_min = input_min
self.input_max = input_max
class AutoRounder:
"""
AutoRounder class, to optimize for number of msbs to keep druing round bit pattern operation.
"""
target_msbs: int
is_adjusted: bool
input_min: int
input_max: int
input_bit_width: int
lsbs_to_remove: int
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
# pylint: disable=protected-access
if local._is_adjusting:
raise RuntimeError(
"AutoRounders cannot be constructed during adjustment, "
"please construct AutoRounders outside the function and reference it"
)
# pylint: enable=protected-access
self.target_msbs = target_msbs
self.is_adjusted = False
self.input_min = 0
self.input_max = 0
self.input_bit_width = 0
self.lsbs_to_remove = 0
@staticmethod
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
"""
Adjust AutoRounders in a function using an inputset.
"""
# pylint: disable=protected-access,too-many-branches
try: # extract underlying function for decorators
function = function.function # type: ignore
assert callable(function)
except AttributeError:
pass
if local._is_adjusting:
raise RuntimeError("AutoRounders cannot be adjusted recursively")
try:
local._is_adjusting = True
while True:
rounder = None
for sample in inputset:
if not isinstance(sample, tuple):
sample = (sample,)
try:
function(*sample)
except Adjusting as adjuster:
rounder = adjuster.rounder
rounder.input_min = min(rounder.input_min, adjuster.input_min)
rounder.input_max = max(rounder.input_max, adjuster.input_max)
input_value = Value.of([rounder.input_min, rounder.input_max])
assert isinstance(input_value.dtype, Integer)
rounder.input_bit_width = input_value.dtype.bit_width
if rounder.input_bit_width - rounder.lsbs_to_remove > rounder.target_msbs:
rounder.lsbs_to_remove = rounder.input_bit_width - rounder.target_msbs
else:
return
if rounder is None:
raise ValueError("AutoRounders cannot be adjusted with an empty inputset")
rounder.is_adjusted = True
finally:
local._is_adjusting = False
# pylint: enable=protected-access,too-many-branches
def round_bit_pattern(
x: Union[int, List, np.ndarray, Tracer],
lsbs_to_remove: int,
) -> Union[int, List, np.ndarray, Tracer]:
x: Union[int, np.integer, List, np.ndarray, Tracer],
lsbs_to_remove: Union[int, AutoRounder],
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
If `lsbs_to_remove` is an `AutoRounder`:
corresponding integer value will be determined by adjustment process.
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
@@ -55,19 +168,44 @@ def round_bit_pattern(
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
Args:
x (Union[int, np.ndarray, Tracer]):
x (Union[int, np.integer, np.ndarray, Tracer]):
input to round
lsbs_to_remove (int):
number of the least significant numbers to remove
lsbs_to_remove (Union[int, AutoRounder]):
number of the least significant bits to remove
or an auto rounder object which will be used to determine the integer value
Returns:
Union[int, np.ndarray, Tracer]:
Union[int, np.integer, np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
rounded value(s) otherwise
"""
def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]:
# pylint: disable=protected-access,too-many-branches
if isinstance(lsbs_to_remove, AutoRounder):
if local._is_adjusting:
if not lsbs_to_remove.is_adjusted:
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x)))
elif not lsbs_to_remove.is_adjusted:
raise RuntimeError(
"AutoRounders cannot be used before adjustment, "
"please call AutoRounder.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
assert isinstance(lsbs_to_remove, int)
def evaluator(
x: Union[int, np.integer, np.ndarray],
lsbs_to_remove: int,
) -> Union[int, np.integer, np.ndarray]:
if lsbs_to_remove == 0:
return x
unit = 1 << lsbs_to_remove
half = 1 << lsbs_to_remove - 1
rounded = (x + half) // unit
@@ -94,7 +232,9 @@ def round_bit_pattern(
raise TypeError(
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
)
elif not isinstance(x, int):
elif not isinstance(x, (int, np.integer)):
raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}")
return evaluator(x, lsbs_to_remove)
# pylint: enable=protected-access,too-many-branches

View File

@@ -6,11 +6,14 @@ import numpy as np
import pytest
import concrete.numpy as cnp
from concrete.numpy.representation.utils import format_constant
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_output",
[
(0b_0000_0011, 0, 0b_0000_0011),
(0b_0000_0100, 0, 0b_0000_0100),
(0b_0000_0000, 3, 0b_0000_0000),
(0b_0000_0001, 3, 0b_0000_0000),
(0b_0000_0010, 3, 0b_0000_0000),
@@ -33,7 +36,7 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
"""
Test round bit pattern in evaluation context.
"""
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
assert cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output
@pytest.mark.parametrize(
@@ -53,13 +56,18 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
),
],
)
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
def test_bad_plain_round_bit_pattern(
sample,
lsbs_to_remove,
expected_error,
expected_message,
):
"""
Test round bit pattern in evaluation context with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
cnp.round_bit_pattern(sample, lsbs_to_remove)
cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove)
assert str(excinfo.value) == expected_message
@@ -85,7 +93,191 @@ def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
@cnp.compiler({"x": "encrypted"})
def function(x):
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
x_rounded = cnp.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return np.abs(50 * np.sin(x_rounded)).astype(np.int64)
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))
def test_auto_rounding(helpers):
"""
Test round bit pattern with auto rounding.
"""
# with auto adjust rounders configuration
# ---------------------------------------
# y has the max value of 1999, so it's 11 bits
# our target msb is 5 bits, which means we need to remove 6 of the least significant bits
rounder1 = cnp.AutoRounder(target_msbs=5)
@cnp.compiler({"x": "encrypted"})
def function1(x):
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder1)
return np.sqrt(z).astype(np.int64)
inputset1 = range(1000)
function1.trace(inputset1, helpers.configuration(), auto_adjust_rounders=True)
assert rounder1.lsbs_to_remove == 6
# manual
# ------
# y has the max value of 1999, so it's 11 bits
# our target msb is 3 bits, which means we need to remove 8 of the least significant bits
rounder2 = cnp.AutoRounder(target_msbs=3)
@cnp.compiler({"x": "encrypted"})
def function2(x):
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder2)
return np.sqrt(z).astype(np.int64)
inputset2 = range(1000)
cnp.AutoRounder.adjust(function2, inputset2)
assert rounder2.lsbs_to_remove == 8
# complicated case
# ----------------
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries3 = list(range(2**8))
# we have 8-bit inputs for this table, and we only want to use first 5-bits
for i in range(0, 2**8, 2**3):
# so we set every 8th entry to a 4-bit value
entries3[i] = np.random.randint(0, (2**4) - (2**2))
# when this tlu is applied to an 8-bit value with 5-bit msb rounding, result will be 4-bits
table3 = cnp.LookupTable(entries3)
# and this is the rounder for table1, which should have lsbs_to_remove of 3
rounder3 = cnp.AutoRounder(target_msbs=5)
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries4 = list(range(2**8))
# we have 4-bit inputs for this table, and we only want to use first 2-bits
for i in range(0, 2**4, 2**2):
# so we set every 4th entry to an 8-bit value
entries4[i] = np.random.randint(0, 2**8)
# when this tlu is applied to a 4-bit value with 2-bit msb rounding, result will be 8-bits
table4 = cnp.LookupTable(entries4)
# and this is the rounder for table2, which should have lsbs_to_remove of 2
rounder4 = cnp.AutoRounder(target_msbs=2)
@cnp.compiler({"x": "encrypted"})
def function3(x):
a = cnp.round_bit_pattern(x, lsbs_to_remove=rounder3)
b = table3[a]
c = cnp.round_bit_pattern(b, lsbs_to_remove=rounder4)
d = table4[c]
return d
inputset3 = range((2**8) - (2**3))
circuit3 = function3.compile(
inputset3,
helpers.configuration(),
auto_adjust_rounders=True,
virtual=True,
)
assert rounder3.lsbs_to_remove == 3
assert rounder4.lsbs_to_remove == 2
table3_formatted_string = format_constant(table3.table, 25)
table4_formatted_string = format_constant(table4.table, 25)
helpers.check_str(
f"""
%0 = x # EncryptedScalar<uint8>
%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar<uint8>
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar<uint4>
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
return %4
""",
str(circuit3),
)
def test_auto_rounding_without_adjustment():
"""
Test round bit pattern with auto rounding but without adjustment.
"""
rounder = cnp.AutoRounder(target_msbs=5)
def function(x):
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
function(100)
assert str(excinfo.value) == (
"AutoRounders cannot be used before adjustment, "
"please call AutoRounder.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
def test_auto_rounding_with_empty_inputset():
"""
Test round bit pattern with auto rounding but with empty inputset.
"""
rounder = cnp.AutoRounder(target_msbs=5)
def function(x):
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(ValueError) as excinfo:
cnp.AutoRounder.adjust(function, [])
assert str(excinfo.value) == "AutoRounders cannot be adjusted with an empty inputset"
def test_auto_rounding_recursive_adjustment():
"""
Test round bit pattern with auto rounding but with recursive adjustment.
"""
rounder = cnp.AutoRounder(target_msbs=5)
def function(x):
cnp.AutoRounder.adjust(function, range(10))
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
cnp.AutoRounder.adjust(function, range(10))
assert str(excinfo.value) == "AutoRounders cannot be adjusted recursively"
def test_auto_rounding_construct_in_function():
"""
Test round bit pattern with auto rounding but rounder is constructed within the function.
"""
def function(x):
y = x + 1000
z = cnp.round_bit_pattern(y, lsbs_to_remove=cnp.AutoRounder(target_msbs=5))
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
cnp.AutoRounder.adjust(function, range(10))
assert str(excinfo.value) == (
"AutoRounders cannot be constructed during adjustment, "
"please construct AutoRounders outside the function and reference it"
)