feat: introduce round bit pattern extension for virtual circuits

This commit is contained in:
Umut
2022-10-31 15:10:46 +01:00
parent 79951b51b7
commit c552a955c0
4 changed files with 193 additions and 1 deletions

View File

@@ -15,7 +15,7 @@ from .compilation import (
Server,
)
from .compilation.decorators import circuit, compiler
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph
from .tracing.typing import (

View File

@@ -4,6 +4,7 @@ Provide additional features that are not present in numpy.
from .array import array
from .ones import one, ones
from .round_bit_pattern import round_bit_pattern
from .table import LookupTable
from .univariate import univariate
from .zeros import zero, zeros

View File

@@ -0,0 +1,100 @@
"""
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
"""
from copy import deepcopy
from typing import List, Union
import numpy as np
from ..representation import Node
from ..tracing import Tracer
def round_bit_pattern(
x: Union[int, List, np.ndarray, Tracer],
lsbs_to_remove: int,
) -> Union[int, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0011 , lsbs_to_remove = 3 => 0b_0000_0000
x = 0b_0000_0100 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0101 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0110 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_0000_0111 , lsbs_to_remove = 3 => 0b_0000_1000
x = 0b_1010_0000 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0001 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0010 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0011 , lsbs_to_remove = 3 => 0b_1010_0000
x = 0b_1010_0100 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0101 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0110 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_0111 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1000 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1001 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1010 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1011 , lsbs_to_remove = 3 => 0b_1010_1000
x = 0b_1010_1100 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1101 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1110 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1010_1111 , lsbs_to_remove = 3 => 0b_1011_0000
x = 0b_1011_1000 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1001 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1010 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1011 , lsbs_to_remove = 3 => 0b_1011_1000
x = 0b_1011_1100 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1101 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1110 , lsbs_to_remove = 3 => 0b_1100_0000
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
Args:
x (Union[int, np.ndarray, Tracer]):
input to round
lsbs_to_remove (int):
number of the least significant numbers to remove
Returns:
Union[int, np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
rounded value(s) otherwise
"""
def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]:
unit = 1 << lsbs_to_remove
half = 1 << lsbs_to_remove - 1
rounded = (x + half) // unit
return rounded * unit
if isinstance(x, Tracer):
computation = Node.generic(
"round_bit_pattern",
[x.output],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},
)
return Tracer(computation, [x])
if isinstance(x, list): # pragma: no cover
try:
x = np.array(x)
except Exception: # pylint: disable=broad-except
pass
if isinstance(x, np.ndarray):
if not np.issubdtype(x.dtype, np.integer):
raise TypeError(
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
)
elif not isinstance(x, int):
raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}")
return evaluator(x, lsbs_to_remove)

View File

@@ -0,0 +1,91 @@
"""
Tests of execution of round bit pattern operation.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_output",
[
(0b_0000_0000, 3, 0b_0000_0000),
(0b_0000_0001, 3, 0b_0000_0000),
(0b_0000_0010, 3, 0b_0000_0000),
(0b_0000_0011, 3, 0b_0000_0000),
(0b_0000_0100, 3, 0b_0000_1000),
(0b_0000_0101, 3, 0b_0000_1000),
(0b_0000_0110, 3, 0b_0000_1000),
(0b_0000_0111, 3, 0b_0000_1000),
(0b_0000_1000, 3, 0b_0000_1000),
(0b_0000_1001, 3, 0b_0000_1000),
(0b_0000_1010, 3, 0b_0000_1000),
(0b_0000_1011, 3, 0b_0000_1000),
(0b_0000_1100, 3, 0b_0001_0000),
(0b_0000_1101, 3, 0b_0001_0000),
(0b_0000_1110, 3, 0b_0001_0000),
(0b_0000_1111, 3, 0b_0001_0000),
],
)
def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
"""
Test round bit pattern in evaluation context.
"""
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_error,expected_message",
[
(
np.array([3.2, 4.1]),
3,
TypeError,
"Expected input elements to be integers but they are dtype[float64]",
),
(
"foo",
3,
TypeError,
"Expected input to be an int or a numpy array but it's str",
),
],
)
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
"""
Test round bit pattern in evaluation context with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
cnp.round_bit_pattern(sample, lsbs_to_remove)
assert str(excinfo.value) == expected_message
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
(5, 1),
(5, 2),
(5, 3),
(5, 4),
],
)
def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in evaluation context.
"""
@cnp.compiler({"x": "encrypted"})
def function(x):
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))