mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
committed by
Benoit Chevallier
parent
9ffe9b667a
commit
cfe48cca15
@@ -33,3 +33,11 @@ myst-parser = "^0.15.1"
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"error",
|
||||
"ignore::UserWarning",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.optimization.topological import fuse_float_operations
|
||||
from hdk.common.values import EncryptedScalar
|
||||
from hdk.numpy import tracing
|
||||
from hdk.numpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
@@ -36,7 +37,7 @@ def simple_fuse_output(x):
|
||||
return x.astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
|
||||
def complex_fuse_indirect_input(x, y):
|
||||
def complex_fuse_indirect_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
intermediate = x + y
|
||||
intermediate = intermediate + 2
|
||||
@@ -44,7 +45,7 @@ def complex_fuse_indirect_input(x, y):
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
x_p_1 = intermediate + 1.5
|
||||
x_p_2 = intermediate + 2.7
|
||||
x_p_3 = numpy.rint(x_p_1 + x_p_2)
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
@@ -55,11 +56,11 @@ def complex_fuse_indirect_input(x, y):
|
||||
)
|
||||
|
||||
|
||||
def complex_fuse_direct_input(x, y):
|
||||
def complex_fuse_direct_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
x_p_1 = x + 1.5
|
||||
x_p_2 = x + 2.7
|
||||
x_p_3 = numpy.rint(x_p_1 + x_p_2)
|
||||
x_p_1 = x + 0.1
|
||||
x_p_2 = x + 0.2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
@@ -77,8 +78,16 @@ def complex_fuse_direct_input(x, y):
|
||||
pytest.param(no_fuse_unhandled, False, id="no_fuse_unhandled"),
|
||||
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
|
||||
pytest.param(simple_fuse_output, True, id="no_fuse"),
|
||||
pytest.param(complex_fuse_indirect_input, True, id="complex_fuse_indirect_input"),
|
||||
pytest.param(complex_fuse_direct_input, True, id="complex_fuse_direct_input"),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_indirect_input_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_direct_input(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_direct_input_with_rint",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("input_", [0, 2, 42, 44])
|
||||
@@ -105,3 +114,47 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
def test_fuse_float_operations_correctness():
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
with fuse_float_operations."""
|
||||
|
||||
for fun in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
Reference in New Issue
Block a user