mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(representation): handle failure in UnivariateFunction get_table
- if only some values are problematic, flood fill the resulting table with other valid values - if no valid value was generated an AssertionError will be thrown
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""File containing code to represent source programs operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
@@ -196,6 +197,31 @@ class Constant(IntermediateNode):
|
||||
return str(self.constant_data)
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
"""Use a flooding algorithm to replace None values.
|
||||
|
||||
Args:
|
||||
table (list): the list in which there are None values that need to be replaced by copies of
|
||||
the closest non None data from the list.
|
||||
"""
|
||||
assert_true(any(value is not None for value in table))
|
||||
|
||||
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
|
||||
while not_none_values_idx:
|
||||
current_idx = not_none_values_idx.popleft()
|
||||
current_value = table[current_idx]
|
||||
previous_idx = current_idx - 1
|
||||
next_idx = current_idx + 1
|
||||
if previous_idx >= 0 and table[previous_idx] is None:
|
||||
table[previous_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(previous_idx)
|
||||
if next_idx < len(table) and table[next_idx] is None:
|
||||
table[next_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(next_idx)
|
||||
|
||||
assert_true(all(value is not None for value in table))
|
||||
|
||||
|
||||
class UnivariateFunction(IntermediateNode):
|
||||
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
|
||||
|
||||
@@ -267,11 +293,20 @@ class UnivariateFunction(IntermediateNode):
|
||||
min_input_range = input_dtype.min_value()
|
||||
max_input_range = input_dtype.max_value() + 1
|
||||
|
||||
def catch(func, *args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
# We currently cannot trigger exceptions in the code during evaluation
|
||||
except Exception: # pragma: no cover # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
table = [
|
||||
self.evaluate({0: input_value_constructor(input_value)})
|
||||
catch(self.evaluate, {0: input_value_constructor(input_value)})
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
flood_replace_none_values(table)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
|
||||
@@ -217,13 +217,15 @@ def compile_numpy_function_into_op_graph(
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
return _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return _compile_numpy_function_into_op_graph_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
@@ -280,7 +282,10 @@ def _compile_numpy_function_internal(
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
if show_mlir:
|
||||
@@ -337,14 +342,16 @@ def compile_numpy_function(
|
||||
|
||||
# Try to compile the function and save partial artifacts on failure
|
||||
try:
|
||||
return _compile_numpy_function_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
# Use context manager to restore numpy error handling
|
||||
with numpy.errstate(**numpy.geterr()):
|
||||
return _compile_numpy_function_internal(
|
||||
function_to_compile,
|
||||
function_parameters,
|
||||
inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
|
||||
# If it could be tested, we would have fixed the underlying issue.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test file for intermediate representation"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
@@ -292,3 +294,23 @@ def test_is_equivalent_to(
|
||||
== test_helpers.nodes_are_equivalent(node2, node1)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"list_to_fill,expected_list",
|
||||
[
|
||||
pytest.param([None, 1, 2, 3, None, None], [1, 1, 2, 3, 3, 3]),
|
||||
pytest.param([None], None, marks=pytest.mark.xfail(strict=True)),
|
||||
pytest.param([None, None, None, None, 7, None, None, None], [7, 7, 7, 7, 7, 7, 7, 7]),
|
||||
pytest.param([None, None, 3, None, None, None, 2, None], [3, 3, 3, 3, 3, 2, 2, 2]),
|
||||
],
|
||||
)
|
||||
def test_flood_replace_none_values(list_to_fill: list, expected_list: list):
|
||||
"""Unit test for flood_replace_none_values"""
|
||||
|
||||
# avoid modifying the test input
|
||||
list_to_fill_copy = deepcopy(list_to_fill)
|
||||
ir.flood_replace_none_values(list_to_fill_copy)
|
||||
|
||||
assert all(value is not None for value in list_to_fill_copy)
|
||||
assert list_to_fill_copy == expected_list
|
||||
|
||||
@@ -116,13 +116,6 @@ def mix_x_and_y_and_call_binary_f_two(func, c, x, y):
|
||||
return z
|
||||
|
||||
|
||||
def mix_x_and_y_and_call_binary_f_two_avoid_0_input(func, c, x, y):
|
||||
"""Create an upper function to test `func`"""
|
||||
z = numpy.abs(func(c, x + 1) + 1)
|
||||
z = z.astype(numpy.uint32) + y
|
||||
return z
|
||||
|
||||
|
||||
def check_is_good_execution(compiler_engine, function, args):
|
||||
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
|
||||
return an error. One can set the expected probability of success of one execution and the
|
||||
@@ -223,13 +216,9 @@ def test_binary_ufunc_operations(ufunc):
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
# 0 not in the domain of definition
|
||||
# Can't make it work, #649
|
||||
# TODO: fixme
|
||||
pass
|
||||
# subtest_compile_and_run_binary_ufunc_correctness(
|
||||
# ufunc, mix_x_and_y_and_call_binary_f_two_avoid_0_input, 31, ((1, 5), (1, 5))
|
||||
# )
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 31, ((1, 5), (1, 5))
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
|
||||
Reference in New Issue
Block a user