From 0cd33b6f67892eba22e5bbb7fc902872e34334e9 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 12 Oct 2021 17:32:15 +0200 Subject: [PATCH] feat(debugging): let's stop using custom_assert closes #637 --- .../bounds_measurement/inputset_eval.py | 4 ++-- concrete/common/common_helpers.py | 4 ++-- concrete/common/compilation/artifacts.py | 12 +++++----- concrete/common/data_types/dtypes_helpers.py | 16 +++++++------- concrete/common/data_types/floats.py | 4 ++-- concrete/common/data_types/integers.py | 4 ++-- concrete/common/debugging/__init__.py | 2 +- concrete/common/debugging/custom_assert.py | 8 +++---- concrete/common/debugging/drawing.py | 4 ++-- concrete/common/debugging/printing.py | 12 +++++----- concrete/common/mlir/converters.py | 22 +++++++++---------- concrete/common/mlir/mlir_converter.py | 4 ++-- concrete/common/operator_graph.py | 20 ++++++++--------- concrete/common/optimization/topological.py | 4 ++-- .../common/representation/intermediate.py | 14 ++++++------ concrete/common/tracing/base_tracer.py | 10 ++++----- concrete/common/tracing/tracing_helpers.py | 4 ++-- concrete/numpy/np_dtypes_helpers.py | 16 +++++++------- concrete/numpy/tracing.py | 22 +++++++++---------- 19 files changed, 92 insertions(+), 94 deletions(-) diff --git a/concrete/common/bounds_measurement/inputset_eval.py b/concrete/common/bounds_measurement/inputset_eval.py index 88051b441..904c0f009 100644 --- a/concrete/common/bounds_measurement/inputset_eval.py +++ b/concrete/common/bounds_measurement/inputset_eval.py @@ -8,7 +8,7 @@ from ..data_types.dtypes_helpers import ( get_base_value_for_python_constant_data, is_data_type_compatible_with, ) -from ..debugging import custom_assert +from ..debugging import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import IntermediateNode @@ -139,7 +139,7 @@ def eval_op_graph_bounds_on_inputset( """ def check_inputset_input_len_is_valid(data_to_check): - custom_assert( + assert_true( len(data_to_check) == len(op_graph.input_nodes), ( f"Got input data from inputset of len: {len(data_to_check)}, " diff --git a/concrete/common/common_helpers.py b/concrete/common/common_helpers.py index 53b3380c1..9ad5138d3 100644 --- a/concrete/common/common_helpers.py +++ b/concrete/common/common_helpers.py @@ -3,7 +3,7 @@ from typing import List, Optional from .data_types.integers import Integer -from .debugging import custom_assert +from .debugging import assert_true from .operator_graph import OPGraph from .representation.intermediate import IntermediateNode @@ -54,7 +54,7 @@ def check_op_graph_is_integer_program( """ offending_nodes = [] if offending_nodes_out is None else offending_nodes_out - custom_assert( + assert_true( isinstance(offending_nodes, list), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}", ) diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index 474a4b9b8..9c7178b00 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union import networkx as nx from PIL import Image -from ..debugging import custom_assert, draw_graph, get_printable_graph +from ..debugging import assert_true, draw_graph, get_printable_graph from ..operator_graph import OPGraph from ..representation.intermediate import IntermediateNode from ..values import BaseValue @@ -102,7 +102,7 @@ class CompilationArtifacts: None """ - custom_assert(self.final_operation_graph is not None) + assert_true(self.final_operation_graph is not None) self.bounds_of_the_final_operation_graph = bounds def add_final_operation_graph_mlir(self, mlir: str): @@ -115,7 +115,7 @@ class CompilationArtifacts: None """ - custom_assert(self.final_operation_graph is not None) + assert_true(self.final_operation_graph is not None) self.mlir_of_the_final_operation_graph = mlir def export(self): @@ -188,7 +188,7 @@ class CompilationArtifacts: f.write(f"{representation}") if self.bounds_of_the_final_operation_graph is not None: - custom_assert(self.final_operation_graph is not None) + assert_true(self.final_operation_graph is not None) with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f: # TODO: # if nx.topological_sort is not deterministic between calls, @@ -196,11 +196,11 @@ class CompilationArtifacts: # thus, we may want to change this in the future for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)): bounds = self.bounds_of_the_final_operation_graph.get(node) - custom_assert(bounds is not None) + assert_true(bounds is not None) f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n") if self.mlir_of_the_final_operation_graph is not None: - custom_assert(self.final_operation_graph is not None) + assert_true(self.final_operation_graph is not None) with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: f.write(self.mlir_of_the_final_operation_graph) diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 87cf75988..86eb7c24f 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -4,7 +4,7 @@ from copy import deepcopy from functools import partial from typing import Callable, Optional, Tuple, Union, cast -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue from .base import BaseDataType from .floats import Float @@ -146,8 +146,8 @@ def find_type_to_hold_both_lossy( Returns: BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2 """ - custom_assert(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}") - custom_assert(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}") + assert_true(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}") + assert_true(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}") type_to_return: BaseDataType @@ -205,15 +205,15 @@ def mix_tensor_values_determine_holding_dtype( value2 dtypes. """ - custom_assert( + assert_true( isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue" ) - custom_assert( + assert_true( isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" ) resulting_shape = broadcast_shapes(value1.shape, value2.shape) - custom_assert( + assert_true( resulting_shape is not None, ( f"Tensors have incompatible shapes which is not supported.\n" @@ -250,7 +250,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> dtypes. """ - custom_assert( + assert_true( (value1.__class__ == value2.__class__), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}", ) @@ -274,7 +274,7 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] BaseDataType: The corresponding BaseDataType """ constant_data_type: BaseDataType - custom_assert( + assert_true( isinstance(constant_data, (int, float)), f"Unsupported constant data of type {type(constant_data)}", ) diff --git a/concrete/common/data_types/floats.py b/concrete/common/data_types/floats.py index a26c240b8..a537d28bb 100644 --- a/concrete/common/data_types/floats.py +++ b/concrete/common/data_types/floats.py @@ -2,7 +2,7 @@ from functools import partial -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from . import base @@ -15,7 +15,7 @@ class Float(base.BaseDataType): def __init__(self, bit_width: int) -> None: super().__init__() - custom_assert(bit_width in (32, 64), "Only 32 and 64 bits floats are supported") + assert_true(bit_width in (32, 64), "Only 32 and 64 bits floats are supported") self.bit_width = bit_width def __repr__(self) -> str: diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py index 181a017b1..ed9654972 100644 --- a/concrete/common/data_types/integers.py +++ b/concrete/common/data_types/integers.py @@ -3,7 +3,7 @@ import math from typing import Any, Iterable -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from . import base @@ -15,7 +15,7 @@ class Integer(base.BaseDataType): def __init__(self, bit_width: int, is_signed: bool) -> None: super().__init__() - custom_assert(bit_width > 0, "bit_width must be > 0") + assert_true(bit_width > 0, "bit_width must be > 0") self.bit_width = bit_width self.is_signed = is_signed diff --git a/concrete/common/debugging/__init__.py b/concrete/common/debugging/__init__.py index c087039b5..811bf62da 100644 --- a/concrete/common/debugging/__init__.py +++ b/concrete/common/debugging/__init__.py @@ -1,4 +1,4 @@ """Module for debugging.""" -from .custom_assert import custom_assert +from .custom_assert import assert_true from .drawing import draw_graph from .printing import get_printable_graph diff --git a/concrete/common/debugging/custom_assert.py b/concrete/common/debugging/custom_assert.py index 71c88512f..1a639776c 100644 --- a/concrete/common/debugging/custom_assert.py +++ b/concrete/common/debugging/custom_assert.py @@ -1,7 +1,7 @@ """Provide some variants of assert.""" -def custom_assert(condition: bool, on_error_msg: str = "") -> None: +def _custom_assert(condition: bool, on_error_msg: str = "") -> None: """Provide a custom assert which is kept even if the optimized python mode is used. See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation @@ -25,7 +25,7 @@ def assert_true(condition: bool, on_error_msg: str = ""): on_error_msg(str): optional message for precising the error, in case of error """ - return custom_assert(condition, on_error_msg) + return _custom_assert(condition, on_error_msg) def assert_false(condition: bool, on_error_msg: str = ""): @@ -36,7 +36,7 @@ def assert_false(condition: bool, on_error_msg: str = ""): on_error_msg(str): optional message for precising the error, in case of error """ - return custom_assert(not condition, on_error_msg) + return _custom_assert(not condition, on_error_msg) def assert_not_reached(on_error_msg: str): @@ -46,4 +46,4 @@ def assert_not_reached(on_error_msg: str): on_error_msg(str): message for precising the error """ - return custom_assert(False, on_error_msg) + return _custom_assert(False, on_error_msg) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index 65ce67628..372cbb79b 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt import networkx as nx from PIL import Image -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import ( ALL_IR_NODES, @@ -36,7 +36,7 @@ IR_NODE_COLOR_MAPPING = { } _missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys() -custom_assert( +assert_true( len(_missing_nodes_in_mapping) == 0, ( f"Missing IR node in IR_NODE_COLOR_MAPPING : " diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 0b0444e7f..5ecfbc95e 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -4,7 +4,7 @@ from typing import Any, Dict import networkx as nx -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import Constant, Input, UnivariateFunction @@ -50,7 +50,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: Returns: str: a string to print or save in a file """ - custom_assert(isinstance(opgraph, OPGraph)) + assert_true(isinstance(opgraph, OPGraph)) list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values()) graph = opgraph.graph @@ -64,7 +64,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: # This code doesn't work with more than a single output. For more outputs, # we would need to change the way the destination are created: currently, # they only are done by incrementing i - custom_assert(len(node.outputs) == 1) + assert_true(len(node.outputs) == 1) if isinstance(node, Input): what_to_print = node.input_name @@ -91,9 +91,9 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: list_of_arg_name += [(index["input_idx"], str(map_table[pred]))] # Some checks, because the previous algorithm is not clear - custom_assert(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))) + assert_true(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))) list_of_arg_name.sort() - custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))) + assert_true([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))) prefix_to_add_to_what_to_print = "" suffix_to_add_to_what_to_print = "" @@ -105,7 +105,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: if node.op_attributes["in_which_input_is_constant"] == 0: prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, " else: - custom_assert( + assert_true( node.op_attributes["in_which_input_is_constant"] == 1, "'in_which_input_is_constant' should be a key of node.op_attributes", ) diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 255a6720a..21d09c791 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -21,15 +21,15 @@ from ..data_types.dtypes_helpers import ( value_is_encrypted_tensor_integer, ) from ..data_types.integers import Integer -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction from ..values import TensorValue def add(node, preds, ir_to_mlir_node, ctx): """Convert an addition intermediate node.""" - custom_assert(len(node.inputs) == 2, "addition should have two inputs") - custom_assert(len(node.outputs) == 1, "addition should have a single output") + assert_true(len(node.inputs) == 2, "addition should have two inputs") + assert_true(len(node.outputs) == 1, "addition should have a single output") if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): @@ -72,8 +72,8 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): def sub(node, preds, ir_to_mlir_node, ctx): """Convert a subtraction intermediate node.""" - custom_assert(len(node.inputs) == 2, "subtraction should have two inputs") - custom_assert(len(node.outputs) == 1, "subtraction should have a single output") + assert_true(len(node.inputs) == 2, "subtraction should have two inputs") + assert_true(len(node.outputs) == 1, "subtraction should have a single output") if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer( node.inputs[1] ): @@ -96,8 +96,8 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): def mul(node, preds, ir_to_mlir_node, ctx): """Convert a multiplication intermediate node.""" - custom_assert(len(node.inputs) == 2, "multiplication should have two inputs") - custom_assert(len(node.outputs) == 1, "multiplication should have a single output") + assert_true(len(node.inputs) == 2, "multiplication should have two inputs") + assert_true(len(node.outputs) == 1, "multiplication should have a single output") if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( node.inputs[1] ): @@ -166,8 +166,8 @@ def constant(node, _, __, ctx): def apply_lut(node, preds, ir_to_mlir_node, ctx): """Convert a UnivariateFunction intermediate node.""" - custom_assert(len(node.inputs) == 1, "LUT should have a single input") - custom_assert(len(node.outputs) == 1, "LUT should have a single output") + assert_true(len(node.inputs) == 1, "LUT should have a single input") + assert_true(len(node.outputs) == 1, "LUT should have a single output") if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): raise TypeError("Only support LUT with encrypted unsigned integers inputs") if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]): @@ -192,8 +192,8 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): def dot(node, preds, ir_to_mlir_node, ctx): """Convert a dot intermediate node.""" - custom_assert(len(node.inputs) == 2, "Dot should have two inputs") - custom_assert(len(node.outputs) == 1, "Dot should have a single output") + assert_true(len(node.inputs) == 2, "Dot should have two inputs") + assert_true(len(node.outputs) == 1, "Dot should have a single output") if not ( ( value_is_encrypted_tensor_integer(node.inputs[0]) diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index 55af22122..984fc0fc3 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -17,7 +17,7 @@ from ..data_types.dtypes_helpers import ( value_is_encrypted_scalar_unsigned_integer, value_is_encrypted_tensor_unsigned_integer, ) -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import Input @@ -83,7 +83,7 @@ class MLIRConverter: if is_signed and not is_encrypted: # clear signed return IntegerType.get_signed(bit_width) # should be clear unsigned at this point - custom_assert(not is_signed and not is_encrypted) + assert_true(not is_signed and not is_encrypted) # unsigned integer are considered signless in the compiler return IntegerType.get_signless(bit_width) diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index ec3e06dc2..4c3498cb9 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -12,7 +12,7 @@ from .data_types.dtypes_helpers import ( ) from .data_types.floats import Float from .data_types.integers import Integer, make_integer_to_hold -from .debugging.custom_assert import custom_assert +from .debugging.custom_assert import assert_true from .representation.intermediate import Input, IntermediateNode from .tracing import BaseTracer from .tracing.tracing_helpers import create_graph_from_output_tracers @@ -31,14 +31,12 @@ class OPGraph: input_nodes: Dict[int, Input], output_nodes: Dict[int, IntermediateNode], ) -> None: - custom_assert( - len(input_nodes) > 0, "Got a graph without input nodes which is not supported" - ) - custom_assert( + assert_true(len(input_nodes) > 0, "Got a graph without input nodes which is not supported") + assert_true( all(isinstance(node, Input) for node in input_nodes.values()), "Got input nodes that were not Input, which is not supported", ) - custom_assert( + assert_true( all(isinstance(node, IntermediateNode) for node in output_nodes.values()), "Got output nodes which were not IntermediateNode, which is not supported", ) @@ -51,7 +49,7 @@ class OPGraph: def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]: inputs = dict(enumerate(args)) - custom_assert( + assert_true( len(inputs) == len(self.input_nodes), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}", ) @@ -183,7 +181,7 @@ class OPGraph: min_data_type_constructor = get_type_constructor_for_constant_data(min_bound) max_data_type_constructor = get_type_constructor_for_constant_data(max_bound) - custom_assert( + assert_true( max_data_type_constructor == min_data_type_constructor, ( f"Got two different type constructors for min and max bound: " @@ -200,7 +198,7 @@ class OPGraph: (min_bound, max_bound), force_signed=False ) else: - custom_assert( + assert_true( isinstance(min_data_type, Float) and isinstance(max_data_type, Float), ( "min_bound and max_bound have different common types, " @@ -212,7 +210,7 @@ class OPGraph: output_value.dtype.underlying_type_constructor = data_type_constructor else: # Currently variable inputs are only allowed to be integers - custom_assert( + assert_true( isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), ( f"Inputs to a graph should be integers, got bounds that were float, \n" @@ -229,7 +227,7 @@ class OPGraph: # TODO: #57 manage multiple outputs from a node, probably requires an output_idx when # adding an edge - custom_assert(len(node.outputs) == 1) + assert_true(len(node.outputs) == 1) successors = self.graph.succ[node] for succ in successors: diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index b9941cc10..16d2d1df0 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -8,7 +8,7 @@ import networkx as nx from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer -from ..debugging.custom_assert import assert_true, custom_assert +from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction from ..values import TensorValue @@ -119,7 +119,7 @@ def convert_float_subgraph_to_fused_node( variable_input_nodes = [ node for node in float_subgraph_start_nodes if not isinstance(node, Constant) ] - custom_assert(len(variable_input_nodes) == 1) + assert_true(len(variable_input_nodes) == 1) current_subgraph_variable_input = variable_input_nodes[0] new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 2183a8597..144076ffc 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -12,7 +12,7 @@ from ..data_types.dtypes_helpers import ( mix_values_determine_holding_dtype, ) from ..data_types.integers import Integer -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -33,7 +33,7 @@ class IntermediateNode(ABC): **_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes ) -> None: self.inputs = list(inputs) - custom_assert(all(isinstance(x, BaseValue) for x in self.inputs)) + assert_true(all(isinstance(x, BaseValue) for x in self.inputs)) # Register all IR nodes def __init_subclass__(cls, **kwargs): @@ -49,7 +49,7 @@ class IntermediateNode(ABC): """__init__ for a binary operation, ie two inputs.""" IntermediateNode.__init__(self, inputs) - custom_assert(len(self.inputs) == 2) + assert_true(len(self.inputs) == 2) self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])] @@ -148,7 +148,7 @@ class Input(IntermediateNode): program_input_idx: int, ) -> None: super().__init__((input_value,)) - custom_assert(len(self.inputs) == 1) + assert_true(len(self.inputs) == 1) self.input_name = input_name self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] @@ -222,7 +222,7 @@ class UnivariateFunction(IntermediateNode): op_attributes: Optional[Dict[str, Any]] = None, ) -> None: super().__init__([input_base_value]) - custom_assert(len(self.inputs) == 1) + assert_true(len(self.inputs) == 1) self.arbitrary_func = arbitrary_func self.op_args = op_args if op_args is not None else () self.op_kwargs = op_kwargs if op_kwargs is not None else {} @@ -306,9 +306,9 @@ class Dot(IntermediateNode): ] = default_dot_evaluation_function, ) -> None: super().__init__(inputs) - custom_assert(len(self.inputs) == 2) + assert_true(len(self.inputs) == 2) - custom_assert( + assert_true( all( isinstance(input_value, TensorValue) and input_value.ndim == 1 for input_value in self.inputs diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 1bd9ad747..086b24139 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, List, Tuple, Type, Union -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true from ..representation.intermediate import ( IR_MIX_VALUES_FUNC_ARG_NAME, Add, @@ -111,7 +111,7 @@ class BaseTracer(ABC): Add, ) - custom_assert(len(result_tracer) == 1) + assert_true(len(result_tracer) == 1) return result_tracer[0] # With that is that x + 1 and 1 + x have the same graph. If we want to keep @@ -128,7 +128,7 @@ class BaseTracer(ABC): Sub, ) - custom_assert(len(result_tracer) == 1) + assert_true(len(result_tracer) == 1) return result_tracer[0] def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": @@ -140,7 +140,7 @@ class BaseTracer(ABC): Sub, ) - custom_assert(len(result_tracer) == 1) + assert_true(len(result_tracer) == 1) return result_tracer[0] def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": @@ -152,7 +152,7 @@ class BaseTracer(ABC): Mul, ) - custom_assert(len(result_tracer) == 1) + assert_true(len(result_tracer) == 1) return result_tracer[0] # With that is that x * 3 and 3 * x have the same graph. If we want to keep diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index 2e3eea6a8..24b2f673f 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph -from ..debugging.custom_assert import assert_true, custom_assert +from ..debugging.custom_assert import assert_true from ..representation.intermediate import Input from ..values import BaseValue from .base_tracer import BaseTracer @@ -124,7 +124,7 @@ def create_graph_from_output_tracers( current_tracers = next_tracers - custom_assert(is_directed_acyclic_graph(graph)) + assert_true(is_directed_acyclic_graph(graph)) # Check each edge is unique unique_edges = set( diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index dc5a0d3bf..599886068 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -17,7 +17,7 @@ from ..common.data_types.dtypes_helpers import ( ) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer -from ..common.debugging.custom_assert import custom_assert +from ..common.debugging.custom_assert import assert_true from ..common.values import BaseValue, TensorValue NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { @@ -72,13 +72,13 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d Returns: numpy.dtype: The resulting numpy.dtype """ - custom_assert( + assert_true( isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}" ) type_to_return: numpy.dtype if isinstance(common_dtype, Float): - custom_assert( + assert_true( common_dtype.bit_width in ( 32, @@ -117,7 +117,7 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType: The corresponding BaseDataType """ base_dtype: BaseDataType - custom_assert( + assert_true( isinstance( constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) ), @@ -159,12 +159,12 @@ def get_base_value_for_numpy_or_python_constant_data( with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). """ constant_data_value: Callable[..., BaseValue] - custom_assert( + assert_true( not isinstance(constant_data, list), "Unsupported constant data of type list " "(if you meant to use a list as an array, please use numpy.array instead)", ) - custom_assert( + assert_true( isinstance( constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES), @@ -198,7 +198,7 @@ def get_numpy_function_output_dtype( List[numpy.dtype]: The ordered numpy dtypes of the function outputs """ if isinstance(function, numpy.ufunc): - custom_assert( + assert_true( (len(input_dtypes) == function.nin), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}", ) @@ -231,7 +231,7 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any): constant_data (Any): The data for which we want to determine the type constructor. """ - custom_assert( + assert_true( isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), f"Unsupported constant data of type {type(constant_data)}", ) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index c9b5c6966..12ffb5f27 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -7,7 +7,7 @@ import numpy from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype -from ..common.debugging.custom_assert import assert_true, custom_assert +from ..common.debugging.custom_assert import assert_true from ..common.operator_graph import OPGraph from ..common.representation.intermediate import Constant, Dot, UnivariateFunction from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters @@ -41,7 +41,7 @@ class NPTracer(BaseTracer): """ if method == "__call__": tracing_func = self.get_tracing_func_for_np_function(ufunc) - custom_assert( + assert_true( (len(kwargs) == 0), f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}", ) @@ -58,7 +58,7 @@ class NPTracer(BaseTracer): Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch """ tracing_func = self.get_tracing_func_for_np_function(func) - custom_assert( + assert_true( (len(kwargs) == 0), f"**kwargs are currently not supported for numpy functions, func: {func}", ) @@ -77,10 +77,10 @@ class NPTracer(BaseTracer): Returns: NPTracer: The NPTracer representing the casting operation """ - custom_assert( + assert_true( len(args) == 0, f"astype currently only supports tracing without *args, got {args}" ) - custom_assert( + assert_true( (len(kwargs) == 0), f"astype currently only supports tracing without **kwargs, got {kwargs}", ) @@ -150,9 +150,9 @@ class NPTracer(BaseTracer): Returns: NPTracer: The output NPTracer containing the traced function """ - custom_assert(len(input_tracers) == 1) + assert_true(len(input_tracers) == 1) common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers) - custom_assert(len(common_output_dtypes) == 1) + assert_true(len(common_output_dtypes) == 1) traced_computation = UnivariateFunction( input_base_value=input_tracers[0].output, @@ -179,7 +179,7 @@ class NPTracer(BaseTracer): Returns: NPTracer: The output NPTracer containing the traced function """ - custom_assert(len(input_tracers) == 2) + assert_true(len(input_tracers) == 2) # One of the inputs has to be constant if isinstance(input_tracers[0].traced_computation, Constant): @@ -204,7 +204,7 @@ class NPTracer(BaseTracer): return binary_operator(x, baked_constant, **kwargs) common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers) - custom_assert(len(common_output_dtypes) == 1) + assert_true(len(common_output_dtypes) == 1) op_kwargs = deepcopy(kwargs) op_kwargs["baked_constant"] = baked_constant @@ -242,7 +242,7 @@ class NPTracer(BaseTracer): assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}") common_output_dtypes = self._manage_dtypes(numpy.dot, *args) - custom_assert(len(common_output_dtypes) == 1) + assert_true(len(common_output_dtypes) == 1) traced_computation = Dot( [input_tracer.output for input_tracer in args], @@ -399,7 +399,7 @@ list_of_not_supported = [ if ufunc.nin not in [1, 2] ] -custom_assert(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}") +assert_true(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}") del list_of_not_supported # We are adding initial support for `np.array(...)` +,-,* `BaseTracer`