mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(TLU): extend TLU to 2 ** bit_width elements
This commit is contained in:
@@ -10,6 +10,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_scalar_integer,
|
||||
)
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction
|
||||
|
||||
|
||||
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
|
||||
@@ -63,3 +64,21 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
|
||||
def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
"""Extend direct lookup tables to the maximum length the input bit width can support.
|
||||
|
||||
Args:
|
||||
op_graph: graph to update lookup tables for
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
|
||||
table = node.op_kwargs["table"]
|
||||
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
|
||||
expected_length = 2 ** bit_width
|
||||
if len(table) > expected_length:
|
||||
node.op_kwargs["table"] = table[:expected_length]
|
||||
else:
|
||||
repeat = expected_length // len(table)
|
||||
node.op_kwargs["table"] = (table * repeat)[:expected_length]
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
extend_direct_lookup_tables,
|
||||
is_graph_values_compatible_with_mlir,
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
@@ -133,6 +134,9 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# TODO: workaround extend LUT #359
|
||||
extend_direct_lookup_tables(op_graph)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,20 @@ def lut(x):
|
||||
return table[x]
|
||||
|
||||
|
||||
# TODO: remove workaround #359
|
||||
def lut_more_bits_than_table_length(x, y):
|
||||
"""Test lookup table when bit_width support longer LUT"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x] + y
|
||||
|
||||
|
||||
# TODO: remove workaround #359
|
||||
def lut_less_bits_than_table_length(x):
|
||||
"""Test lookup table when bit_width support smaller LUT"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x]
|
||||
|
||||
|
||||
def dot(x, y):
|
||||
"""Test dot"""
|
||||
return numpy.dot(x, y)
|
||||
@@ -184,6 +198,21 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
lut_more_bits_than_table_length,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(0, 16)),
|
||||
),
|
||||
(
|
||||
lut_less_bits_than_table_length,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user