mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: sanitize signed inputs without table lookups
This commit is contained in:
@@ -158,14 +158,13 @@ class Client:
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
if isinstance(arg, int) and arg < 0:
|
||||
sanitized_args[index] = 2 * (expected_max + 1) + arg
|
||||
is_signed = self.specs.input_signs[index]
|
||||
sanitizer = 0 if not is_signed else 2 ** (width - 1)
|
||||
|
||||
if isinstance(arg, int):
|
||||
sanitized_args[index] = arg + sanitizer
|
||||
else:
|
||||
sanitized_args[index] = np.where(
|
||||
arg >= 0,
|
||||
arg,
|
||||
2 * (expected_max + 1) + arg,
|
||||
).astype(np.uint64)
|
||||
sanitized_args[index] = (arg + sanitizer).astype(np.uint64)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
|
||||
@@ -13,9 +13,10 @@ import numpy as np
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from mlir.dialects import arith, func
|
||||
from mlir.ir import (
|
||||
Attribute,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
InsertionPoint,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
@@ -551,25 +552,26 @@ class GraphConverter:
|
||||
|
||||
if input_dtype.is_signed:
|
||||
assert_that(input_value.is_encrypted)
|
||||
|
||||
n = input_dtype.bit_width
|
||||
lut_range = np.arange(2**n)
|
||||
|
||||
lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n))
|
||||
lut_type = RankedTensorType.get(
|
||||
(2**n,), IntegerType.get_signless(64, context=ctx)
|
||||
)
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=ctx)
|
||||
# ConstantOp is being decorated, and the init function is supposed to take more
|
||||
# arguments than those pylint is considering
|
||||
sanitizer_type = IntegerType.get_signless(n + 1)
|
||||
sanitizer = 2 ** (n - 1)
|
||||
|
||||
if input_value.is_scalar:
|
||||
sanitizer_attr = IntegerAttr.get(sanitizer_type, sanitizer)
|
||||
else:
|
||||
sanitizer_type = RankedTensorType.get((1,), sanitizer_type)
|
||||
sanitizer_attr = Attribute.parse(f"dense<[{sanitizer}]> : {sanitizer_type}")
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
sanitizer_cst = arith.ConstantOp(sanitizer_type, sanitizer_attr)
|
||||
# pylint: enable=too-many-function-args
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
|
||||
if input_value.is_scalar:
|
||||
sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result
|
||||
sanitized = fhe.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
else:
|
||||
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
|
||||
sanitized = fhelinalg.SubEintIntOp(resulting_type, arg, sanitizer_cst).result
|
||||
|
||||
sanitized_args.append(sanitized)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user