mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: resolve integration issues with the new compiler
This commit is contained in:
@@ -19,6 +19,7 @@ from mlir.ir import (
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
)
|
||||
|
||||
@@ -460,8 +461,7 @@ class GraphConverter:
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
|
||||
|
||||
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
|
||||
direct_replacements: Dict[str, str] = {}
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
@@ -493,13 +493,19 @@ class GraphConverter:
|
||||
node,
|
||||
preds,
|
||||
constant_cache,
|
||||
direct_replacements,
|
||||
from_elements_operations,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
|
||||
return results
|
||||
|
||||
direct_replacements = {}
|
||||
for placeholder, elements in from_elements_operations.items():
|
||||
element_names = [NodeConverter.mlir_name(element) for element in elements]
|
||||
actual_value = f"tensor.from_elements {', '.join(element_names)} : {placeholder.type}"
|
||||
direct_replacements[NodeConverter.mlir_name(placeholder)] = actual_value
|
||||
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
|
||||
@@ -4,7 +4,6 @@ Declaration of `NodeConverter` class.
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module,too-many-lines
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -52,7 +51,7 @@ class NodeConverter:
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
constant_cache: Dict[Tuple[Type, Attribute], OpResult]
|
||||
direct_replacements: Dict[str, str]
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]]
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
@@ -114,7 +113,7 @@ class NodeConverter:
|
||||
node: Node,
|
||||
preds: List[OpResult],
|
||||
constant_cache: Dict[Tuple[Type, Attribute], OpResult],
|
||||
direct_replacements: Dict[str, str],
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]],
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.graph = graph
|
||||
@@ -135,7 +134,7 @@ class NodeConverter:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
|
||||
self.constant_cache = constant_cache
|
||||
self.direct_replacements = direct_replacements
|
||||
self.from_elements_operations = from_elements_operations
|
||||
|
||||
def convert(self) -> OpResult:
|
||||
"""
|
||||
@@ -276,17 +275,10 @@ class NodeConverter:
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
number_of_values = len(preds)
|
||||
|
||||
intermediate_value = deepcopy(self.node.output)
|
||||
intermediate_value.shape = (number_of_values,)
|
||||
|
||||
intermediate_type = NodeConverter.value_to_mlir_type(self.ctx, intermediate_value)
|
||||
|
||||
pred_names = []
|
||||
processed_preds = []
|
||||
for pred, value in zip(preds, self.node.inputs):
|
||||
if value.is_encrypted or self.node.output.is_clear:
|
||||
pred_names.append(NodeConverter.mlir_name(pred))
|
||||
processed_preds.append(pred)
|
||||
continue
|
||||
|
||||
assert isinstance(value.dtype, Integer)
|
||||
@@ -296,35 +288,22 @@ class NodeConverter:
|
||||
zero = fhe.ZeroEintOp(zero_type).result
|
||||
|
||||
encrypted_pred = fhe.AddEintIntOp(zero_type, zero, pred).result
|
||||
pred_names.append(NodeConverter.mlir_name(encrypted_pred))
|
||||
processed_preds.append(encrypted_pred)
|
||||
|
||||
# `placeholder_result` will be replaced textually by `actual_value` below in graph converter
|
||||
# `tensor.from_elements` cannot be created from python bindings
|
||||
# that's why we use placeholder values and text manipulation
|
||||
|
||||
placeholder_result = fhe.ZeroTensorOp(intermediate_type).result
|
||||
placeholder_result_name = NodeConverter.mlir_name(placeholder_result)
|
||||
if self.node.output.is_clear:
|
||||
attribute = Attribute.parse(f"dense<0> : {resulting_type}")
|
||||
# pylint: disable=too-many-function-args
|
||||
placeholder_result = arith.ConstantOp(resulting_type, attribute).result
|
||||
# pylint: enable=too-many-function-args
|
||||
else:
|
||||
placeholder_result = fhe.ZeroTensorOp(resulting_type).result
|
||||
|
||||
actual_value = f"tensor.from_elements {', '.join(pred_names)} : {intermediate_type}"
|
||||
self.direct_replacements[placeholder_result_name] = actual_value
|
||||
|
||||
if self.node.output.shape == (number_of_values,):
|
||||
return placeholder_result
|
||||
|
||||
return tensor.ExpandShapeOp(
|
||||
resulting_type,
|
||||
placeholder_result,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[
|
||||
IntegerAttr.get(IntegerType.get_signless(64), i)
|
||||
for i in range(len(self.node.output.shape))
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
).result
|
||||
self.from_elements_operations[placeholder_result] = processed_preds
|
||||
return placeholder_result
|
||||
|
||||
def _convert_concat(self) -> OpResult:
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
Name Version License
|
||||
Pillow 9.2.0 Historical Permission Notice and Disclaimer (HPND)
|
||||
PyYAML 6.0 MIT License
|
||||
concrete-compiler 0.13.0 BSD-3
|
||||
concrete-compiler 0.15.0 BSD-3
|
||||
cycler 0.11.0 BSD License
|
||||
fonttools 4.34.4 MIT License
|
||||
kiwisolver 1.4.4 BSD License
|
||||
matplotlib 3.5.2 Python Software Foundation License
|
||||
matplotlib 3.5.3 Python Software Foundation License
|
||||
networkx 2.8.5 BSD License
|
||||
numpy 1.23.1 BSD License
|
||||
packaging 21.3 Apache Software License; BSD License
|
||||
pygraphviz 1.9 BSD License
|
||||
pyparsing 3.0.9 MIT License
|
||||
python-dateutil 2.8.2 Apache Software License; BSD License
|
||||
setuptools-scm 7.0.5 MIT License
|
||||
setuptools-scm 6.4.2 MIT License
|
||||
six 1.16.0 MIT License
|
||||
tomli 2.0.1 MIT License
|
||||
torch 1.12.0 BSD License
|
||||
torch 1.12.1 BSD License
|
||||
typing-extensions 4.3.0 Python Software Foundation License
|
||||
|
||||
1192
poetry.lock
generated
1192
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -43,7 +43,7 @@ matplotlib = "^3.5.1"
|
||||
numpy = "^1.22.0"
|
||||
pygraphviz = { version = "^1.7", optional = true }
|
||||
Pillow = "^9.0.0"
|
||||
concrete-compiler = "^0.13.0"
|
||||
concrete-compiler = "^0.15.0"
|
||||
torch = "^1.10.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -520,7 +520,7 @@ def deterministic_unary_function(x):
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 15], "shape": (3, 2)},
|
||||
},
|
||||
id="x + shape[0] + x.ndim + x.size",
|
||||
id="x + x.shape[0] + x.ndim + x.size",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (50 * np.sin(x.transpose())).astype(np.int64),
|
||||
|
||||
@@ -436,8 +436,8 @@ def test_graph_converter_bad_convert(
|
||||
range(3),
|
||||
"""
|
||||
|
||||
module {
|
||||
func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
%c1_i4 = arith.constant 1 : i4
|
||||
%cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3>
|
||||
|
||||
Reference in New Issue
Block a user