diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 471d255af..8aaf03c58 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -15,7 +15,17 @@ from .compilation import ( Server, ) from .compilation.decorators import circuit, compiler -from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros +from .extensions import ( + AutoRounder, + LookupTable, + array, + one, + ones, + round_bit_pattern, + univariate, + zero, + zeros, +) from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH from .representation import Graph from .tracing.typing import ( diff --git a/concrete/numpy/extensions/__init__.py b/concrete/numpy/extensions/__init__.py index 4af2396cc..65229552d 100644 --- a/concrete/numpy/extensions/__init__.py +++ b/concrete/numpy/extensions/__init__.py @@ -4,7 +4,7 @@ Provide additional features that are not present in numpy. from .array import array from .ones import one, ones -from .round_bit_pattern import round_bit_pattern +from .round_bit_pattern import AutoRounder, round_bit_pattern from .table import LookupTable from .univariate import univariate from .zeros import zero, zeros diff --git a/concrete/numpy/extensions/round_bit_pattern.py b/concrete/numpy/extensions/round_bit_pattern.py index dd4af4c88..aef6a6908 100644 --- a/concrete/numpy/extensions/round_bit_pattern.py +++ b/concrete/numpy/extensions/round_bit_pattern.py @@ -2,22 +2,135 @@ Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups. """ +import threading from copy import deepcopy -from typing import List, Union +from typing import Any, Callable, Iterable, List, Tuple, Union import numpy as np +from ..dtypes import Integer +from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH from ..representation import Node from ..tracing import Tracer +from ..values import Value + +local = threading.local() + +# pylint: disable=protected-access +local._is_adjusting = False +# pylint: enable=protected-access + + +class Adjusting(BaseException): + """ + Adjusting class, to be used as early stop signal during adjustment. + """ + + rounder: "AutoRounder" + input_min: int + input_max: int + + def __init__(self, rounder: "AutoRounder", input_min: int, input_max: int): + super().__init__() + self.rounder = rounder + self.input_min = input_min + self.input_max = input_max + + +class AutoRounder: + """ + AutoRounder class, to optimize for number of msbs to keep druing round bit pattern operation. + """ + + target_msbs: int + + is_adjusted: bool + input_min: int + input_max: int + input_bit_width: int + lsbs_to_remove: int + + def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH): + # pylint: disable=protected-access + if local._is_adjusting: + raise RuntimeError( + "AutoRounders cannot be constructed during adjustment, " + "please construct AutoRounders outside the function and reference it" + ) + # pylint: enable=protected-access + + self.target_msbs = target_msbs + + self.is_adjusted = False + self.input_min = 0 + self.input_max = 0 + self.input_bit_width = 0 + self.lsbs_to_remove = 0 + + @staticmethod + def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]): + """ + Adjust AutoRounders in a function using an inputset. + """ + + # pylint: disable=protected-access,too-many-branches + + try: # extract underlying function for decorators + function = function.function # type: ignore + assert callable(function) + except AttributeError: + pass + + if local._is_adjusting: + raise RuntimeError("AutoRounders cannot be adjusted recursively") + + try: + local._is_adjusting = True + while True: + rounder = None + + for sample in inputset: + if not isinstance(sample, tuple): + sample = (sample,) + + try: + function(*sample) + except Adjusting as adjuster: + rounder = adjuster.rounder + + rounder.input_min = min(rounder.input_min, adjuster.input_min) + rounder.input_max = max(rounder.input_max, adjuster.input_max) + + input_value = Value.of([rounder.input_min, rounder.input_max]) + assert isinstance(input_value.dtype, Integer) + rounder.input_bit_width = input_value.dtype.bit_width + + if rounder.input_bit_width - rounder.lsbs_to_remove > rounder.target_msbs: + rounder.lsbs_to_remove = rounder.input_bit_width - rounder.target_msbs + else: + return + + if rounder is None: + raise ValueError("AutoRounders cannot be adjusted with an empty inputset") + + rounder.is_adjusted = True + + finally: + local._is_adjusting = False + + # pylint: enable=protected-access,too-many-branches def round_bit_pattern( - x: Union[int, List, np.ndarray, Tracer], - lsbs_to_remove: int, -) -> Union[int, List, np.ndarray, Tracer]: + x: Union[int, np.integer, List, np.ndarray, Tracer], + lsbs_to_remove: Union[int, AutoRounder], +) -> Union[int, np.integer, List, np.ndarray, Tracer]: """ Round the bit pattern of an integer. + If `lsbs_to_remove` is an `AutoRounder`: + corresponding integer value will be determined by adjustment process. + x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000 x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000 x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000 @@ -55,19 +168,44 @@ def round_bit_pattern( x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000 Args: - x (Union[int, np.ndarray, Tracer]): + x (Union[int, np.integer, np.ndarray, Tracer]): input to round - lsbs_to_remove (int): - number of the least significant numbers to remove + lsbs_to_remove (Union[int, AutoRounder]): + number of the least significant bits to remove + or an auto rounder object which will be used to determine the integer value Returns: - Union[int, np.ndarray, Tracer]: + Union[int, np.integer, np.ndarray, Tracer]: Tracer that respresents the operation during tracing rounded value(s) otherwise """ - def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]: + # pylint: disable=protected-access,too-many-branches + + if isinstance(lsbs_to_remove, AutoRounder): + if local._is_adjusting: + if not lsbs_to_remove.is_adjusted: + raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) + + elif not lsbs_to_remove.is_adjusted: + raise RuntimeError( + "AutoRounders cannot be used before adjustment, " + "please call AutoRounder.adjust with the function that will be compiled " + "and provide the exact inputset that will be used for compilation" + ) + + lsbs_to_remove = lsbs_to_remove.lsbs_to_remove + + assert isinstance(lsbs_to_remove, int) + + def evaluator( + x: Union[int, np.integer, np.ndarray], + lsbs_to_remove: int, + ) -> Union[int, np.integer, np.ndarray]: + if lsbs_to_remove == 0: + return x + unit = 1 << lsbs_to_remove half = 1 << lsbs_to_remove - 1 rounded = (x + half) // unit @@ -94,7 +232,9 @@ def round_bit_pattern( raise TypeError( f"Expected input elements to be integers but they are {type(x.dtype).__name__}" ) - elif not isinstance(x, int): + elif not isinstance(x, (int, np.integer)): raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}") return evaluator(x, lsbs_to_remove) + + # pylint: enable=protected-access,too-many-branches diff --git a/tests/execution/test_round_bit_pattern.py b/tests/execution/test_round_bit_pattern.py index 837ae7789..c14b493a7 100644 --- a/tests/execution/test_round_bit_pattern.py +++ b/tests/execution/test_round_bit_pattern.py @@ -6,11 +6,14 @@ import numpy as np import pytest import concrete.numpy as cnp +from concrete.numpy.representation.utils import format_constant @pytest.mark.parametrize( "sample,lsbs_to_remove,expected_output", [ + (0b_0000_0011, 0, 0b_0000_0011), + (0b_0000_0100, 0, 0b_0000_0100), (0b_0000_0000, 3, 0b_0000_0000), (0b_0000_0001, 3, 0b_0000_0000), (0b_0000_0010, 3, 0b_0000_0000), @@ -33,7 +36,7 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output): """ Test round bit pattern in evaluation context. """ - assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output + assert cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output @pytest.mark.parametrize( @@ -53,13 +56,18 @@ def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output): ), ], ) -def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message): +def test_bad_plain_round_bit_pattern( + sample, + lsbs_to_remove, + expected_error, + expected_message, +): """ Test round bit pattern in evaluation context with bad parameters. """ with pytest.raises(expected_error) as excinfo: - cnp.round_bit_pattern(sample, lsbs_to_remove) + cnp.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) assert str(excinfo.value) == expected_message @@ -85,7 +93,191 @@ def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers): @cnp.compiler({"x": "encrypted"}) def function(x): - return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64) + x_rounded = cnp.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove) + return np.abs(50 * np.sin(x_rounded)).astype(np.int64) circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True) helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits)) + + +def test_auto_rounding(helpers): + """ + Test round bit pattern with auto rounding. + """ + + # with auto adjust rounders configuration + # --------------------------------------- + + # y has the max value of 1999, so it's 11 bits + # our target msb is 5 bits, which means we need to remove 6 of the least significant bits + + rounder1 = cnp.AutoRounder(target_msbs=5) + + @cnp.compiler({"x": "encrypted"}) + def function1(x): + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder1) + return np.sqrt(z).astype(np.int64) + + inputset1 = range(1000) + function1.trace(inputset1, helpers.configuration(), auto_adjust_rounders=True) + + assert rounder1.lsbs_to_remove == 6 + + # manual + # ------ + + # y has the max value of 1999, so it's 11 bits + # our target msb is 3 bits, which means we need to remove 8 of the least significant bits + + rounder2 = cnp.AutoRounder(target_msbs=3) + + @cnp.compiler({"x": "encrypted"}) + def function2(x): + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder2) + return np.sqrt(z).astype(np.int64) + + inputset2 = range(1000) + cnp.AutoRounder.adjust(function2, inputset2) + + assert rounder2.lsbs_to_remove == 8 + + # complicated case + # ---------------- + + # have 2 ** 8 entries during evaluation, it won't matter after compilation + entries3 = list(range(2**8)) + # we have 8-bit inputs for this table, and we only want to use first 5-bits + for i in range(0, 2**8, 2**3): + # so we set every 8th entry to a 4-bit value + entries3[i] = np.random.randint(0, (2**4) - (2**2)) + # when this tlu is applied to an 8-bit value with 5-bit msb rounding, result will be 4-bits + table3 = cnp.LookupTable(entries3) + # and this is the rounder for table1, which should have lsbs_to_remove of 3 + rounder3 = cnp.AutoRounder(target_msbs=5) + + # have 2 ** 8 entries during evaluation, it won't matter after compilation + entries4 = list(range(2**8)) + # we have 4-bit inputs for this table, and we only want to use first 2-bits + for i in range(0, 2**4, 2**2): + # so we set every 4th entry to an 8-bit value + entries4[i] = np.random.randint(0, 2**8) + # when this tlu is applied to a 4-bit value with 2-bit msb rounding, result will be 8-bits + table4 = cnp.LookupTable(entries4) + # and this is the rounder for table2, which should have lsbs_to_remove of 2 + rounder4 = cnp.AutoRounder(target_msbs=2) + + @cnp.compiler({"x": "encrypted"}) + def function3(x): + a = cnp.round_bit_pattern(x, lsbs_to_remove=rounder3) + b = table3[a] + c = cnp.round_bit_pattern(b, lsbs_to_remove=rounder4) + d = table4[c] + return d + + inputset3 = range((2**8) - (2**3)) + circuit3 = function3.compile( + inputset3, + helpers.configuration(), + auto_adjust_rounders=True, + virtual=True, + ) + + assert rounder3.lsbs_to_remove == 3 + assert rounder4.lsbs_to_remove == 2 + + table3_formatted_string = format_constant(table3.table, 25) + table4_formatted_string = format_constant(table4.table, 25) + + helpers.check_str( + f""" + +%0 = x # EncryptedScalar +%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar +%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar +%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar +%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar +return %4 + + """, + str(circuit3), + ) + + +def test_auto_rounding_without_adjustment(): + """ + Test round bit pattern with auto rounding but without adjustment. + """ + + rounder = cnp.AutoRounder(target_msbs=5) + + def function(x): + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder) + return np.sqrt(z).astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + function(100) + + assert str(excinfo.value) == ( + "AutoRounders cannot be used before adjustment, " + "please call AutoRounder.adjust with the function that will be compiled " + "and provide the exact inputset that will be used for compilation" + ) + + +def test_auto_rounding_with_empty_inputset(): + """ + Test round bit pattern with auto rounding but with empty inputset. + """ + + rounder = cnp.AutoRounder(target_msbs=5) + + def function(x): + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder) + return np.sqrt(z).astype(np.int64) + + with pytest.raises(ValueError) as excinfo: + cnp.AutoRounder.adjust(function, []) + + assert str(excinfo.value) == "AutoRounders cannot be adjusted with an empty inputset" + + +def test_auto_rounding_recursive_adjustment(): + """ + Test round bit pattern with auto rounding but with recursive adjustment. + """ + + rounder = cnp.AutoRounder(target_msbs=5) + + def function(x): + cnp.AutoRounder.adjust(function, range(10)) + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=rounder) + return np.sqrt(z).astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + cnp.AutoRounder.adjust(function, range(10)) + + assert str(excinfo.value) == "AutoRounders cannot be adjusted recursively" + + +def test_auto_rounding_construct_in_function(): + """ + Test round bit pattern with auto rounding but rounder is constructed within the function. + """ + + def function(x): + y = x + 1000 + z = cnp.round_bit_pattern(y, lsbs_to_remove=cnp.AutoRounder(target_msbs=5)) + return np.sqrt(z).astype(np.int64) + + with pytest.raises(RuntimeError) as excinfo: + cnp.AutoRounder.adjust(function, range(10)) + + assert str(excinfo.value) == ( + "AutoRounders cannot be constructed during adjustment, " + "please construct AutoRounders outside the function and reference it" + )