feat(frontend-python): support using booleans with LookupTable

This commit is contained in:
Umut
2023-04-17 11:40:46 +02:00
parent 93991dd082
commit fa0e246613
2 changed files with 13 additions and 3 deletions

View File

@@ -59,7 +59,7 @@ class LookupTable:
def __repr__(self):
return str(list(self.table))
def __getitem__(self, key: Union[int, np.integer, np.ndarray, Tracer]):
def __getitem__(self, key: Union[int, np.integer, np.bool_, np.ndarray, Tracer]):
if not isinstance(key, Tracer):
return LookupTable.apply(key, self.table)
@@ -92,14 +92,14 @@ class LookupTable:
@staticmethod
def apply(
key: Union[int, np.integer, np.ndarray],
key: Union[int, np.integer, np.bool_, np.ndarray],
table: np.ndarray,
) -> Union[int, np.integer, np.ndarray]:
"""
Apply lookup table.
Args:
key (Union[int, np.integer, np.ndarray]):
key (Union[int, np.integer, np.bool_, np.ndarray]):
lookup key
table (np.ndarray):
@@ -114,6 +114,9 @@ class LookupTable:
if `table` cannot be looked up with `key`
"""
if isinstance(key, (np.bool_, np.ndarray)) and np.issubdtype(key.dtype, np.bool_):
key = key.astype(np.int64)
if not isinstance(key, (int, np.integer, np.ndarray)) or (
isinstance(key, np.ndarray) and not np.issubdtype(key.dtype, np.integer)
):

View File

@@ -677,6 +677,13 @@ def deterministic_unary_function(x):
},
id="np.squeeze(x, axis=1) where x.shape == (1, 1, 1)",
),
pytest.param(
lambda x: fhe.LookupTable([10, 5])[x > 5],
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="fhe.LookupTable([10, 5])[x > 5]",
),
],
)
def test_others(function, parameters, helpers):