mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
feat(frontend): support arbitrary bitwidth in from_native
This commit is contained in:
@@ -997,50 +997,37 @@ class Converter:
|
||||
def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
dtype: TFHERSIntegerType = node.properties["attributes"]["type"]
|
||||
input_bit_width, carry_width, msg_width = (
|
||||
output_tfhers_bit_width, carry_width, msg_width = (
|
||||
dtype.bit_width,
|
||||
dtype.carry_width,
|
||||
dtype.msg_width,
|
||||
)
|
||||
native_int = preds[0]
|
||||
|
||||
# output_tfhers_bit_width is the target bitwidth while native_int.bit_width is the number
|
||||
# of bits containing the actual message
|
||||
assert output_tfhers_bit_width >= native_int.bit_width, (
|
||||
f"output_tfhers_bit_width: {output_tfhers_bit_width}, "
|
||||
f"native_int.bit_width: {native_int.bit_width}"
|
||||
)
|
||||
assert (
|
||||
input_bit_width >= native_int.bit_width
|
||||
), f"input_bit_width: {input_bit_width}, native_int.bit_width: {native_int.bit_width}"
|
||||
assert (
|
||||
input_bit_width % msg_width == 0
|
||||
), f"input_bit_width: {input_bit_width}, msg_width: {msg_width}"
|
||||
|
||||
# TODO: we may want to remove the cast and work with the number of bits provided
|
||||
# this will make the operation faster by avoiding unnecessary bit extractions
|
||||
if native_int.bit_width < input_bit_width:
|
||||
native_int = ctx.cast(
|
||||
ctx.tensor(
|
||||
(
|
||||
ctx.eint(input_bit_width)
|
||||
if native_int.is_unsigned
|
||||
else ctx.esint(input_bit_width)
|
||||
),
|
||||
native_int.shape,
|
||||
),
|
||||
native_int,
|
||||
)
|
||||
output_tfhers_bit_width % msg_width == 0
|
||||
), f"output_tfhers_bit_width: {output_tfhers_bit_width}, msg_width: {msg_width}"
|
||||
|
||||
# number of ciphertexts representing a single integer
|
||||
num_cts = input_bit_width // msg_width
|
||||
num_cts = output_tfhers_bit_width // msg_width
|
||||
# number of ciphertexts representing the actual part of the message
|
||||
num_full_cts_msg = native_int.bit_width // msg_width
|
||||
num_half_cts_msg = 0 if native_int.bit_width % msg_width == 0 else 1
|
||||
num_cts_msg = num_full_cts_msg + num_half_cts_msg
|
||||
|
||||
# adds a dimension of ciphertexts for the result
|
||||
result_shape = native_int.shape + (num_cts,)
|
||||
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
|
||||
result_shape_msg = native_int.shape + (num_cts_msg,)
|
||||
result_type_msg = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape_msg)
|
||||
|
||||
# we reshape so that we can concatenate later over the last dim (ciphertext dim)
|
||||
reshaped_native_int = ctx.reshape(native_int, native_int.shape + (1,))
|
||||
|
||||
# TODO: remove this when we want to optimize computation so that we don't compute
|
||||
# on empty ciphertexts, based on the bit_width assignment. (e.g. if only two lsb
|
||||
# ciphertexts are used, then we don't want to extract bits from the remaining ones)
|
||||
reshaped_native_int.set_original_bit_width(input_bit_width)
|
||||
|
||||
# we want to extract `msg_width` bits at a time, and store them
|
||||
# in a `msg_width + carry_width` bits eint
|
||||
bits_shape = ctx.tensor(ctx.eint(msg_width + carry_width), reshaped_native_int.shape)
|
||||
@@ -1051,10 +1038,36 @@ class Converter:
|
||||
reshaped_native_int,
|
||||
bits=slice(i * msg_width, (i + 1) * msg_width, 1),
|
||||
)
|
||||
for i in range(num_cts)
|
||||
for i in range(num_full_cts_msg)
|
||||
]
|
||||
# we need to extract the remaining bits if any
|
||||
if num_half_cts_msg:
|
||||
extracted_bits.append(
|
||||
ctx.extract_bits(
|
||||
bits_shape,
|
||||
reshaped_native_int,
|
||||
bits=slice(
|
||||
num_full_cts_msg * msg_width,
|
||||
num_full_cts_msg * msg_width + native_int.bit_width % msg_width,
|
||||
1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result = ctx.concatenate(result_type_msg, extracted_bits, axis=-1)
|
||||
|
||||
# if we expect more ciphertexts than we have in the message: we pad the result with zeros
|
||||
if num_cts_msg != num_cts:
|
||||
result_shape = native_int.shape + (num_cts,)
|
||||
result_type = ctx.tensor(ctx.eint(msg_width + carry_width), result_shape)
|
||||
padding = ctx.zeros(
|
||||
ctx.tensor(
|
||||
ctx.eint(msg_width + carry_width),
|
||||
result_shape_msg[:-1] + (num_cts - num_cts_msg,),
|
||||
)
|
||||
)
|
||||
result = ctx.concatenate(result_type, [result, padding], axis=-1)
|
||||
|
||||
result = ctx.concatenate(result_type, extracted_bits, axis=-1)
|
||||
return ctx.change_partition(result, dest_partition=dtype.params)
|
||||
|
||||
# pylint: enable=missing-function-docstring,unused-argument
|
||||
|
||||
Reference in New Issue
Block a user