From 8fbe5dab4d7bec2455f1418652f517a8b1ad80d8 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 16 Aug 2021 09:36:48 +0100 Subject: [PATCH] fix(mlir): unsigned are considered signless in compiler + changed the name of compiled func to main, as it's the default name to be executed later --- hdk/common/mlir/mlir_converter.py | 5 +++-- tests/common/mlir/test_mlir_converter.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hdk/common/mlir/mlir_converter.py b/hdk/common/mlir/mlir_converter.py index 2e443b785..d1f0d8668 100644 --- a/hdk/common/mlir/mlir_converter.py +++ b/hdk/common/mlir/mlir_converter.py @@ -58,7 +58,8 @@ class MLIRConverter: dtype = cast(Integer, value.data_type) if dtype.is_signed: return IntegerType.get_signed(dtype.bit_width, context=self.context) - return IntegerType.get_unsigned(dtype.bit_width, context=self.context) + # unsigned integer are considered signless in the compiler + return IntegerType.get_signless(dtype.bit_width, context=self.context) raise TypeError(f"can't convert value of type {type(value)} to MLIR type") def convert(self, op_graph: OPGraph) -> str: @@ -80,7 +81,7 @@ class MLIRConverter: ] @builtin.FuncOp.from_py_func(*func_types) - def fhe_circuit(*arg): + def main(*arg): ir_to_mlir_node = {} for arg_num, node in op_graph.input_nodes.items(): ir_to_mlir_node[node] = arg[arg_num] diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index 851405072..1537d0cf2 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -157,7 +157,7 @@ def test_hdk_clear_integer_to_mlir_type(is_signed): if is_signed: assert int_mlir == IntegerType.get_signed(5) else: - assert int_mlir == IntegerType.get_unsigned(5) + assert int_mlir == IntegerType.get_signless(5) def test_failing_hdk_to_mlir_type():