feat: cache constant operations during MLIR conversion

This commit is contained in:
Umut
2022-07-18 16:38:51 +02:00
parent 84bdb65529
commit a60891292b
3 changed files with 57 additions and 6 deletions

View File

@@ -473,6 +473,7 @@ class GraphConverter:
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = sanitized_args[arg_num]
constant_cache = {}
for node in nx.topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
@@ -483,6 +484,7 @@ class GraphConverter:
graph,
node,
preds,
constant_cache,
direct_replacements,
)
ir_to_mlir[node] = node_converter.convert()

View File

@@ -51,6 +51,7 @@ class NodeConverter:
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
constant_cache: Dict[Tuple[Type, Attribute], OpResult]
direct_replacements: Dict[str, str]
# pylint: enable=too-many-instance-attributes
@@ -112,6 +113,7 @@ class NodeConverter:
graph: Graph,
node: Node,
preds: List[OpResult],
constant_cache: Dict[Tuple[Type, Attribute], OpResult],
direct_replacements: Dict[str, str],
):
self.ctx = ctx
@@ -132,6 +134,7 @@ class NodeConverter:
else:
self.one_of_the_inputs_is_a_tensor = True
self.constant_cache = constant_cache
self.direct_replacements = direct_replacements
def convert(self) -> OpResult:
@@ -380,7 +383,7 @@ class NodeConverter:
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
return arith.ConstantOp(resulting_type, attr).result
return self._create_constant(resulting_type, attr).result
def _convert_conv1d(self) -> OpResult:
"""
@@ -556,7 +559,7 @@ class NodeConverter:
constant_attr = IntegerAttr.get(constant_type, 1)
zero = fhe.ZeroEintOp(resulting_type).result
one = arith.ConstantOp(constant_type, constant_attr).result
one = self._create_constant(constant_type, constant_attr).result
result = fhe.AddEintIntOp(resulting_type, zero, one).result
else:
@@ -569,7 +572,7 @@ class NodeConverter:
constant_attr = Attribute.parse(f"dense<[1]> : {constant_type}")
zeros = fhe.ZeroTensorOp(resulting_type).result
ones = arith.ConstantOp(constant_type, constant_attr).result
ones = self._create_constant(constant_type, constant_attr).result
result = fhelinalg.AddEintIntOp(resulting_type, zeros, ones).result
@@ -718,7 +721,7 @@ class NodeConverter:
for value, dimension_size in zip(index, input_shape):
value = int(value)
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
indices.append(arith.ConstantOp(index_type, attr).result)
indices.append(self._create_constant(index_type, attr).result)
return tensor.ExtractOp(output_type, self.preds[0], indices).result
offsets = []
@@ -920,7 +923,7 @@ class NodeConverter:
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
lut = arith.ConstantOp(lut_type, lut_attr).result
lut = self._create_constant(lut_type, lut_attr).result
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
pred = self.preds[variable_input_index]
@@ -940,7 +943,7 @@ class NodeConverter:
resulting_type,
pred,
lut,
arith.ConstantOp(map_type, map_attr).result,
self._create_constant(map_type, map_attr).result,
).result
else:
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
@@ -978,3 +981,10 @@ class NodeConverter:
result = fhe.ZeroTensorOp(resulting_type).result
return result
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute):
result = self.constant_cache.get((mlir_type, mlir_attribute))
if result is None:
result = arith.ConstantOp(mlir_type, mlir_attribute)
self.constant_cache[(mlir_type, mlir_attribute)] = result
return result

View File

@@ -412,4 +412,43 @@ def test_graph_converter_bad_convert(
helpers.check_str(expected_message, str(excinfo.value))
@pytest.mark.parametrize(
"function,inputset,expected_mlir",
[
pytest.param(
lambda x: 1 + cnp.LookupTable([4, 1, 2, 3])[x] + cnp.LookupTable([4, 1, 2, 3])[x + 1],
range(3),
"""
module {
func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%c1_i4 = arith.constant 1 : i4
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
%1 = "FHE.add_eint_int"(%arg0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%2 = "FHE.add_eint_int"(%0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
%3 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
%4 = "FHE.add_eint"(%2, %3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3>
return %4 : !FHE.eint<3>
}
}
""", # noqa: E501
# Notice that there is only a single 1 and a single table cst above
),
],
)
def test_constant_cache(function, inputset, expected_mlir, helpers):
"""
Test caching MLIR constants.
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)
helpers.check_str(expected_mlir, circuit.mlir)
# pylint: enable=line-too-long