fix(representation): handle failure in UnivariateFunction get_table

- if only some values are problematic, flood fill the resulting table with
other valid values
- if no valid value was generated an AssertionError will be thrown
This commit is contained in:
Arthur Meyre
2021-10-15 09:33:19 +02:00
parent 1711663f67
commit 93e39e58f7
4 changed files with 84 additions and 31 deletions

View File

@@ -1,6 +1,7 @@
"""File containing code to represent source programs operations."""
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
@@ -196,6 +197,31 @@ class Constant(IntermediateNode):
return str(self.constant_data)
def flood_replace_none_values(table: list):
"""Use a flooding algorithm to replace None values.
Args:
table (list): the list in which there are None values that need to be replaced by copies of
the closest non None data from the list.
"""
assert_true(any(value is not None for value in table))
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
while not_none_values_idx:
current_idx = not_none_values_idx.popleft()
current_value = table[current_idx]
previous_idx = current_idx - 1
next_idx = current_idx + 1
if previous_idx >= 0 and table[previous_idx] is None:
table[previous_idx] = deepcopy(current_value)
not_none_values_idx.append(previous_idx)
if next_idx < len(table) and table[next_idx] is None:
table[next_idx] = deepcopy(current_value)
not_none_values_idx.append(next_idx)
assert_true(all(value is not None for value in table))
class UnivariateFunction(IntermediateNode):
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
@@ -267,11 +293,20 @@ class UnivariateFunction(IntermediateNode):
min_input_range = input_dtype.min_value()
max_input_range = input_dtype.max_value() + 1
def catch(func, *args, **kwargs):
try:
return func(*args, **kwargs)
# We currently cannot trigger exceptions in the code during evaluation
except Exception: # pragma: no cover # pylint: disable=broad-except
return None
table = [
self.evaluate({0: input_value_constructor(input_value)})
catch(self.evaluate, {0: input_value_constructor(input_value)})
for input_value in range(min_input_range, max_input_range)
]
flood_replace_none_values(table)
return table

View File

@@ -217,13 +217,15 @@ def compile_numpy_function_into_op_graph(
# Try to compile the function and save partial artifacts on failure
try:
return _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return _compile_numpy_function_into_op_graph_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
)
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.
@@ -280,7 +282,10 @@ def _compile_numpy_function_internal(
# Convert graph to an MLIR representation
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(op_graph)
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested
if show_mlir:
@@ -337,14 +342,16 @@ def compile_numpy_function(
# Try to compile the function and save partial artifacts on failure
try:
return _compile_numpy_function_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)
# Use context manager to restore numpy error handling
with numpy.errstate(**numpy.geterr()):
return _compile_numpy_function_internal(
function_to_compile,
function_parameters,
inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)
except Exception: # pragma: no cover
# This branch is reserved for unexpected issues and hence it shouldn't be tested.
# If it could be tested, we would have fixed the underlying issue.

View File

@@ -1,5 +1,7 @@
"""Test file for intermediate representation"""
from copy import deepcopy
import numpy
import pytest
@@ -292,3 +294,23 @@ def test_is_equivalent_to(
== test_helpers.nodes_are_equivalent(node2, node1)
== expected_result
)
@pytest.mark.parametrize(
"list_to_fill,expected_list",
[
pytest.param([None, 1, 2, 3, None, None], [1, 1, 2, 3, 3, 3]),
pytest.param([None], None, marks=pytest.mark.xfail(strict=True)),
pytest.param([None, None, None, None, 7, None, None, None], [7, 7, 7, 7, 7, 7, 7, 7]),
pytest.param([None, None, 3, None, None, None, 2, None], [3, 3, 3, 3, 3, 2, 2, 2]),
],
)
def test_flood_replace_none_values(list_to_fill: list, expected_list: list):
"""Unit test for flood_replace_none_values"""
# avoid modifying the test input
list_to_fill_copy = deepcopy(list_to_fill)
ir.flood_replace_none_values(list_to_fill_copy)
assert all(value is not None for value in list_to_fill_copy)
assert list_to_fill_copy == expected_list

View File

@@ -116,13 +116,6 @@ def mix_x_and_y_and_call_binary_f_two(func, c, x, y):
return z
def mix_x_and_y_and_call_binary_f_two_avoid_0_input(func, c, x, y):
"""Create an upper function to test `func`"""
z = numpy.abs(func(c, x + 1) + 1)
z = z.astype(numpy.uint32) + y
return z
def check_is_good_execution(compiler_engine, function, args):
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
return an error. One can set the expected probability of success of one execution and the
@@ -223,13 +216,9 @@ def test_binary_ufunc_operations(ufunc):
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
)
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
# 0 not in the domain of definition
# Can't make it work, #649
# TODO: fixme
pass
# subtest_compile_and_run_binary_ufunc_correctness(
# ufunc, mix_x_and_y_and_call_binary_f_two_avoid_0_input, 31, ((1, 5), (1, 5))
# )
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 31, ((1, 5), (1, 5))
)
elif ufunc in [numpy.lcm, numpy.left_shift]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(