feat(frontend): support arbitrary bitwidth in from_native

This commit is contained in:
youben11
2025-02-18 13:00:29 +01:00
parent f2bc945edc
commit f776e32d24

View File

@@ -997,50 +997,37 @@ class Converter:
def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
dtype: TFHERSIntegerType = node.properties["attributes"]["type"]
input_bit_width, carry_width, msg_width = (
output_tfhers_bit_width, carry_width, msg_width = (
dtype.bit_width,
dtype.carry_width,
dtype.msg_width,
)
native_int = preds[0]
# output_tfhers_bit_width is the target bitwidth while native_int.bit_width is the number
# of bits containing the actual message
assert output_tfhers_bit_width >= native_int.bit_width, (
f"output_tfhers_bit_width: {output_tfhers_bit_width}, "
f"native_int.bit_width: {native_int.bit_width}"
)
assert (
input_bit_width >= native_int.bit_width
), f"input_bit_width: {input_bit_width}, native_int.bit_width: {native_int.bit_width}"
assert (
input_bit_width % msg_width == 0
), f"input_bit_width: {input_bit_width}, msg_width: {msg_width}"
# TODO: we may want to remove the cast and work with the number of bits provided
# this will make the operation faster by avoiding unnecessary bit extractions
if native_int.bit_width < input_bit_width:
native_int = ctx.cast(
ctx.tensor(
(
ctx.eint(input_bit_width)
if native_int.is_unsigned
else ctx.esint(input_bit_width)
),
native_int.shape,
),
native_int,
)
output_tfhers_bit_width % msg_width == 0
), f"output_tfhers_bit_width: {output_tfhers_bit_width}, msg_width: {msg_width}"
# number of ciphertexts representing a single integer
num_cts = input_bit_width // msg_width
num_cts = output_tfhers_bit_width // msg_width
# number of ciphertexts representing the actual part of the message
num_full_cts_msg = native_int.bit_width // msg_width
num_half_cts_msg = 0 if native_int.bit_width % msg_width == 0 else 1
num_cts_msg = num_full_cts_msg + num_half_cts_msg
# adds a dimension of ciphertexts for the result
result_shape = native_int.shape + (num_cts,)
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
result_shape_msg = native_int.shape + (num_cts_msg,)
result_type_msg = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape_msg)
# we reshape so that we can concatenate later over the last dim (ciphertext dim)
reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,))
# TODO: remove this when we want to optimize computation so that we don't compute
# on empty ciphertexts, based on the bit_width assignment. (e.g. if only two lsb
# ciphertexts are used, then we don't want to extract bits from the remaining ones)
reshaped_native_int.set_original_bit_width(input_bit_width)
# we want to extract `msg_width` bits at a time, and store them
# in a `msg_width + carry_width` bits eint
bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
@@ -1051,10 +1038,36 @@ class Converter:
reshaped_native_int,
bits=slice(i * msg_width, (i + 1) * msg_width, 1),
)
for i in range(num_cts)
for i in range(num_full_cts_msg)
]
# we need to extract the remaining bits if any
if num_half_cts_msg:
extracted_bits.append(
ctx.extract_bits(
bits_shape,
reshaped_native_int,
bits=slice(
num_full_cts_msg * msg_width,
num_full_cts_msg * msg_width + native_int.bit_width % msg_width,
1,
),
)
)
result = ctx.concatenate(result_type_msg, extracted_bits, axis=-1)
# if we expect more ciphertexts than we have in the message: we pad the result with zeros
if num_cts_msg != num_cts:
result_shape = native_int.shape + (num_cts,)
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
padding = ctx.zeros(
ctx.tensor(
ctx.eint(msg_width + carry_width),
result_shape_msg[:-1] + (num_cts - num_cts_msg,),
)
)
result = ctx.concatenate(result_type, [result, padding], axis=-1)
result = ctx.concatenate(result_type, extracted_bits, axis=-1)
return ctx.change_partition(result, dest_partition=dtype.params)
# pylint: enable=missing-function-docstring,unused-argument