mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test: check correctness of ufunc's
for the moment: - it has no hard check of correctness for now - some functions are not managed ref #551
This commit is contained in:
committed by
Benoit Chevallier
parent
67f50fb8ce
commit
12465f86ac
@@ -10,6 +10,7 @@ from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
@@ -52,6 +53,252 @@ def complicated_topology(x):
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f(func, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(10 * func(x))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f_with_float_inputs(func, x, y):
|
||||
"""Create an upper function to test `func`, with inputs which are forced to be floats"""
|
||||
z = numpy.abs(10 * func(x + 0.1))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f_with_integer_inputs(func, x, y):
|
||||
"""Create an upper function to test `func`, with inputs which are forced to be integers but
|
||||
in a way which is fusable into a TLU"""
|
||||
a = x + 0.1
|
||||
a = numpy.rint(a).astype(numpy.int32)
|
||||
z = numpy.abs(10 * func(a))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f_which_expects_small_inputs(func, x, y):
|
||||
"""Create an upper function to test `func`, which expects small values to not use too much
|
||||
precision"""
|
||||
a = numpy.abs(0.77 * numpy.sin(x))
|
||||
z = numpy.abs(3 * func(a))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f_which_has_large_outputs(func, x, y):
|
||||
"""Create an upper function to test `func`, which outputs large values"""
|
||||
a = numpy.abs(2 * numpy.sin(x))
|
||||
z = numpy.abs(func(a) * 0.131)
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_f_avoid_0_input(func, x, y):
|
||||
"""Create an upper function to test `func`, which makes that inputs are not 0"""
|
||||
a = numpy.abs(7 * numpy.sin(x)) + 1
|
||||
z = numpy.abs(5 * func(a))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_binary_f_one(func, c, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(func(x, c) + 1)
|
||||
z = z.astype(numpy.uint32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_binary_f_two(func, c, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(func(c, x) + 1)
|
||||
z = z.astype(numpy.uint32) + y
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_binary_f_two_avoid_0_input(func, c, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(func(c, x + 1) + 1)
|
||||
z = z.astype(numpy.uint32) + y
|
||||
return z
|
||||
|
||||
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input_ranges):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
return lambda x, y: upper_function(ufunc, x, y)
|
||||
|
||||
function = get_function(ufunc, upper_function)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {arg_name: EncryptedScalar(Integer(64, False)) for arg_name in ["x", "y"]}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
|
||||
# TODO: fix the check
|
||||
# assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
if compiler_engine.run(*args) != function(*args):
|
||||
print("Warning, bad computation")
|
||||
|
||||
|
||||
def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, input_ranges):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
return lambda x, y: upper_function(ufunc, c, x, y)
|
||||
|
||||
function = get_function(ufunc, upper_function)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {arg_name: EncryptedScalar(Integer(64, False)) for arg_name in ["x", "y"]}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
|
||||
# TODO: fix the check
|
||||
# assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
if compiler_engine.run(*args) != function(*args):
|
||||
print("Warning, bad computation")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ufunc",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 2],
|
||||
)
|
||||
def test_binary_ufunc_operations(ufunc):
|
||||
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 4), (0, 5))
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 5), (0, 5))
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 41, ((0, 5), (0, 5))
|
||||
)
|
||||
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
# 0 not in the domain of definition
|
||||
# Can't make it work, #649
|
||||
# TODO: fixme
|
||||
pass
|
||||
# subtest_compile_and_run_binary_ufunc_correctness(
|
||||
# ufunc, mix_x_and_y_and_call_binary_f_two_avoid_0_input, 31, ((1, 5), (1, 5))
|
||||
# )
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Can't make it work
|
||||
# TODO: fixme
|
||||
pass
|
||||
|
||||
# Need small constants to keep results sufficiently small
|
||||
# subtest_compile_and_run_binary_ufunc_correctness(
|
||||
# ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
# )
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 42, ((0, 5), (0, 5))
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ufunc", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
|
||||
)
|
||||
def test_unary_ufunc_operations(ufunc):
|
||||
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [
|
||||
numpy.degrees,
|
||||
numpy.rad2deg,
|
||||
]:
|
||||
# Need to reduce the output value, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_has_large_outputs, ((0, 5), (0, 5))
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.negative,
|
||||
]:
|
||||
# Need to turn the input into a float
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_with_float_inputs, ((0, 5), (0, 5))
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.invert,
|
||||
]:
|
||||
# Can't make it work, to have a fusable function
|
||||
# TODO: fixme
|
||||
pass
|
||||
# subtest_compile_and_run_unary_ufunc_correctness(
|
||||
# ufunc, mix_x_and_y_and_call_f_with_integer_inputs, ((0, 5), (0, 5))
|
||||
# )
|
||||
elif ufunc in [
|
||||
numpy.arccosh,
|
||||
numpy.log,
|
||||
numpy.log2,
|
||||
numpy.log10,
|
||||
numpy.reciprocal,
|
||||
]:
|
||||
# No 0 in the domain of definition
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_avoid_0_input, ((1, 5), (1, 5))
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.cosh,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.square,
|
||||
numpy.arccos,
|
||||
numpy.arcsin,
|
||||
numpy.arctanh,
|
||||
numpy.sinh,
|
||||
]:
|
||||
# Need a small range of inputs, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_expects_small_inputs, ((0, 5), (0, 5))
|
||||
)
|
||||
else:
|
||||
# Regular case for univariate functions
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f, ((0, 5), (0, 5))
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user