diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 569fa84ad..f4672ba02 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -10,7 +10,7 @@ from typing import cast # pylint: disable=no-name-in-module,no-member import numpy as np -from mlir.dialects import std as std_dialect +from mlir.dialects import arith as arith_dialect from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType from zamalang.dialects import hlfhe @@ -135,7 +135,7 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No data = node.constant_data int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) - return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result + return arith_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result if value_is_clear_tensor_integer(value): value = cast(TensorValue, value) @@ -156,7 +156,7 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No # we use `Attribute.parse` to let the underlying library do it by itself value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}") - return std_dialect.ConstantOp(vec_type, value_attr).result + return arith_dialect.ConstantOp(vec_type, value_attr).result raise TypeError(f"Don't support {value} constants") @@ -193,7 +193,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): out_dtype = cast(Integer, node.outputs[0].dtype) # Create table dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx) - tensor_lut = std_dialect.ConstantOp( + tensor_lut = arith_dialect.ConstantOp( RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)), dense_elem, ).result