fix(mlir): use arith dialect instead of std during MLIR conversion

This commit is contained in:
Umut
2021-11-02 12:12:06 +03:00
committed by Arthur Meyre
parent 9dedf1abc6
commit 759914dca6

View File

@@ -10,7 +10,7 @@ from typing import cast
# pylint: disable=no-name-in-module,no-member
import numpy as np
from mlir.dialects import std as std_dialect
from mlir.dialects import arith as arith_dialect
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from zamalang.dialects import hlfhe
@@ -135,7 +135,7 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
return arith_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
if value_is_clear_tensor_integer(value):
value = cast(TensorValue, value)
@@ -156,7 +156,7 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
# we use `Attribute.parse` to let the underlying library do it by itself
value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}")
return std_dialect.ConstantOp(vec_type, value_attr).result
return arith_dialect.ConstantOp(vec_type, value_attr).result
raise TypeError(f"Don't support {value} constants")
@@ -193,7 +193,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
tensor_lut = std_dialect.ConstantOp(
tensor_lut = arith_dialect.ConstantOp(
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
dense_elem,
).result