refactor: sanitize signed inputs without table lookups

This commit is contained in:
Umut
2022-12-12 12:18:26 +01:00
parent 416ca6938b
commit d4428eaac5
2 changed files with 21 additions and 20 deletions

View File

@@ -158,14 +158,13 @@ class Client:
)
if is_valid:
if isinstance(arg, int) and arg < 0:
sanitized_args[index] = 2 * (expected_max + 1) + arg
is_signed = self.specs.input_signs[index]
sanitizer = 0 if not is_signed else 2 ** (width - 1)
if isinstance(arg, int):
sanitized_args[index] = arg + sanitizer
else:
sanitized_args[index] = np.where(
arg >= 0,
arg,
2 * (expected_max + 1) + arg,
).astype(np.uint64)
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)

View File

@@ -13,9 +13,10 @@ import numpy as np
from concrete.lang.dialects import fhe, fhelinalg
from mlir.dialects import arith, func
from mlir.ir import (
Attribute,
Context,
DenseElementsAttr,
InsertionPoint,
IntegerAttr,
IntegerType,
Location,
Module,
@@ -551,25 +552,26 @@ class GraphConverter:
if input_dtype.is_signed:
assert_that(input_value.is_encrypted)
n = input_dtype.bit_width
lut_range = np.arange(2**n)
lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n))
lut_type = RankedTensorType.get(
(2**n,), IntegerType.get_signless(64, context=ctx)
)
lut_attr = DenseElementsAttr.get(lut_values, context=ctx)
# ConstantOp is being decorated, and the init function is supposed to take more
# arguments than those pylint is considering
sanitizer_type = IntegerType.get_signless(n + 1)
sanitizer = 2 ** (n - 1)
if input_value.is_scalar:
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
else:
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
# pylint: disable=too-many-function-args
lut = arith.ConstantOp(lut_type, lut_attr).result
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
# pylint: enable=too-many-function-args
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
if input_value.is_scalar:
sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
else:
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
sanitized_args.append(sanitized)
else: