mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: add python generic helpers
- move catch and update_and_return_dict there
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
"""Helpers for all kinds of tasks."""
|
||||
|
||||
from . import indexing_helpers
|
||||
from . import indexing_helpers, python_helpers
|
||||
|
||||
35
concrete/common/helpers/python_helpers.py
Normal file
35
concrete/common/helpers/python_helpers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Common python helpers."""
|
||||
|
||||
from typing import Any, Callable, Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def update_and_return_dict(
|
||||
dict_to_update: dict, update_values: Union[Mapping, Iterable[Tuple[Any, Any]]]
|
||||
) -> dict:
|
||||
"""Update a dictionary and return the ref to the dictionary that was updated.
|
||||
|
||||
Args:
|
||||
dict_to_update (dict): the dict to update
|
||||
update_values (Union[Mapping, Iterable[Tuple[Any, Any]]]): the values to update the dict
|
||||
with
|
||||
|
||||
Returns:
|
||||
dict: the dict that was just updated.
|
||||
"""
|
||||
dict_to_update.update(update_values)
|
||||
return dict_to_update
|
||||
|
||||
|
||||
def catch(func: Callable, *args, **kwargs) -> Union[Any, None]:
|
||||
"""Execute func by passing args and kwargs. Catch exceptions and return None in case of failure.
|
||||
|
||||
Args:
|
||||
func (Callable): function to execute and catch exceptions from
|
||||
|
||||
Returns:
|
||||
Union[Any, None]: the function result if there was no exception, None otherwise.
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
@@ -16,6 +16,7 @@ from ..data_types.dtypes_helpers import (
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..helpers import indexing_helpers
|
||||
from ..helpers.python_helpers import catch, update_and_return_dict
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
@@ -390,22 +391,11 @@ class GenericFunction(IntermediateNode):
|
||||
min_input_range = variable_input_dtype.min_value()
|
||||
max_input_range = variable_input_dtype.max_value() + 1
|
||||
|
||||
def catch(func, *args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
# We currently cannot trigger exceptions in the code during evaluation
|
||||
except Exception: # pragma: no cover # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
template_input_dict = {
|
||||
idx: node.evaluate({}) if isinstance(node, Constant) else None
|
||||
for idx, node in enumerate(ordered_preds)
|
||||
}
|
||||
|
||||
def update_and_return_dict(dict_to_update: dict, update_values):
|
||||
dict_to_update.update(update_values)
|
||||
return dict_to_update
|
||||
|
||||
table = [
|
||||
catch(
|
||||
self.evaluate,
|
||||
|
||||
24
tests/common/helpers/test_python_helpers.py
Normal file
24
tests/common/helpers/test_python_helpers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Test file for common python helpers"""
|
||||
|
||||
from concrete.common.helpers.python_helpers import catch
|
||||
|
||||
|
||||
def test_catch_failure():
|
||||
"""Test case for when the function called with catch raises an exception."""
|
||||
|
||||
def f_fail():
|
||||
return 1 / 0
|
||||
|
||||
assert catch(f_fail) is None
|
||||
|
||||
|
||||
def test_catch():
|
||||
"""Test case for catch"""
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return *args, dict(**kwargs)
|
||||
|
||||
assert catch(f, (1, 2, 3,), **{"one": 1, "two": 2, "three": 3}) == (
|
||||
(1, 2, 3),
|
||||
{"one": 1, "two": 2, "three": 3},
|
||||
)
|
||||
Reference in New Issue
Block a user