From c552a955c00dafa4e9779ae9a534d4ce021217af Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 31 Oct 2022 15:10:46 +0100 Subject: [PATCH] feat: introduce round bit pattern extension for virtual circuits --- concrete/numpy/__init__.py | 2 +- concrete/numpy/extensions/__init__.py | 1 + .../numpy/extensions/round_bit_pattern.py | 100 ++++++++++++++++++ tests/execution/test_round_bit_pattern.py | 91 ++++++++++++++++ 4 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 concrete/numpy/extensions/round_bit_pattern.py create mode 100644 tests/execution/test_round_bit_pattern.py diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index bb8af20ae..471d255af 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -15,7 +15,7 @@ from .compilation import ( Server, ) from .compilation.decorators import circuit, compiler -from .extensions import LookupTable, array, one, ones, univariate, zero, zeros +from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH from .representation import Graph from .tracing.typing import ( diff --git a/concrete/numpy/extensions/__init__.py b/concrete/numpy/extensions/__init__.py index 1dbb09aa2..4af2396cc 100644 --- a/concrete/numpy/extensions/__init__.py +++ b/concrete/numpy/extensions/__init__.py @@ -4,6 +4,7 @@ Provide additional features that are not present in numpy. from .array import array from .ones import one, ones +from .round_bit_pattern import round_bit_pattern from .table import LookupTable from .univariate import univariate from .zeros import zero, zeros diff --git a/concrete/numpy/extensions/round_bit_pattern.py b/concrete/numpy/extensions/round_bit_pattern.py new file mode 100644 index 000000000..dd4af4c88 --- /dev/null +++ b/concrete/numpy/extensions/round_bit_pattern.py @@ -0,0 +1,100 @@ +""" +Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups. +""" + +from copy import deepcopy +from typing import List, Union + +import numpy as np + +from ..representation import Node +from ..tracing import Tracer + + +def round_bit_pattern( + x: Union[int, List, np.ndarray, Tracer], + lsbs_to_remove: int, +) -> Union[int, List, np.ndarray, Tracer]: + """ + Round the bit pattern of an integer. + + x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000 + x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000 + x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000 + x = 0b_0000_0011 , lsbs_to_remove = 3 => 0b_0000_0000 + x = 0b_0000_0100 , lsbs_to_remove = 3 => 0b_0000_1000 + x = 0b_0000_0101 , lsbs_to_remove = 3 => 0b_0000_1000 + x = 0b_0000_0110 , lsbs_to_remove = 3 => 0b_0000_1000 + x = 0b_0000_0111 , lsbs_to_remove = 3 => 0b_0000_1000 + + x = 0b_1010_0000 , lsbs_to_remove = 3 => 0b_1010_0000 + x = 0b_1010_0001 , lsbs_to_remove = 3 => 0b_1010_0000 + x = 0b_1010_0010 , lsbs_to_remove = 3 => 0b_1010_0000 + x = 0b_1010_0011 , lsbs_to_remove = 3 => 0b_1010_0000 + x = 0b_1010_0100 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_0101 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_0110 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_0111 , lsbs_to_remove = 3 => 0b_1010_1000 + + x = 0b_1010_1000 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_1001 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_1010 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_1011 , lsbs_to_remove = 3 => 0b_1010_1000 + x = 0b_1010_1100 , lsbs_to_remove = 3 => 0b_1011_0000 + x = 0b_1010_1101 , lsbs_to_remove = 3 => 0b_1011_0000 + x = 0b_1010_1110 , lsbs_to_remove = 3 => 0b_1011_0000 + x = 0b_1010_1111 , lsbs_to_remove = 3 => 0b_1011_0000 + + x = 0b_1011_1000 , lsbs_to_remove = 3 => 0b_1011_1000 + x = 0b_1011_1001 , lsbs_to_remove = 3 => 0b_1011_1000 + x = 0b_1011_1010 , lsbs_to_remove = 3 => 0b_1011_1000 + x = 0b_1011_1011 , lsbs_to_remove = 3 => 0b_1011_1000 + x = 0b_1011_1100 , lsbs_to_remove = 3 => 0b_1100_0000 + x = 0b_1011_1101 , lsbs_to_remove = 3 => 0b_1100_0000 + x = 0b_1011_1110 , lsbs_to_remove = 3 => 0b_1100_0000 + x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000 + + Args: + x (Union[int, np.ndarray, Tracer]): + input to round + + lsbs_to_remove (int): + number of the least significant numbers to remove + + Returns: + Union[int, np.ndarray, Tracer]: + Tracer that respresents the operation during tracing + rounded value(s) otherwise + """ + + def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]: + unit = 1 << lsbs_to_remove + half = 1 << lsbs_to_remove - 1 + rounded = (x + half) // unit + return rounded * unit + + if isinstance(x, Tracer): + computation = Node.generic( + "round_bit_pattern", + [x.output], + deepcopy(x.output), + evaluator, + kwargs={"lsbs_to_remove": lsbs_to_remove}, + ) + return Tracer(computation, [x]) + + if isinstance(x, list): # pragma: no cover + try: + x = np.array(x) + except Exception: # pylint: disable=broad-except + pass + + if isinstance(x, np.ndarray): + if not np.issubdtype(x.dtype, np.integer): + raise TypeError( + f"Expected input elements to be integers but they are {type(x.dtype).__name__}" + ) + elif not isinstance(x, int): + raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}") + + return evaluator(x, lsbs_to_remove) diff --git a/tests/execution/test_round_bit_pattern.py b/tests/execution/test_round_bit_pattern.py new file mode 100644 index 000000000..837ae7789 --- /dev/null +++ b/tests/execution/test_round_bit_pattern.py @@ -0,0 +1,91 @@ +""" +Tests of execution of round bit pattern operation. +""" + +import numpy as np +import pytest + +import concrete.numpy as cnp + + +@pytest.mark.parametrize( + "sample,lsbs_to_remove,expected_output", + [ + (0b_0000_0000, 3, 0b_0000_0000), + (0b_0000_0001, 3, 0b_0000_0000), + (0b_0000_0010, 3, 0b_0000_0000), + (0b_0000_0011, 3, 0b_0000_0000), + (0b_0000_0100, 3, 0b_0000_1000), + (0b_0000_0101, 3, 0b_0000_1000), + (0b_0000_0110, 3, 0b_0000_1000), + (0b_0000_0111, 3, 0b_0000_1000), + (0b_0000_1000, 3, 0b_0000_1000), + (0b_0000_1001, 3, 0b_0000_1000), + (0b_0000_1010, 3, 0b_0000_1000), + (0b_0000_1011, 3, 0b_0000_1000), + (0b_0000_1100, 3, 0b_0001_0000), + (0b_0000_1101, 3, 0b_0001_0000), + (0b_0000_1110, 3, 0b_0001_0000), + (0b_0000_1111, 3, 0b_0001_0000), + ], +) +def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output): + """ + Test round bit pattern in evaluation context. + """ + assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output + + +@pytest.mark.parametrize( + "sample,lsbs_to_remove,expected_error,expected_message", + [ + ( + np.array([3.2, 4.1]), + 3, + TypeError, + "Expected input elements to be integers but they are dtype[float64]", + ), + ( + "foo", + 3, + TypeError, + "Expected input to be an int or a numpy array but it's str", + ), + ], +) +def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message): + """ + Test round bit pattern in evaluation context with bad parameters. + """ + + with pytest.raises(expected_error) as excinfo: + cnp.round_bit_pattern(sample, lsbs_to_remove) + + assert str(excinfo.value) == expected_message + + +@pytest.mark.parametrize( + "input_bits,lsbs_to_remove", + [ + (3, 1), + (3, 2), + (4, 1), + (4, 2), + (4, 3), + (5, 1), + (5, 2), + (5, 3), + (5, 4), + ], +) +def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers): + """ + Test round bit pattern in evaluation context. + """ + + @cnp.compiler({"x": "encrypted"}) + def function(x): + return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64) + + circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True) + helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))