mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: cache constant operations during MLIR conversion
This commit is contained in:
@@ -473,6 +473,7 @@ class GraphConverter:
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = sanitized_args[arg_num]
|
||||
|
||||
constant_cache = {}
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
@@ -483,6 +484,7 @@ class GraphConverter:
|
||||
graph,
|
||||
node,
|
||||
preds,
|
||||
constant_cache,
|
||||
direct_replacements,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
@@ -51,6 +51,7 @@ class NodeConverter:
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
constant_cache: Dict[Tuple[Type, Attribute], OpResult]
|
||||
direct_replacements: Dict[str, str]
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
@@ -112,6 +113,7 @@ class NodeConverter:
|
||||
graph: Graph,
|
||||
node: Node,
|
||||
preds: List[OpResult],
|
||||
constant_cache: Dict[Tuple[Type, Attribute], OpResult],
|
||||
direct_replacements: Dict[str, str],
|
||||
):
|
||||
self.ctx = ctx
|
||||
@@ -132,6 +134,7 @@ class NodeConverter:
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
|
||||
self.constant_cache = constant_cache
|
||||
self.direct_replacements = direct_replacements
|
||||
|
||||
def convert(self) -> OpResult:
|
||||
@@ -380,7 +383,7 @@ class NodeConverter:
|
||||
|
||||
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
|
||||
|
||||
return arith.ConstantOp(resulting_type, attr).result
|
||||
return self._create_constant(resulting_type, attr).result
|
||||
|
||||
def _convert_conv1d(self) -> OpResult:
|
||||
"""
|
||||
@@ -556,7 +559,7 @@ class NodeConverter:
|
||||
constant_attr = IntegerAttr.get(constant_type, 1)
|
||||
|
||||
zero = fhe.ZeroEintOp(resulting_type).result
|
||||
one = arith.ConstantOp(constant_type, constant_attr).result
|
||||
one = self._create_constant(constant_type, constant_attr).result
|
||||
|
||||
result = fhe.AddEintIntOp(resulting_type, zero, one).result
|
||||
else:
|
||||
@@ -569,7 +572,7 @@ class NodeConverter:
|
||||
constant_attr = Attribute.parse(f"dense<[1]> : {constant_type}")
|
||||
|
||||
zeros = fhe.ZeroTensorOp(resulting_type).result
|
||||
ones = arith.ConstantOp(constant_type, constant_attr).result
|
||||
ones = self._create_constant(constant_type, constant_attr).result
|
||||
|
||||
result = fhelinalg.AddEintIntOp(resulting_type, zeros, ones).result
|
||||
|
||||
@@ -718,7 +721,7 @@ class NodeConverter:
|
||||
for value, dimension_size in zip(index, input_shape):
|
||||
value = int(value)
|
||||
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
|
||||
indices.append(arith.ConstantOp(index_type, attr).result)
|
||||
indices.append(self._create_constant(index_type, attr).result)
|
||||
return tensor.ExtractOp(output_type, self.preds[0], indices).result
|
||||
|
||||
offsets = []
|
||||
@@ -920,7 +923,7 @@ class NodeConverter:
|
||||
|
||||
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
lut = self._create_constant(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
pred = self.preds[variable_input_index]
|
||||
@@ -940,7 +943,7 @@ class NodeConverter:
|
||||
resulting_type,
|
||||
pred,
|
||||
lut,
|
||||
arith.ConstantOp(map_type, map_attr).result,
|
||||
self._create_constant(map_type, map_attr).result,
|
||||
).result
|
||||
else:
|
||||
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
@@ -978,3 +981,10 @@ class NodeConverter:
|
||||
result = fhe.ZeroTensorOp(resulting_type).result
|
||||
|
||||
return result
|
||||
|
||||
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute):
|
||||
result = self.constant_cache.get((mlir_type, mlir_attribute))
|
||||
if result is None:
|
||||
result = arith.ConstantOp(mlir_type, mlir_attribute)
|
||||
self.constant_cache[(mlir_type, mlir_attribute)] = result
|
||||
return result
|
||||
|
||||
@@ -412,4 +412,43 @@ def test_graph_converter_bad_convert(
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,inputset,expected_mlir",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: 1 + cnp.LookupTable([4, 1, 2, 3])[x] + cnp.LookupTable([4, 1, 2, 3])[x + 1],
|
||||
range(3),
|
||||
"""
|
||||
|
||||
module {
|
||||
func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%c1_i4 = arith.constant 1 : i4
|
||||
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
%1 = "FHE.add_eint_int"(%arg0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
%2 = "FHE.add_eint_int"(%0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
|
||||
%3 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
%4 = "FHE.add_eint"(%2, %3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
|
||||
return %4 : !FHE.eint<3>
|
||||
}
|
||||
}
|
||||
|
||||
""", # noqa: E501
|
||||
# Notice that there is only a single 1 and a single table cst above
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_cache(function, inputset, expected_mlir, helpers):
|
||||
"""
|
||||
Test caching MLIR constants.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_mlir, circuit.mlir)
|
||||
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
Reference in New Issue
Block a user