mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add more bivariate operations
which are supported if one of the operands is a constant refs #126
This commit is contained in:
committed by
Benoit Chevallier
parent
ab1f0f3c4a
commit
2da3895f1a
@@ -259,13 +259,13 @@ class NPTracer(BaseTracer):
|
||||
numpy.arctan,
|
||||
numpy.arctan2,
|
||||
numpy.arctanh,
|
||||
# numpy.bitwise_and,
|
||||
# numpy.bitwise_or,
|
||||
# numpy.bitwise_xor,
|
||||
numpy.bitwise_and,
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
# numpy.conjugate,
|
||||
# numpy.copysign,
|
||||
numpy.copysign,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
@@ -278,51 +278,51 @@ class NPTracer(BaseTracer):
|
||||
numpy.fabs,
|
||||
numpy.float_power,
|
||||
numpy.floor,
|
||||
# numpy.floor_divide,
|
||||
# numpy.fmax,
|
||||
# numpy.fmin,
|
||||
# numpy.fmod,
|
||||
numpy.floor_divide,
|
||||
numpy.fmax,
|
||||
numpy.fmin,
|
||||
numpy.fmod,
|
||||
# numpy.frexp,
|
||||
# numpy.gcd,
|
||||
numpy.gcd,
|
||||
# numpy.greater,
|
||||
# numpy.greater_equal,
|
||||
# numpy.heaviside,
|
||||
# numpy.hypot,
|
||||
numpy.heaviside,
|
||||
numpy.hypot,
|
||||
# numpy.invert,
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
# numpy.isnat,
|
||||
# numpy.lcm,
|
||||
# numpy.ldexp,
|
||||
# numpy.left_shift,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
# numpy.less,
|
||||
# numpy.less_equal,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
# numpy.logaddexp,
|
||||
# numpy.logaddexp2,
|
||||
# numpy.logical_and,
|
||||
# numpy.logical_not,
|
||||
# numpy.logical_or,
|
||||
# numpy.logical_xor,
|
||||
numpy.logaddexp,
|
||||
numpy.logaddexp2,
|
||||
numpy.logical_and,
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
# numpy.matmul,
|
||||
# numpy.maximum,
|
||||
# numpy.minimum,
|
||||
numpy.maximum,
|
||||
numpy.minimum,
|
||||
# numpy.modf,
|
||||
# numpy.multiply,
|
||||
numpy.negative,
|
||||
# numpy.nextafter,
|
||||
numpy.nextafter,
|
||||
# numpy.not_equal,
|
||||
numpy.positive,
|
||||
# numpy.power,
|
||||
numpy.power,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
numpy.reciprocal,
|
||||
# numpy.remainder,
|
||||
# numpy.right_shift,
|
||||
numpy.remainder,
|
||||
numpy.right_shift,
|
||||
numpy.rint,
|
||||
numpy.sign,
|
||||
numpy.signbit,
|
||||
@@ -334,7 +334,7 @@ class NPTracer(BaseTracer):
|
||||
# numpy.subtract,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
# numpy.true_divide,
|
||||
numpy.true_divide,
|
||||
numpy.trunc,
|
||||
]
|
||||
|
||||
@@ -378,9 +378,18 @@ NPTracer.UFUNC_ROUTING = {
|
||||
fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1
|
||||
}
|
||||
|
||||
NPTracer.UFUNC_ROUTING[numpy.arctan2] = _get_binary_fun(numpy.arctan2)
|
||||
NPTracer.UFUNC_ROUTING[numpy.float_power] = _get_binary_fun(numpy.float_power)
|
||||
NPTracer.UFUNC_ROUTING.update(
|
||||
{fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2}
|
||||
)
|
||||
|
||||
list_of_not_supported = [
|
||||
(ufunc.__name__, ufunc.nin)
|
||||
for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
if ufunc.nin not in [1, 2]
|
||||
]
|
||||
|
||||
custom_assert(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
|
||||
del list_of_not_supported
|
||||
|
||||
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
|
||||
# (note that this is not the proper complete handling of these functions)
|
||||
|
||||
@@ -189,16 +189,42 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_and,
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
numpy.logical_and,
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
numpy.remainder,
|
||||
numpy.right_shift,
|
||||
}
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
|
||||
# Know if the function is defined for integer inputs
|
||||
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
|
||||
if i not in [0, 2]:
|
||||
continue
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
|
||||
# a float output even for functions which return an integer (eg, XOR), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.int32)
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
@@ -208,7 +234,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
else:
|
||||
# With a float in second position
|
||||
@@ -217,6 +243,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
# Domain of definition
|
||||
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
|
||||
input_list = [2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
@@ -406,8 +406,11 @@ def test_tracing_astype(
|
||||
),
|
||||
],
|
||||
)
|
||||
# numpy.logical_not is removed from the following test since it is expecting inputs which are
|
||||
# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
|
||||
"function_to_trace_def",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not],
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
Reference in New Issue
Block a user