diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index b6297eef8..406d034e6 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -71,6 +71,13 @@ def check_node_compatibility_with_mlir( return "only integer constants are supported" # pragma: no cover elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions + for inp in inputs: + if not value_is_integer(inp): + return ( + f"{node.op_name} with floating-point inputs " + f"is required to be fused to be supported" + ) + if node.op_kind == "TLU": assert_true( len( @@ -90,9 +97,6 @@ def check_node_compatibility_with_mlir( if node.op_name == "TLU": # pragma: no cover return "only unsigned integer lookup tables are supported" - if node.op_name.startswith("astype"): - return f"{node.op_name} is not supported without fusing" - # e.g., `np.absolute is not supported for the time being` return f"{node.op_name} is not supported for the time being" else: diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index c62c004f7..91ae343f5 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1101,7 +1101,7 @@ function you are trying to compile isn't supported for MLIR lowering %8 = mul(%7, %0) # EncryptedScalar ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported %9 = astype(%8, dtype=int32) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype is not supported without fusing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype with floating-point inputs is required to be fused to be supported return %9 """.strip() # noqa: E501