From fae89bd452af16775e253389d302d2c2d5d93147 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 4 Nov 2021 15:16:27 +0100 Subject: [PATCH] refactor: add python generic helpers - move catch and update_and_return_dict there --- concrete/common/helpers/__init__.py | 2 +- concrete/common/helpers/python_helpers.py | 35 +++++++++++++++++++ .../common/representation/intermediate.py | 12 +------ tests/common/helpers/test_python_helpers.py | 24 +++++++++++++ 4 files changed, 61 insertions(+), 12 deletions(-) create mode 100644 concrete/common/helpers/python_helpers.py create mode 100644 tests/common/helpers/test_python_helpers.py diff --git a/concrete/common/helpers/__init__.py b/concrete/common/helpers/__init__.py index 908c72ca0..8680796ce 100644 --- a/concrete/common/helpers/__init__.py +++ b/concrete/common/helpers/__init__.py @@ -1,3 +1,3 @@ """Helpers for all kinds of tasks.""" -from . import indexing_helpers +from . import indexing_helpers, python_helpers diff --git a/concrete/common/helpers/python_helpers.py b/concrete/common/helpers/python_helpers.py new file mode 100644 index 000000000..e7c7bbed2 --- /dev/null +++ b/concrete/common/helpers/python_helpers.py @@ -0,0 +1,35 @@ +"""Common python helpers.""" + +from typing import Any, Callable, Iterable, Mapping, Tuple, Union + + +def update_and_return_dict( + dict_to_update: dict, update_values: Union[Mapping, Iterable[Tuple[Any, Any]]] +) -> dict: + """Update a dictionary and return the ref to the dictionary that was updated. + + Args: + dict_to_update (dict): the dict to update + update_values (Union[Mapping, Iterable[Tuple[Any, Any]]]): the values to update the dict + with + + Returns: + dict: the dict that was just updated. + """ + dict_to_update.update(update_values) + return dict_to_update + + +def catch(func: Callable, *args, **kwargs) -> Union[Any, None]: + """Execute func by passing args and kwargs. Catch exceptions and return None in case of failure. + + Args: + func (Callable): function to execute and catch exceptions from + + Returns: + Union[Any, None]: the function result if there was no exception, None otherwise. + """ + try: + return func(*args, **kwargs) + except Exception: # pylint: disable=broad-except + return None diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 445ec6e27..fd0bdeffb 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -16,6 +16,7 @@ from ..data_types.dtypes_helpers import ( from ..data_types.integers import Integer from ..debugging.custom_assert import assert_true from ..helpers import indexing_helpers +from ..helpers.python_helpers import catch, update_and_return_dict from ..values import ( BaseValue, ClearScalar, @@ -390,22 +391,11 @@ class GenericFunction(IntermediateNode): min_input_range = variable_input_dtype.min_value() max_input_range = variable_input_dtype.max_value() + 1 - def catch(func, *args, **kwargs): - try: - return func(*args, **kwargs) - # We currently cannot trigger exceptions in the code during evaluation - except Exception: # pragma: no cover # pylint: disable=broad-except - return None - template_input_dict = { idx: node.evaluate({}) if isinstance(node, Constant) else None for idx, node in enumerate(ordered_preds) } - def update_and_return_dict(dict_to_update: dict, update_values): - dict_to_update.update(update_values) - return dict_to_update - table = [ catch( self.evaluate, diff --git a/tests/common/helpers/test_python_helpers.py b/tests/common/helpers/test_python_helpers.py new file mode 100644 index 000000000..2b219b9b6 --- /dev/null +++ b/tests/common/helpers/test_python_helpers.py @@ -0,0 +1,24 @@ +"""Test file for common python helpers""" + +from concrete.common.helpers.python_helpers import catch + + +def test_catch_failure(): + """Test case for when the function called with catch raises an exception.""" + + def f_fail(): + return 1 / 0 + + assert catch(f_fail) is None + + +def test_catch(): + """Test case for catch""" + + def f(*args, **kwargs): + return *args, dict(**kwargs) + + assert catch(f, (1, 2, 3,), **{"one": 1, "two": 2, "three": 3}) == ( + (1, 2, 3), + {"one": 1, "two": 2, "three": 3}, + )