diff --git a/concrete/common/bounds_measurement/dataset_eval.py b/concrete/common/bounds_measurement/dataset_eval.py index b47889f98..e8662f462 100644 --- a/concrete/common/bounds_measurement/dataset_eval.py +++ b/concrete/common/bounds_measurement/dataset_eval.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Iterator, Tuple +from ..debugging import custom_assert from ..operator_graph import OPGraph from ..representation.intermediate import IntermediateNode @@ -35,10 +36,13 @@ def eval_op_graph_bounds_on_dataset( """ def check_dataset_input_len_is_valid(data_to_check): - assert len(data_to_check) == len(op_graph.input_nodes), ( - f"Got input data from dataset of len: {len(data_to_check)}, " - f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make " - f"sure your data generator returns valid tuples of input values" + custom_assert( + len(data_to_check) == len(op_graph.input_nodes), + ( + f"Got input data from dataset of len: {len(data_to_check)}, " + f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make " + f"sure your data generator returns valid tuples of input values" + ), ) # TODO: do we want to check coherence between the input data type and the corresponding Input ir diff --git a/concrete/common/common_helpers.py b/concrete/common/common_helpers.py index 71dac78be..53b0989f8 100644 --- a/concrete/common/common_helpers.py +++ b/concrete/common/common_helpers.py @@ -3,6 +3,7 @@ from typing import List, Optional from .data_types.integers import Integer +from .debugging import custom_assert from .operator_graph import OPGraph from .representation import intermediate as ir @@ -53,9 +54,10 @@ def check_op_graph_is_integer_program( """ offending_nodes = [] if offending_nodes_out is None else offending_nodes_out - assert isinstance( - offending_nodes, list - ), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}" + custom_assert( + isinstance(offending_nodes, list), + f"offending_nodes_out must be a list, got {type(offending_nodes_out)}", + ) offending_nodes.clear() offending_nodes.extend( diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index 152b48b62..f35347dd1 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union import networkx as nx from PIL import Image -from ..debugging import draw_graph, get_printable_graph +from ..debugging import custom_assert, draw_graph, get_printable_graph from ..operator_graph import OPGraph from ..representation import intermediate as ir from ..values import BaseValue @@ -102,7 +102,7 @@ class CompilationArtifacts: None """ - assert self.final_operation_graph is not None + custom_assert(self.final_operation_graph is not None) self.bounds_of_the_final_operation_graph = bounds def add_final_operation_graph_mlir(self, mlir: str): @@ -115,7 +115,7 @@ class CompilationArtifacts: None """ - assert self.final_operation_graph is not None + custom_assert(self.final_operation_graph is not None) self.mlir_of_the_final_operation_graph = mlir def export(self): @@ -186,7 +186,7 @@ class CompilationArtifacts: f.write(f"{representation}\n") if self.bounds_of_the_final_operation_graph is not None: - assert self.final_operation_graph is not None + custom_assert(self.final_operation_graph is not None) with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f: # TODO: # if nx.topological_sort is not deterministic between calls, @@ -194,11 +194,11 @@ class CompilationArtifacts: # thus, we may want to change this in the future for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)): bounds = self.bounds_of_the_final_operation_graph.get(node) - assert bounds is not None + custom_assert(bounds is not None) f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n") if self.mlir_of_the_final_operation_graph is not None: - assert self.final_operation_graph is not None + custom_assert(self.final_operation_graph is not None) with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: f.write(self.mlir_of_the_final_operation_graph) diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 83b4a5084..49c096384 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -4,6 +4,7 @@ from copy import deepcopy from functools import partial from typing import Callable, Union, cast +from ..debugging.custom_assert import custom_assert from ..values import ( BaseValue, ClearScalar, @@ -149,8 +150,8 @@ def find_type_to_hold_both_lossy( Returns: BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2 """ - assert isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}" - assert isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}" + custom_assert(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}") + custom_assert(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}") type_to_return: BaseDataType @@ -208,8 +209,12 @@ def mix_scalar_values_determine_holding_dtype( value2 dtypes. """ - assert isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue" - assert isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue" + custom_assert( + isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue" + ) + custom_assert( + isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue" + ) holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type) mixed_value: ScalarValue @@ -241,12 +246,19 @@ def mix_tensor_values_determine_holding_dtype( value2 dtypes. """ - assert isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue" - assert isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" + custom_assert( + isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue" + ) + custom_assert( + isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" + ) - assert value1.shape == value2.shape, ( - f"Tensors have different shapes which is not supported.\n" - f"value1: {value1.shape}, value2: {value2.shape}" + custom_assert( + value1.shape == value2.shape, + ( + f"Tensors have different shapes which is not supported.\n" + f"value1: {value1.shape}, value2: {value2.shape}" + ), ) holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type) @@ -279,9 +291,10 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> dtypes. """ - assert ( - value1.__class__ == value2.__class__ - ), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}" + custom_assert( + (value1.__class__ == value2.__class__), + f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}", + ) if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue): return mix_scalar_values_determine_holding_dtype(value1, value2) @@ -304,9 +317,10 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] BaseDataType: The corresponding BaseDataType """ constant_data_type: BaseDataType - assert isinstance( - constant_data, (int, float) - ), f"Unsupported constant data of type {type(constant_data)}" + custom_assert( + isinstance(constant_data, (int, float)), + f"Unsupported constant data of type {type(constant_data)}", + ) if isinstance(constant_data, int): is_signed = constant_data < 0 constant_data_type = Integer( diff --git a/concrete/common/data_types/floats.py b/concrete/common/data_types/floats.py index 63b52b3b3..a26c240b8 100644 --- a/concrete/common/data_types/floats.py +++ b/concrete/common/data_types/floats.py @@ -2,6 +2,7 @@ from functools import partial +from ..debugging.custom_assert import custom_assert from . import base @@ -14,7 +15,7 @@ class Float(base.BaseDataType): def __init__(self, bit_width: int) -> None: super().__init__() - assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported" + custom_assert(bit_width in (32, 64), "Only 32 and 64 bits floats are supported") self.bit_width = bit_width def __repr__(self) -> str: diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py index 2cbe83560..181a017b1 100644 --- a/concrete/common/data_types/integers.py +++ b/concrete/common/data_types/integers.py @@ -3,6 +3,7 @@ import math from typing import Any, Iterable +from ..debugging.custom_assert import custom_assert from . import base @@ -14,7 +15,7 @@ class Integer(base.BaseDataType): def __init__(self, bit_width: int, is_signed: bool) -> None: super().__init__() - assert bit_width > 0, "bit_width must be > 0" + custom_assert(bit_width > 0, "bit_width must be > 0") self.bit_width = bit_width self.is_signed = is_signed diff --git a/concrete/common/debugging/__init__.py b/concrete/common/debugging/__init__.py index a5253afdc..c087039b5 100644 --- a/concrete/common/debugging/__init__.py +++ b/concrete/common/debugging/__init__.py @@ -1,3 +1,4 @@ """Module for debugging.""" +from .custom_assert import custom_assert from .drawing import draw_graph from .printing import get_printable_graph diff --git a/concrete/common/debugging/custom_assert.py b/concrete/common/debugging/custom_assert.py new file mode 100644 index 000000000..71c88512f --- /dev/null +++ b/concrete/common/debugging/custom_assert.py @@ -0,0 +1,49 @@ +"""Provide some variants of assert.""" + + +def custom_assert(condition: bool, on_error_msg: str = "") -> None: + """Provide a custom assert which is kept even if the optimized python mode is used. + + See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation + on the classical assert function + + Args: + condition(bool): the condition. If False, raise AssertionError + on_error_msg(str): optional message for precising the error, in case of error + + """ + + if not condition: + raise AssertionError(on_error_msg) + + +def assert_true(condition: bool, on_error_msg: str = ""): + """Provide a custom assert to check that the condition is True. + + Args: + condition(bool): the condition. If False, raise AssertionError + on_error_msg(str): optional message for precising the error, in case of error + + """ + return custom_assert(condition, on_error_msg) + + +def assert_false(condition: bool, on_error_msg: str = ""): + """Provide a custom assert to check that the condition is False. + + Args: + condition(bool): the condition. If True, raise AssertionError + on_error_msg(str): optional message for precising the error, in case of error + + """ + return custom_assert(not condition, on_error_msg) + + +def assert_not_reached(on_error_msg: str): + """Provide a custom assert to check that a piece of code is never reached. + + Args: + on_error_msg(str): message for precising the error + + """ + return custom_assert(False, on_error_msg) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index a4662da8b..bcca75469 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import networkx as nx from PIL import Image +from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph from ..representation import intermediate as ir from ..representation.intermediate import ALL_IR_NODES @@ -26,9 +27,12 @@ IR_NODE_COLOR_MAPPING = { } _missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys() -assert len(_missing_nodes_in_mapping) == 0, ( - f"Missing IR node in IR_NODE_COLOR_MAPPING : " - f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" +custom_assert( + len(_missing_nodes_in_mapping) == 0, + ( + f"Missing IR node in IR_NODE_COLOR_MAPPING : " + f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" + ), ) del _missing_nodes_in_mapping diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 8cce142b1..3b3123f82 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -4,6 +4,7 @@ from typing import Any, Dict import networkx as nx +from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph from ..representation import intermediate as ir @@ -32,7 +33,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: Returns: str: a string to print or save in a file """ - assert isinstance(opgraph, OPGraph) + custom_assert(isinstance(opgraph, OPGraph)) list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values()) graph = opgraph.graph @@ -46,7 +47,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: # This code doesn't work with more than a single output. For more outputs, # we would need to change the way the destination are created: currently, # they only are done by incrementing i - assert len(node.outputs) == 1 + custom_assert(len(node.outputs) == 1) if isinstance(node, ir.Input): what_to_print = node.input_name @@ -72,9 +73,9 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: list_of_arg_name += [(index["input_idx"], str(map_table[pred]))] # Some checks, because the previous algorithm is not clear - assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)) + custom_assert(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))) list_of_arg_name.sort() - assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))) + custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))) # Then, just print the predecessors in the right order what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index cab84d863..20d7703e2 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -21,13 +21,14 @@ from ..data_types.dtypes_helpers import ( value_is_encrypted_scalar_unsigned_integer, value_is_encrypted_tensor_integer, ) +from ..debugging.custom_assert import custom_assert from ..representation import intermediate as ir def add(node, preds, ir_to_mlir_node, ctx): """Convert an addition intermediate node.""" - assert len(node.inputs) == 2, "addition should have two inputs" - assert len(node.outputs) == 1, "addition should have a single output" + custom_assert(len(node.inputs) == 2, "addition should have two inputs") + custom_assert(len(node.outputs) == 1, "addition should have a single output") if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): @@ -70,8 +71,8 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): def sub(node, preds, ir_to_mlir_node, ctx): """Convert a subtraction intermediate node.""" - assert len(node.inputs) == 2, "subtraction should have two inputs" - assert len(node.outputs) == 1, "subtraction should have a single output" + custom_assert(len(node.inputs) == 2, "subtraction should have two inputs") + custom_assert(len(node.outputs) == 1, "subtraction should have a single output") if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer( node.inputs[1] ): @@ -94,8 +95,8 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): def mul(node, preds, ir_to_mlir_node, ctx): """Convert a multiplication intermediate node.""" - assert len(node.inputs) == 2, "multiplication should have two inputs" - assert len(node.outputs) == 1, "multiplication should have a single output" + custom_assert(len(node.inputs) == 2, "multiplication should have two inputs") + custom_assert(len(node.outputs) == 1, "multiplication should have a single output") if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): @@ -134,8 +135,8 @@ def constant(node, _, __, ctx): def apply_lut(node, preds, ir_to_mlir_node, ctx): """Convert an arbitrary function intermediate node.""" - assert len(node.inputs) == 1, "LUT should have a single input" - assert len(node.outputs) == 1, "LUT should have a single output" + custom_assert(len(node.inputs) == 1, "LUT should have a single input") + custom_assert(len(node.outputs) == 1, "LUT should have a single output") if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): raise TypeError("Only support LUT with encrypted unsigned integers inputs") if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]): @@ -160,8 +161,8 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): def dot(node, preds, ir_to_mlir_node, ctx): """Convert a dot intermediate node.""" - assert len(node.inputs) == 2, "Dot should have two inputs" - assert len(node.outputs) == 1, "Dot should have a single output" + custom_assert(len(node.inputs) == 2, "Dot should have two inputs") + custom_assert(len(node.outputs) == 1, "Dot should have a single output") if not ( ( value_is_encrypted_tensor_integer(node.inputs[0]) diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index 352927ba7..2aa328ac3 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -25,6 +25,7 @@ from ..data_types.dtypes_helpers import ( value_is_encrypted_scalar_unsigned_integer, value_is_encrypted_tensor_unsigned_integer, ) +from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph from ..representation import intermediate as ir @@ -93,7 +94,7 @@ class MLIRConverter: if is_signed and not is_encrypted: # clear signed return IntegerType.get_signed(bit_width) # should be clear unsigned at this point - assert not is_signed and not is_encrypted + custom_assert(not is_signed and not is_encrypted) # unsigned integer are considered signless in the compiler return IntegerType.get_signless(bit_width) diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index b19c42140..1c5ee104e 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -12,6 +12,7 @@ from .data_types.dtypes_helpers import ( ) from .data_types.floats import Float from .data_types.integers import Integer, make_integer_to_hold +from .debugging.custom_assert import custom_assert from .representation import intermediate as ir from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -30,13 +31,17 @@ class OPGraph: input_nodes: Dict[int, ir.Input], output_nodes: Dict[int, ir.IntermediateNode], ) -> None: - assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported" - assert all( - isinstance(node, ir.Input) for node in input_nodes.values() - ), "Got input nodes that were not ir.Input, which is not supported" - assert all( - isinstance(node, ir.IntermediateNode) for node in output_nodes.values() - ), "Got output nodes which were not ir.IntermediateNode, which is not supported" + custom_assert( + len(input_nodes) > 0, "Got a graph without input nodes which is not supported" + ) + custom_assert( + all(isinstance(node, ir.Input) for node in input_nodes.values()), + "Got input nodes that were not ir.Input, which is not supported", + ) + custom_assert( + all(isinstance(node, ir.IntermediateNode) for node in output_nodes.values()), + "Got output nodes which were not ir.IntermediateNode, which is not supported", + ) self.graph = graph self.input_nodes = input_nodes @@ -46,9 +51,10 @@ class OPGraph: def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]: inputs = dict(enumerate(args)) - assert len(inputs) == len( - self.input_nodes - ), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}" + custom_assert( + len(inputs) == len(self.input_nodes), + f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}", + ) results = self.evaluate(inputs) tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs()) @@ -177,9 +183,12 @@ class OPGraph: min_data_type_constructor = get_type_constructor_for_constant_data(min_bound) max_data_type_constructor = get_type_constructor_for_constant_data(max_bound) - assert max_data_type_constructor == min_data_type_constructor, ( - f"Got two different type constructors for min and max bound: " - f"{min_data_type_constructor}, {max_data_type_constructor}" + custom_assert( + max_data_type_constructor == min_data_type_constructor, + ( + f"Got two different type constructors for min and max bound: " + f"{min_data_type_constructor}, {max_data_type_constructor}" + ), ) data_type_constructor = max_data_type_constructor @@ -191,20 +200,25 @@ class OPGraph: (min_bound, max_bound), force_signed=False ) else: - assert isinstance(min_data_type, Float) and isinstance( - max_data_type, Float - ), ( - "min_bound and max_bound have different common types, " - "this should never happen.\n" - f"min_bound: {min_data_type}, max_bound: {max_data_type}" + custom_assert( + isinstance(min_data_type, Float) and isinstance(max_data_type, Float), + ( + "min_bound and max_bound have different common types, " + "this should never happen.\n" + f"min_bound: {min_data_type}, max_bound: {max_data_type}" + ), ) output_value.data_type = Float(64) output_value.data_type.underlying_type_constructor = data_type_constructor else: # Currently variable inputs are only allowed to be integers - assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), ( - f"Inputs to a graph should be integers, got bounds that were float, \n" - f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})" + custom_assert( + isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), + ( + f"Inputs to a graph should be integers, got bounds that were float, \n" + f"min: {min_bound} ({type(min_bound)}), " + f"max: {max_bound} ({type(max_bound)})" + ), ) node.inputs[0].data_type = make_integer_to_hold( (min_bound, max_bound), force_signed=False @@ -215,7 +229,7 @@ class OPGraph: # TODO: #57 manage multiple outputs from a node, probably requires an output_idx when # adding an edge - assert len(node.outputs) == 1 + custom_assert(len(node.outputs) == 1) successors = self.graph.succ[node] for succ in successors: diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 0638b1e16..7646c38b3 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -7,6 +7,7 @@ import networkx as nx from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer +from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph from ..representation import intermediate as ir @@ -112,7 +113,7 @@ def convert_float_subgraph_to_fused_node( non_constant_start_nodes = [ node for node in float_subgraph_start_nodes if not isinstance(node, ir.Constant) ] - assert len(non_constant_start_nodes) == 1 + custom_assert(len(non_constant_start_nodes) == 1) current_subgraph_variable_input = non_constant_start_nodes[0] new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index ed6325510..e42fc4b27 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -12,6 +12,7 @@ from ..data_types.dtypes_helpers import ( mix_scalar_values_determine_holding_dtype, ) from ..data_types.integers import Integer +from ..debugging.custom_assert import custom_assert from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -32,7 +33,7 @@ class IntermediateNode(ABC): **_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes ) -> None: self.inputs = list(inputs) - assert all(isinstance(x, BaseValue) for x in self.inputs) + custom_assert(all(isinstance(x, BaseValue) for x in self.inputs)) # Register all IR nodes def __init_subclass__(cls, **kwargs): @@ -48,7 +49,7 @@ class IntermediateNode(ABC): """__init__ for a binary operation, ie two inputs.""" IntermediateNode.__init__(self, inputs) - assert len(self.inputs) == 2 + custom_assert(len(self.inputs) == 2) self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] @@ -147,7 +148,7 @@ class Input(IntermediateNode): program_input_idx: int, ) -> None: super().__init__((input_value,)) - assert len(self.inputs) == 1 + custom_assert(len(self.inputs) == 1) self.input_name = input_name self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] @@ -216,7 +217,7 @@ class ArbitraryFunction(IntermediateNode): op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__([input_base_value]) - assert len(self.inputs) == 1 + custom_assert(len(self.inputs) == 1) self.arbitrary_func = arbitrary_func self.op_args = op_args if op_args is not None else () self.op_kwargs = op_kwargs if op_kwargs is not None else {} @@ -295,12 +296,15 @@ class Dot(IntermediateNode): ] = default_dot_evaluation_function, ) -> None: super().__init__(inputs) - assert len(self.inputs) == 2 + custom_assert(len(self.inputs) == 2) - assert all( - isinstance(input_value, TensorValue) and input_value.ndim == 1 - for input_value in self.inputs - ), f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)" + custom_assert( + all( + isinstance(input_value, TensorValue) and input_value.ndim == 1 + for input_value in self.inputs + ), + f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)", + ) output_scalar_value = ( EncryptedScalar diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index dfaa12548..094c8f80d 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, List, Tuple, Type, Union +from ..debugging.custom_assert import custom_assert from ..representation import intermediate as ir from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME from ..values import BaseValue @@ -105,7 +106,7 @@ class BaseTracer(ABC): ir.Add, ) - assert len(result_tracer) == 1 + custom_assert(len(result_tracer) == 1) return result_tracer[0] # With that is that x + 1 and 1 + x have the same graph. If we want to keep @@ -122,7 +123,7 @@ class BaseTracer(ABC): ir.Sub, ) - assert len(result_tracer) == 1 + custom_assert(len(result_tracer) == 1) return result_tracer[0] def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": @@ -134,7 +135,7 @@ class BaseTracer(ABC): ir.Sub, ) - assert len(result_tracer) == 1 + custom_assert(len(result_tracer) == 1) return result_tracer[0] def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": @@ -146,7 +147,7 @@ class BaseTracer(ABC): ir.Mul, ) - assert len(result_tracer) == 1 + custom_assert(len(result_tracer) == 1) return result_tracer[0] # With that is that x * 3 and 3 * x have the same graph. If we want to keep diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index bf8748fe2..712926270 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph +from ..debugging.custom_assert import custom_assert from ..representation import intermediate as ir from ..values import BaseValue from .base_tracer import BaseTracer @@ -121,6 +122,6 @@ def create_graph_from_output_tracers( current_tracers = next_tracers - assert is_directed_acyclic_graph(graph) + custom_assert(is_directed_acyclic_graph(graph)) return graph diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 1158925c4..2d2c4eab7 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -16,6 +16,7 @@ from ..common.data_types.dtypes_helpers import ( ) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer +from ..common.debugging.custom_assert import custom_assert from ..common.values import BaseValue, ScalarValue, TensorValue NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { @@ -69,16 +70,20 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d Returns: numpy.dtype: The resulting numpy.dtype """ - assert isinstance( - common_dtype, BASE_DATA_TYPES - ), f"Unsupported common_dtype: {type(common_dtype)}" + custom_assert( + isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}" + ) type_to_return: numpy.dtype if isinstance(common_dtype, Float): - assert common_dtype.bit_width in ( - 32, - 64, - ), "Only converting Float(32) or Float(64) is supported" + custom_assert( + common_dtype.bit_width + in ( + 32, + 64, + ), + "Only converting Float(32) or Float(64) is supported", + ) type_to_return = ( numpy.dtype(numpy.float64) if common_dtype.bit_width == 64 @@ -110,9 +115,10 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType: The corresponding BaseDataType """ base_dtype: BaseDataType - assert isinstance( - constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) - ), f"Unsupported constant data of type {type(constant_data)}" + custom_assert( + isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + f"Unsupported constant data of type {type(constant_data)}", + ) if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): # numpy base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype) @@ -141,9 +147,10 @@ def get_base_value_for_numpy_or_python_constant_data( with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). """ constant_data_value: Callable[..., BaseValue] - assert isinstance( - constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) - ), f"Unsupported constant data of type {type(constant_data)}" + custom_assert( + isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + f"Unsupported constant data of type {type(constant_data)}", + ) base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data) if isinstance(constant_data, numpy.ndarray): @@ -171,9 +178,10 @@ def get_numpy_function_output_dtype( List[numpy.dtype]: The ordered numpy dtypes of the function outputs """ if isinstance(function, numpy.ufunc): - assert ( - len(input_dtypes) == function.nin - ), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}" + custom_assert( + (len(input_dtypes) == function.nin), + f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}", + ) input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes] @@ -203,9 +211,10 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any): constant_data (Any): The data for which we want to determine the type constructor. """ - assert isinstance( - constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) - ), f"Unsupported constant data of type {type(constant_data)}" + custom_assert( + isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + f"Unsupported constant data of type {type(constant_data)}", + ) scalar_constructor: Type diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 0c1739852..7376a63f0 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -7,6 +7,7 @@ import numpy from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype +from ..common.debugging.custom_assert import custom_assert from ..common.operator_graph import OPGraph from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters @@ -40,9 +41,10 @@ class NPTracer(BaseTracer): """ if method == "__call__": tracing_func = self.get_tracing_func_for_np_function(ufunc) - assert ( - len(kwargs) == 0 - ), f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}" + custom_assert( + (len(kwargs) == 0), + f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}", + ) return tracing_func(*input_tracers, **kwargs) raise NotImplementedError("Only __call__ method is supported currently") @@ -52,9 +54,10 @@ class NPTracer(BaseTracer): Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch """ tracing_func = self.get_tracing_func_for_np_function(func) - assert ( - len(kwargs) == 0 - ), f"**kwargs are currently not supported for numpy functions, func: {func}" + custom_assert( + (len(kwargs) == 0), + f"**kwargs are currently not supported for numpy functions, func: {func}", + ) return tracing_func(*args, **kwargs) def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": @@ -69,10 +72,13 @@ class NPTracer(BaseTracer): Returns: NPTracer: The NPTracer representing the casting operation """ - assert len(args) == 0, f"astype currently only supports tracing without *args, got {args}" - assert ( - len(kwargs) == 0 - ), f"astype currently only supports tracing without **kwargs, got {kwargs}" + custom_assert( + len(args) == 0, f"astype currently only supports tracing without *args, got {args}" + ) + custom_assert( + (len(kwargs) == 0), + f"astype currently only supports tracing without **kwargs, got {kwargs}", + ) normalized_numpy_dtype = numpy.dtype(numpy_dtype) output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype) @@ -139,9 +145,9 @@ class NPTracer(BaseTracer): Returns: NPTracer: The output NPTracer containing the traced function """ - assert len(input_tracers) == 1 + custom_assert(len(input_tracers) == 1) common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers) - assert len(common_output_dtypes) == 1 + custom_assert(len(common_output_dtypes) == 1) traced_computation = ArbitraryFunction( input_base_value=input_tracers[0].output, @@ -167,7 +173,7 @@ class NPTracer(BaseTracer): dot_inputs = (self, self._sanitize(other_tracer)) common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs) - assert len(common_output_dtypes) == 1 + custom_assert(len(common_output_dtypes) == 1) traced_computation = Dot( [input_tracer.output for input_tracer in dot_inputs], diff --git a/tests/common/debugging/test_custom_assert.py b/tests/common/debugging/test_custom_assert.py new file mode 100644 index 000000000..aa34b7a85 --- /dev/null +++ b/tests/common/debugging/test_custom_assert.py @@ -0,0 +1,29 @@ +"""Test custom assert functions.""" +import pytest + +from concrete.common.debugging.custom_assert import ( + assert_false, + assert_not_reached, + assert_true, +) + + +def test_assert_not_functions(): + """Test custom assert functions""" + assert_true(True, "one check") + assert_false(False, "another check") + + with pytest.raises(AssertionError) as excinfo: + assert_not_reached("yet another one") + + assert "yet another one" in str(excinfo.value) + + with pytest.raises(AssertionError) as excinfo: + assert_true(False, "one failing check") + + assert "one failing check" in str(excinfo.value) + + with pytest.raises(AssertionError) as excinfo: + assert_false(True, "another failing check") + + assert "another failing check" in str(excinfo.value)