mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: introduce round bit pattern extension for virtual circuits
This commit is contained in:
@@ -15,7 +15,7 @@ from .compilation import (
|
||||
Server,
|
||||
)
|
||||
from .compilation.decorators import circuit, compiler
|
||||
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
|
||||
from .extensions import LookupTable, array, one, ones, round_bit_pattern, univariate, zero, zeros
|
||||
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
from .tracing.typing import (
|
||||
|
||||
@@ -4,6 +4,7 @@ Provide additional features that are not present in numpy.
|
||||
|
||||
from .array import array
|
||||
from .ones import one, ones
|
||||
from .round_bit_pattern import round_bit_pattern
|
||||
from .table import LookupTable
|
||||
from .univariate import univariate
|
||||
from .zeros import zero, zeros
|
||||
|
||||
100
concrete/numpy/extensions/round_bit_pattern.py
Normal file
100
concrete/numpy/extensions/round_bit_pattern.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Declaration of `round_bit_pattern` function, to provide an interface for rounded table lookups.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
|
||||
|
||||
def round_bit_pattern(
|
||||
x: Union[int, List, np.ndarray, Tracer],
|
||||
lsbs_to_remove: int,
|
||||
) -> Union[int, List, np.ndarray, Tracer]:
|
||||
"""
|
||||
Round the bit pattern of an integer.
|
||||
|
||||
x = 0b_0000_0000 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0001 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0010 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0011 , lsbs_to_remove = 3 => 0b_0000_0000
|
||||
x = 0b_0000_0100 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0101 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0110 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
x = 0b_0000_0111 , lsbs_to_remove = 3 => 0b_0000_1000
|
||||
|
||||
x = 0b_1010_0000 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0001 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0010 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0011 , lsbs_to_remove = 3 => 0b_1010_0000
|
||||
x = 0b_1010_0100 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0101 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0110 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_0111 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
|
||||
x = 0b_1010_1000 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1001 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1010 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1011 , lsbs_to_remove = 3 => 0b_1010_1000
|
||||
x = 0b_1010_1100 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1101 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1110 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
x = 0b_1010_1111 , lsbs_to_remove = 3 => 0b_1011_0000
|
||||
|
||||
x = 0b_1011_1000 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1001 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1010 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1011 , lsbs_to_remove = 3 => 0b_1011_1000
|
||||
x = 0b_1011_1100 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1101 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1110 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
x = 0b_1011_1111 , lsbs_to_remove = 3 => 0b_1100_0000
|
||||
|
||||
Args:
|
||||
x (Union[int, np.ndarray, Tracer]):
|
||||
input to round
|
||||
|
||||
lsbs_to_remove (int):
|
||||
number of the least significant numbers to remove
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
rounded value(s) otherwise
|
||||
"""
|
||||
|
||||
def evaluator(x: Union[int, np.ndarray], lsbs_to_remove: int) -> Union[int, np.ndarray]:
|
||||
unit = 1 << lsbs_to_remove
|
||||
half = 1 << lsbs_to_remove - 1
|
||||
rounded = (x + half) // unit
|
||||
return rounded * unit
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
computation = Node.generic(
|
||||
"round_bit_pattern",
|
||||
[x.output],
|
||||
deepcopy(x.output),
|
||||
evaluator,
|
||||
kwargs={"lsbs_to_remove": lsbs_to_remove},
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
|
||||
if isinstance(x, list): # pragma: no cover
|
||||
try:
|
||||
x = np.array(x)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if isinstance(x, np.ndarray):
|
||||
if not np.issubdtype(x.dtype, np.integer):
|
||||
raise TypeError(
|
||||
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
|
||||
)
|
||||
elif not isinstance(x, int):
|
||||
raise TypeError(f"Expected input to be an int or a numpy array but it's {type(x).__name__}")
|
||||
|
||||
return evaluator(x, lsbs_to_remove)
|
||||
91
tests/execution/test_round_bit_pattern.py
Normal file
91
tests/execution/test_round_bit_pattern.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Tests of execution of round bit pattern operation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample,lsbs_to_remove,expected_output",
|
||||
[
|
||||
(0b_0000_0000, 3, 0b_0000_0000),
|
||||
(0b_0000_0001, 3, 0b_0000_0000),
|
||||
(0b_0000_0010, 3, 0b_0000_0000),
|
||||
(0b_0000_0011, 3, 0b_0000_0000),
|
||||
(0b_0000_0100, 3, 0b_0000_1000),
|
||||
(0b_0000_0101, 3, 0b_0000_1000),
|
||||
(0b_0000_0110, 3, 0b_0000_1000),
|
||||
(0b_0000_0111, 3, 0b_0000_1000),
|
||||
(0b_0000_1000, 3, 0b_0000_1000),
|
||||
(0b_0000_1001, 3, 0b_0000_1000),
|
||||
(0b_0000_1010, 3, 0b_0000_1000),
|
||||
(0b_0000_1011, 3, 0b_0000_1000),
|
||||
(0b_0000_1100, 3, 0b_0001_0000),
|
||||
(0b_0000_1101, 3, 0b_0001_0000),
|
||||
(0b_0000_1110, 3, 0b_0001_0000),
|
||||
(0b_0000_1111, 3, 0b_0001_0000),
|
||||
],
|
||||
)
|
||||
def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
|
||||
"""
|
||||
Test round bit pattern in evaluation context.
|
||||
"""
|
||||
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample,lsbs_to_remove,expected_error,expected_message",
|
||||
[
|
||||
(
|
||||
np.array([3.2, 4.1]),
|
||||
3,
|
||||
TypeError,
|
||||
"Expected input elements to be integers but they are dtype[float64]",
|
||||
),
|
||||
(
|
||||
"foo",
|
||||
3,
|
||||
TypeError,
|
||||
"Expected input to be an int or a numpy array but it's str",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
|
||||
"""
|
||||
Test round bit pattern in evaluation context with bad parameters.
|
||||
"""
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
cnp.round_bit_pattern(sample, lsbs_to_remove)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_bits,lsbs_to_remove",
|
||||
[
|
||||
(3, 1),
|
||||
(3, 2),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 3),
|
||||
(5, 1),
|
||||
(5, 2),
|
||||
(5, 3),
|
||||
(5, 4),
|
||||
],
|
||||
)
|
||||
def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
|
||||
"""
|
||||
Test round bit pattern in evaluation context.
|
||||
"""
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
|
||||
|
||||
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
|
||||
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))
|
||||
Reference in New Issue
Block a user