refactor(compilation): remove unnecessary check in compile.py

refs #645
This commit is contained in:
Arthur Meyre
2021-10-25 12:06:24 +02:00
parent c46d96aabf
commit 624143106f
3 changed files with 79 additions and 73 deletions

View File

@@ -2,7 +2,7 @@
import sys
import traceback
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
import numpy
from zamalang import CompilerEngine
@@ -21,7 +21,6 @@ from ..common.mlir.utils import (
)
from ..common.operator_graph import OPGraph
from ..common.optimization.topological import fuse_float_operations
from ..common.representation.intermediate import IntermediateNode
from ..common.values import BaseValue
from ..numpy.tracing import trace_numpy_function
from .np_dtypes_helpers import (
@@ -102,16 +101,6 @@ def _compile_numpy_function_into_op_graph_internal(
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph, compilation_artifacts)
# TODO: To be removed once we support more than integers
offending_non_integer_nodes: List[IntermediateNode] = []
op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes)
if not op_grap_is_int_prog:
raise ValueError(
f"{function_to_compile.__name__} cannot be compiled as it has nodes with either float"
f" inputs or outputs.\nOffending nodes : "
f"{', '.join(str(node) for node in offending_non_integer_nodes)}"
)
# Find bounds with the inputset
inputset_size, node_bounds_and_samples = eval_op_graph_bounds_on_inputset(
op_graph,

View File

@@ -35,8 +35,6 @@ def simple_fuse_not_output(x):
simple_fuse_not_output,
True,
id="simple_fuse_not_output",
marks=pytest.mark.xfail(strict=True),
# fails because it connot be compiled without topological optimizations
),
],
)

View File

@@ -448,12 +448,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
((4, 8), (3, 4), (0, 4)),
["x", "y", "z"],
),
pytest.param(
no_fuse_unhandled,
((-2, 2), (-2, 2)),
["x", "y"],
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
pytest.param(complicated_topology, ((0, 10),), ["x"]),
],
)
@@ -735,6 +729,7 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
)
# pylint: disable=line-too-long,unnecessary-lambda
@pytest.mark.parametrize(
"function,parameters,inputset,match",
[
@@ -745,10 +740,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>\n" # noqa: E501
"%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501
"return(%2)\n"
),
),
@@ -759,10 +754,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = Constant(-1) # ClearScalar<Integer<signed, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%1 = Constant(-1) # ClearScalar<Integer<signed, 2 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"return(%2)\n"
),
),
@@ -773,10 +768,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501
"return(%2)\n"
),
),
@@ -787,10 +782,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501
"return(%2)\n"
),
),
@@ -801,10 +796,10 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported\n" # noqa: E501
"return(%2)\n"
),
),
@@ -815,15 +810,15 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>\n" # noqa: E501
"%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported\n" # noqa: E501
"return(%2)\n"
),
),
pytest.param(
lambda x, y: numpy.dot(x, y), # pylint: disable=unnecessary-lambda
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
"y": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
@@ -843,12 +838,12 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = y # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = Dot(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = x # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501
"%1 = y # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501
"%2 = Dot(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported\n" # noqa: E501
"return(%2)\n"
),
),
@@ -866,22 +861,44 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"return(%1)\n"
),
),
pytest.param(
no_fuse_unhandled,
{"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))},
[(i, i) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n\n"
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%1 = Constant(2.8) # ClearScalar<Float<64 bits>>\n"
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%3 = Constant(9.3) # ClearScalar<Float<64 bits>>\n"
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
"return(%7)\n"
),
),
],
)
# pylint: enable=line-too-long,unnecessary-lambda
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
with pytest.raises(RuntimeError):
try:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
except RuntimeError as error:
assert str(error) == match
raise
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert str(excinfo.value) == match
def test_fail_with_intermediate_signed_values(default_compilation_configuration):
@@ -905,22 +922,24 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration)
show_mlir=True,
)
except RuntimeError as error:
# pylint: disable=line-too-long
match = (
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = y # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 # pylint: disable=line-too-long
"%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
"%0 = y # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501
"%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501
"%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
"%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
"%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>\n" # noqa: E501
"return(%7)\n"
)
# pylint: enable=line-too-long
assert str(error) == match
raise