fix: resolve integration issues with the new compiler

This commit is contained in:
Umut
2022-08-08 16:17:35 +02:00
parent ac426c5f31
commit 6c6e657b6e
7 changed files with 1010 additions and 261 deletions

View File

@@ -19,6 +19,7 @@ from mlir.ir import (
IntegerType,
Location,
Module,
OpResult,
RankedTensorType,
)
@@ -460,8 +461,7 @@ class GraphConverter:
GraphConverter._offset_negative_lookup_table_inputs(graph)
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
direct_replacements: Dict[str, str] = {}
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
@@ -493,13 +493,19 @@ class GraphConverter:
node,
preds,
constant_cache,
direct_replacements,
from_elements_operations,
)
ir_to_mlir[node] = node_converter.convert()
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
return results
direct_replacements = {}
for placeholder, elements in from_elements_operations.items():
element_names = [NodeConverter.mlir_name(element) for element in elements]
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()

View File

@@ -4,7 +4,6 @@ Declaration of `NodeConverter` class.
# pylint: disable=no-member,no-name-in-module,too-many-lines
from copy import deepcopy
from typing import Dict, List, Tuple
import numpy as np
@@ -52,7 +51,7 @@ class NodeConverter:
one_of_the_inputs_is_a_tensor: bool
constant_cache: Dict[Tuple[Type, Attribute], OpResult]
direct_replacements: Dict[str, str]
from_elements_operations: Dict[OpResult, List[OpResult]]
# pylint: enable=too-many-instance-attributes
@@ -114,7 +113,7 @@ class NodeConverter:
node: Node,
preds: List[OpResult],
constant_cache: Dict[Tuple[Type, Attribute], OpResult],
direct_replacements: Dict[str, str],
from_elements_operations: Dict[OpResult, List[OpResult]],
):
self.ctx = ctx
self.graph = graph
@@ -135,7 +134,7 @@ class NodeConverter:
self.one_of_the_inputs_is_a_tensor = True
self.constant_cache = constant_cache
self.direct_replacements = direct_replacements
self.from_elements_operations = from_elements_operations
def convert(self) -> OpResult:
"""
@@ -276,17 +275,10 @@ class NodeConverter:
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
number_of_values = len(preds)
intermediate_value = deepcopy(self.node.output)
intermediate_value.shape = (number_of_values,)
intermediate_type = NodeConverter.value_to_mlir_type(self.ctx, intermediate_value)
pred_names = []
processed_preds = []
for pred, value in zip(preds, self.node.inputs):
if value.is_encrypted or self.node.output.is_clear:
pred_names.append(NodeConverter.mlir_name(pred))
processed_preds.append(pred)
continue
assert isinstance(value.dtype, Integer)
@@ -296,35 +288,22 @@ class NodeConverter:
zero = fhe.ZeroEintOp(zero_type).result
encrypted_pred = fhe.AddEintIntOp(zero_type, zero, pred).result
pred_names.append(NodeConverter.mlir_name(encrypted_pred))
processed_preds.append(encrypted_pred)
# `placeholder_result` will be replaced textually by `actual_value` below in graph converter
# `tensor.from_elements` cannot be created from python bindings
# that's why we use placeholder values and text manipulation
placeholder_result = fhe.ZeroTensorOp(intermediate_type).result
placeholder_result_name = NodeConverter.mlir_name(placeholder_result)
if self.node.output.is_clear:
attribute = Attribute.parse(f"dense<0> : {resulting_type}")
# pylint: disable=too-many-function-args
placeholder_result = arith.ConstantOp(resulting_type, attribute).result
# pylint: enable=too-many-function-args
else:
placeholder_result = fhe.ZeroTensorOp(resulting_type).result
actual_value = f"tensor.from_elements {', '.join(pred_names)} : {intermediate_type}"
self.direct_replacements[placeholder_result_name] = actual_value
if self.node.output.shape == (number_of_values,):
return placeholder_result
return tensor.ExpandShapeOp(
resulting_type,
placeholder_result,
ArrayAttr.get(
[
ArrayAttr.get(
[
IntegerAttr.get(IntegerType.get_signless(64), i)
for i in range(len(self.node.output.shape))
]
)
]
),
).result
self.from_elements_operations[placeholder_result] = processed_preds
return placeholder_result
def _convert_concat(self) -> OpResult:
"""

View File

@@ -1,19 +1,19 @@
Name Version License
Pillow 9.2.0 Historical Permission Notice and Disclaimer (HPND)
PyYAML 6.0 MIT License
concrete-compiler 0.13.0 BSD-3
concrete-compiler 0.15.0 BSD-3
cycler 0.11.0 BSD License
fonttools 4.34.4 MIT License
kiwisolver 1.4.4 BSD License
matplotlib 3.5.2 Python Software Foundation License
matplotlib 3.5.3 Python Software Foundation License
networkx 2.8.5 BSD License
numpy 1.23.1 BSD License
packaging 21.3 Apache Software License; BSD License
pygraphviz 1.9 BSD License
pyparsing 3.0.9 MIT License
python-dateutil 2.8.2 Apache Software License; BSD License
setuptools-scm 7.0.5 MIT License
setuptools-scm 6.4.2 MIT License
six 1.16.0 MIT License
tomli 2.0.1 MIT License
torch 1.12.0 BSD License
torch 1.12.1 BSD License
typing-extensions 4.3.0 Python Software Foundation License

1192
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -43,7 +43,7 @@ matplotlib = "^3.5.1"
numpy = "^1.22.0"
pygraphviz = { version = "^1.7", optional = true }
Pillow = "^9.0.0"
concrete-compiler = "^0.13.0"
concrete-compiler = "^0.15.0"
torch = "^1.10.2"
[tool.poetry.extras]

View File

@@ -520,7 +520,7 @@ def deterministic_unary_function(x):
{
"x": {"status": "encrypted", "range": [0, 15], "shape": (3, 2)},
},
id="x + shape[0] + x.ndim + x.size",
id="x + x.shape[0] + x.ndim + x.size",
),
pytest.param(
lambda x: (50 * np.sin(x.transpose())).astype(np.int64),

View File

@@ -436,8 +436,8 @@ def test_graph_converter_bad_convert(
range(3),
"""
module {
func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
module {
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%c1_i4 = arith.constant 1 : i4
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>