mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
585de78081
commit
e78086eefa
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple
|
||||
|
||||
from ..debugging import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
|
||||
@@ -35,10 +36,13 @@ def eval_op_graph_bounds_on_dataset(
|
||||
"""
|
||||
|
||||
def check_dataset_input_len_is_valid(data_to_check):
|
||||
assert len(data_to_check) == len(op_graph.input_nodes), (
|
||||
f"Got input data from dataset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
custom_assert(
|
||||
len(data_to_check) == len(op_graph.input_nodes),
|
||||
(
|
||||
f"Got input data from dataset of len: {len(data_to_check)}, "
|
||||
f"function being evaluated has {len(op_graph.input_nodes)} inputs, please make "
|
||||
f"sure your data generator returns valid tuples of input values"
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from .data_types.integers import Integer
|
||||
from .debugging import custom_assert
|
||||
from .operator_graph import OPGraph
|
||||
from .representation import intermediate as ir
|
||||
|
||||
@@ -53,9 +54,10 @@ def check_op_graph_is_integer_program(
|
||||
"""
|
||||
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
|
||||
|
||||
assert isinstance(
|
||||
offending_nodes, list
|
||||
), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}"
|
||||
custom_assert(
|
||||
isinstance(offending_nodes, list),
|
||||
f"offending_nodes_out must be a list, got {type(offending_nodes_out)}",
|
||||
)
|
||||
|
||||
offending_nodes.clear()
|
||||
offending_nodes.extend(
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
||||
import networkx as nx
|
||||
from PIL import Image
|
||||
|
||||
from ..debugging import draw_graph, get_printable_graph
|
||||
from ..debugging import custom_assert, draw_graph, get_printable_graph
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..values import BaseValue
|
||||
@@ -102,7 +102,7 @@ class CompilationArtifacts:
|
||||
None
|
||||
"""
|
||||
|
||||
assert self.final_operation_graph is not None
|
||||
custom_assert(self.final_operation_graph is not None)
|
||||
self.bounds_of_the_final_operation_graph = bounds
|
||||
|
||||
def add_final_operation_graph_mlir(self, mlir: str):
|
||||
@@ -115,7 +115,7 @@ class CompilationArtifacts:
|
||||
None
|
||||
"""
|
||||
|
||||
assert self.final_operation_graph is not None
|
||||
custom_assert(self.final_operation_graph is not None)
|
||||
self.mlir_of_the_final_operation_graph = mlir
|
||||
|
||||
def export(self):
|
||||
@@ -186,7 +186,7 @@ class CompilationArtifacts:
|
||||
f.write(f"{representation}\n")
|
||||
|
||||
if self.bounds_of_the_final_operation_graph is not None:
|
||||
assert self.final_operation_graph is not None
|
||||
custom_assert(self.final_operation_graph is not None)
|
||||
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
|
||||
# TODO:
|
||||
# if nx.topological_sort is not deterministic between calls,
|
||||
@@ -194,11 +194,11 @@ class CompilationArtifacts:
|
||||
# thus, we may want to change this in the future
|
||||
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
|
||||
bounds = self.bounds_of_the_final_operation_graph.get(node)
|
||||
assert bounds is not None
|
||||
custom_assert(bounds is not None)
|
||||
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
|
||||
|
||||
if self.mlir_of_the_final_operation_graph is not None:
|
||||
assert self.final_operation_graph is not None
|
||||
custom_assert(self.final_operation_graph is not None)
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(self.mlir_of_the_final_operation_graph)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Union, cast
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
@@ -149,8 +150,8 @@ def find_type_to_hold_both_lossy(
|
||||
Returns:
|
||||
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
|
||||
"""
|
||||
assert isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}"
|
||||
assert isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}"
|
||||
custom_assert(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}")
|
||||
custom_assert(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}")
|
||||
|
||||
type_to_return: BaseDataType
|
||||
|
||||
@@ -208,8 +209,12 @@ def mix_scalar_values_determine_holding_dtype(
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
assert isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
|
||||
assert isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
|
||||
custom_assert(
|
||||
isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
|
||||
)
|
||||
custom_assert(
|
||||
isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
mixed_value: ScalarValue
|
||||
@@ -241,12 +246,19 @@ def mix_tensor_values_determine_holding_dtype(
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
assert isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
|
||||
assert isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
|
||||
custom_assert(
|
||||
isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
|
||||
)
|
||||
custom_assert(
|
||||
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
|
||||
)
|
||||
|
||||
assert value1.shape == value2.shape, (
|
||||
f"Tensors have different shapes which is not supported.\n"
|
||||
f"value1: {value1.shape}, value2: {value2.shape}"
|
||||
custom_assert(
|
||||
value1.shape == value2.shape,
|
||||
(
|
||||
f"Tensors have different shapes which is not supported.\n"
|
||||
f"value1: {value1.shape}, value2: {value2.shape}"
|
||||
),
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
@@ -279,9 +291,10 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
dtypes.
|
||||
"""
|
||||
|
||||
assert (
|
||||
value1.__class__ == value2.__class__
|
||||
), f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}"
|
||||
custom_assert(
|
||||
(value1.__class__ == value2.__class__),
|
||||
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
|
||||
)
|
||||
|
||||
if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue):
|
||||
return mix_scalar_values_determine_holding_dtype(value1, value2)
|
||||
@@ -304,9 +317,10 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
constant_data_type: BaseDataType
|
||||
assert isinstance(
|
||||
constant_data, (int, float)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
constant_data_type = Integer(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from . import base
|
||||
|
||||
|
||||
@@ -14,7 +15,7 @@ class Float(base.BaseDataType):
|
||||
|
||||
def __init__(self, bit_width: int) -> None:
|
||||
super().__init__()
|
||||
assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported"
|
||||
custom_assert(bit_width in (32, 64), "Only 32 and 64 bits floats are supported")
|
||||
self.bit_width = bit_width
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import math
|
||||
from typing import Any, Iterable
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from . import base
|
||||
|
||||
|
||||
@@ -14,7 +15,7 @@ class Integer(base.BaseDataType):
|
||||
|
||||
def __init__(self, bit_width: int, is_signed: bool) -> None:
|
||||
super().__init__()
|
||||
assert bit_width > 0, "bit_width must be > 0"
|
||||
custom_assert(bit_width > 0, "bit_width must be > 0")
|
||||
self.bit_width = bit_width
|
||||
self.is_signed = is_signed
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Module for debugging."""
|
||||
from .custom_assert import custom_assert
|
||||
from .drawing import draw_graph
|
||||
from .printing import get_printable_graph
|
||||
|
||||
49
concrete/common/debugging/custom_assert.py
Normal file
49
concrete/common/debugging/custom_assert.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Provide some variants of assert."""
|
||||
|
||||
|
||||
def custom_assert(condition: bool, on_error_msg: str = "") -> None:
|
||||
"""Provide a custom assert which is kept even if the optimized python mode is used.
|
||||
|
||||
See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation
|
||||
on the classical assert function
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If False, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
|
||||
if not condition:
|
||||
raise AssertionError(on_error_msg)
|
||||
|
||||
|
||||
def assert_true(condition: bool, on_error_msg: str = ""):
|
||||
"""Provide a custom assert to check that the condition is True.
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If False, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
return custom_assert(condition, on_error_msg)
|
||||
|
||||
|
||||
def assert_false(condition: bool, on_error_msg: str = ""):
|
||||
"""Provide a custom assert to check that the condition is False.
|
||||
|
||||
Args:
|
||||
condition(bool): the condition. If True, raise AssertionError
|
||||
on_error_msg(str): optional message for precising the error, in case of error
|
||||
|
||||
"""
|
||||
return custom_assert(not condition, on_error_msg)
|
||||
|
||||
|
||||
def assert_not_reached(on_error_msg: str):
|
||||
"""Provide a custom assert to check that a piece of code is never reached.
|
||||
|
||||
Args:
|
||||
on_error_msg(str): message for precising the error
|
||||
|
||||
"""
|
||||
return custom_assert(False, on_error_msg)
|
||||
@@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from PIL import Image
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import ALL_IR_NODES
|
||||
@@ -26,9 +27,12 @@ IR_NODE_COLOR_MAPPING = {
|
||||
}
|
||||
|
||||
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
|
||||
assert len(_missing_nodes_in_mapping) == 0, (
|
||||
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
custom_assert(
|
||||
len(_missing_nodes_in_mapping) == 0,
|
||||
(
|
||||
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
|
||||
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
||||
),
|
||||
)
|
||||
|
||||
del _missing_nodes_in_mapping
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
@@ -32,7 +33,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
"""
|
||||
assert isinstance(opgraph, OPGraph)
|
||||
custom_assert(isinstance(opgraph, OPGraph))
|
||||
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
@@ -46,7 +47,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
# This code doesn't work with more than a single output. For more outputs,
|
||||
# we would need to change the way the destination are created: currently,
|
||||
# they only are done by incrementing i
|
||||
assert len(node.outputs) == 1
|
||||
custom_assert(len(node.outputs) == 1)
|
||||
|
||||
if isinstance(node, ir.Input):
|
||||
what_to_print = node.input_name
|
||||
@@ -72,9 +73,9 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
|
||||
|
||||
# Some checks, because the previous algorithm is not clear
|
||||
assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))
|
||||
custom_assert(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)))
|
||||
list_of_arg_name.sort()
|
||||
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
|
||||
custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
@@ -21,13 +21,14 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate node."""
|
||||
assert len(node.inputs) == 2, "addition should have two inputs"
|
||||
assert len(node.outputs) == 1, "addition should have a single output"
|
||||
custom_assert(len(node.inputs) == 2, "addition should have two inputs")
|
||||
custom_assert(len(node.outputs) == 1, "addition should have a single output")
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
@@ -70,8 +71,8 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert len(node.inputs) == 2, "subtraction should have two inputs"
|
||||
assert len(node.outputs) == 1, "subtraction should have a single output"
|
||||
custom_assert(len(node.inputs) == 2, "subtraction should have two inputs")
|
||||
custom_assert(len(node.outputs) == 1, "subtraction should have a single output")
|
||||
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
@@ -94,8 +95,8 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert len(node.inputs) == 2, "multiplication should have two inputs"
|
||||
assert len(node.outputs) == 1, "multiplication should have a single output"
|
||||
custom_assert(len(node.inputs) == 2, "multiplication should have two inputs")
|
||||
custom_assert(len(node.outputs) == 1, "multiplication should have a single output")
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
@@ -134,8 +135,8 @@ def constant(node, _, __, ctx):
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an arbitrary function intermediate node."""
|
||||
assert len(node.inputs) == 1, "LUT should have a single input"
|
||||
assert len(node.outputs) == 1, "LUT should have a single output"
|
||||
custom_assert(len(node.inputs) == 1, "LUT should have a single input")
|
||||
custom_assert(len(node.outputs) == 1, "LUT should have a single output")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
raise TypeError("Only support LUT with encrypted unsigned integers inputs")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
|
||||
@@ -160,8 +161,8 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a dot intermediate node."""
|
||||
assert len(node.inputs) == 2, "Dot should have two inputs"
|
||||
assert len(node.outputs) == 1, "Dot should have a single output"
|
||||
custom_assert(len(node.inputs) == 2, "Dot should have two inputs")
|
||||
custom_assert(len(node.outputs) == 1, "Dot should have a single output")
|
||||
if not (
|
||||
(
|
||||
value_is_encrypted_tensor_integer(node.inputs[0])
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_unsigned_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
@@ -93,7 +94,7 @@ class MLIRConverter:
|
||||
if is_signed and not is_encrypted: # clear signed
|
||||
return IntegerType.get_signed(bit_width)
|
||||
# should be clear unsigned at this point
|
||||
assert not is_signed and not is_encrypted
|
||||
custom_assert(not is_signed and not is_encrypted)
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(bit_width)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .data_types.dtypes_helpers import (
|
||||
)
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
from .debugging.custom_assert import custom_assert
|
||||
from .representation import intermediate as ir
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
@@ -30,13 +31,17 @@ class OPGraph:
|
||||
input_nodes: Dict[int, ir.Input],
|
||||
output_nodes: Dict[int, ir.IntermediateNode],
|
||||
) -> None:
|
||||
assert len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
|
||||
assert all(
|
||||
isinstance(node, ir.Input) for node in input_nodes.values()
|
||||
), "Got input nodes that were not ir.Input, which is not supported"
|
||||
assert all(
|
||||
isinstance(node, ir.IntermediateNode) for node in output_nodes.values()
|
||||
), "Got output nodes which were not ir.IntermediateNode, which is not supported"
|
||||
custom_assert(
|
||||
len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
|
||||
)
|
||||
custom_assert(
|
||||
all(isinstance(node, ir.Input) for node in input_nodes.values()),
|
||||
"Got input nodes that were not ir.Input, which is not supported",
|
||||
)
|
||||
custom_assert(
|
||||
all(isinstance(node, ir.IntermediateNode) for node in output_nodes.values()),
|
||||
"Got output nodes which were not ir.IntermediateNode, which is not supported",
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.input_nodes = input_nodes
|
||||
@@ -46,9 +51,10 @@ class OPGraph:
|
||||
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
|
||||
inputs = dict(enumerate(args))
|
||||
|
||||
assert len(inputs) == len(
|
||||
self.input_nodes
|
||||
), f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}"
|
||||
custom_assert(
|
||||
len(inputs) == len(self.input_nodes),
|
||||
f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}",
|
||||
)
|
||||
|
||||
results = self.evaluate(inputs)
|
||||
tuple_result = tuple(results[output_node] for output_node in self.get_ordered_outputs())
|
||||
@@ -177,9 +183,12 @@ class OPGraph:
|
||||
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
|
||||
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
|
||||
|
||||
assert max_data_type_constructor == min_data_type_constructor, (
|
||||
f"Got two different type constructors for min and max bound: "
|
||||
f"{min_data_type_constructor}, {max_data_type_constructor}"
|
||||
custom_assert(
|
||||
max_data_type_constructor == min_data_type_constructor,
|
||||
(
|
||||
f"Got two different type constructors for min and max bound: "
|
||||
f"{min_data_type_constructor}, {max_data_type_constructor}"
|
||||
),
|
||||
)
|
||||
|
||||
data_type_constructor = max_data_type_constructor
|
||||
@@ -191,20 +200,25 @@ class OPGraph:
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
assert isinstance(min_data_type, Float) and isinstance(
|
||||
max_data_type, Float
|
||||
), (
|
||||
"min_bound and max_bound have different common types, "
|
||||
"this should never happen.\n"
|
||||
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
|
||||
custom_assert(
|
||||
isinstance(min_data_type, Float) and isinstance(max_data_type, Float),
|
||||
(
|
||||
"min_bound and max_bound have different common types, "
|
||||
"this should never happen.\n"
|
||||
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
|
||||
),
|
||||
)
|
||||
output_value.data_type = Float(64)
|
||||
output_value.data_type.underlying_type_constructor = data_type_constructor
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
assert isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer), (
|
||||
f"Inputs to a graph should be integers, got bounds that were float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), max: {max_bound} ({type(max_bound)})"
|
||||
custom_assert(
|
||||
isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer),
|
||||
(
|
||||
f"Inputs to a graph should be integers, got bounds that were float, \n"
|
||||
f"min: {min_bound} ({type(min_bound)}), "
|
||||
f"max: {max_bound} ({type(max_bound)})"
|
||||
),
|
||||
)
|
||||
node.inputs[0].data_type = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
@@ -215,7 +229,7 @@ class OPGraph:
|
||||
|
||||
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
|
||||
# adding an edge
|
||||
assert len(node.outputs) == 1
|
||||
custom_assert(len(node.outputs) == 1)
|
||||
|
||||
successors = self.graph.succ[node]
|
||||
for succ in successors:
|
||||
|
||||
@@ -7,6 +7,7 @@ import networkx as nx
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
@@ -112,7 +113,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
non_constant_start_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, ir.Constant)
|
||||
]
|
||||
assert len(non_constant_start_nodes) == 1
|
||||
custom_assert(len(non_constant_start_nodes) == 1)
|
||||
|
||||
current_subgraph_variable_input = non_constant_start_nodes[0]
|
||||
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..data_types.dtypes_helpers import (
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
@@ -32,7 +33,7 @@ class IntermediateNode(ABC):
|
||||
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert all(isinstance(x, BaseValue) for x in self.inputs)
|
||||
custom_assert(all(isinstance(x, BaseValue) for x in self.inputs))
|
||||
|
||||
# Register all IR nodes
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
@@ -48,7 +49,7 @@ class IntermediateNode(ABC):
|
||||
"""__init__ for a binary operation, ie two inputs."""
|
||||
IntermediateNode.__init__(self, inputs)
|
||||
|
||||
assert len(self.inputs) == 2
|
||||
custom_assert(len(self.inputs) == 2)
|
||||
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
@@ -147,7 +148,7 @@ class Input(IntermediateNode):
|
||||
program_input_idx: int,
|
||||
) -> None:
|
||||
super().__init__((input_value,))
|
||||
assert len(self.inputs) == 1
|
||||
custom_assert(len(self.inputs) == 1)
|
||||
self.input_name = input_name
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
@@ -216,7 +217,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value])
|
||||
assert len(self.inputs) == 1
|
||||
custom_assert(len(self.inputs) == 1)
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
@@ -295,12 +296,15 @@ class Dot(IntermediateNode):
|
||||
] = default_dot_evaluation_function,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert len(self.inputs) == 2
|
||||
custom_assert(len(self.inputs) == 2)
|
||||
|
||||
assert all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim == 1
|
||||
for input_value in self.inputs
|
||||
), f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)"
|
||||
custom_assert(
|
||||
all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim == 1
|
||||
for input_value in self.inputs
|
||||
),
|
||||
f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)",
|
||||
)
|
||||
|
||||
output_scalar_value = (
|
||||
EncryptedScalar
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME
|
||||
from ..values import BaseValue
|
||||
@@ -105,7 +106,7 @@ class BaseTracer(ABC):
|
||||
ir.Add,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
|
||||
@@ -122,7 +123,7 @@ class BaseTracer(ABC):
|
||||
ir.Sub,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
@@ -134,7 +135,7 @@ class BaseTracer(ABC):
|
||||
ir.Sub,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
@@ -146,7 +147,7 @@ class BaseTracer(ABC):
|
||||
ir.Mul,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
return result_tracer[0]
|
||||
|
||||
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
from ..values import BaseValue
|
||||
from .base_tracer import BaseTracer
|
||||
@@ -121,6 +122,6 @@ def create_graph_from_output_tracers(
|
||||
|
||||
current_tracers = next_tracers
|
||||
|
||||
assert is_directed_acyclic_graph(graph)
|
||||
custom_assert(is_directed_acyclic_graph(graph))
|
||||
|
||||
return graph
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.debugging.custom_assert import custom_assert
|
||||
from ..common.values import BaseValue, ScalarValue, TensorValue
|
||||
|
||||
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
@@ -69,16 +70,20 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d
|
||||
Returns:
|
||||
numpy.dtype: The resulting numpy.dtype
|
||||
"""
|
||||
assert isinstance(
|
||||
common_dtype, BASE_DATA_TYPES
|
||||
), f"Unsupported common_dtype: {type(common_dtype)}"
|
||||
custom_assert(
|
||||
isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}"
|
||||
)
|
||||
type_to_return: numpy.dtype
|
||||
|
||||
if isinstance(common_dtype, Float):
|
||||
assert common_dtype.bit_width in (
|
||||
32,
|
||||
64,
|
||||
), "Only converting Float(32) or Float(64) is supported"
|
||||
custom_assert(
|
||||
common_dtype.bit_width
|
||||
in (
|
||||
32,
|
||||
64,
|
||||
),
|
||||
"Only converting Float(32) or Float(64) is supported",
|
||||
)
|
||||
type_to_return = (
|
||||
numpy.dtype(numpy.float64)
|
||||
if common_dtype.bit_width == 64
|
||||
@@ -110,9 +115,10 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
base_dtype: BaseDataType
|
||||
assert isinstance(
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
# numpy
|
||||
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data.dtype)
|
||||
@@ -141,9 +147,10 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
assert isinstance(
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
@@ -171,9 +178,10 @@ def get_numpy_function_output_dtype(
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
|
||||
"""
|
||||
if isinstance(function, numpy.ufunc):
|
||||
assert (
|
||||
len(input_dtypes) == function.nin
|
||||
), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}"
|
||||
custom_assert(
|
||||
(len(input_dtypes) == function.nin),
|
||||
f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}",
|
||||
)
|
||||
|
||||
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
@@ -203,9 +211,10 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
constant_data (Any): The data for which we want to determine the type constructor.
|
||||
"""
|
||||
|
||||
assert isinstance(
|
||||
constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
scalar_constructor: Type
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import custom_assert
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
@@ -40,9 +41,10 @@ class NPTracer(BaseTracer):
|
||||
"""
|
||||
if method == "__call__":
|
||||
tracing_func = self.get_tracing_func_for_np_function(ufunc)
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}"
|
||||
custom_assert(
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}",
|
||||
)
|
||||
return tracing_func(*input_tracers, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
@@ -52,9 +54,10 @@ class NPTracer(BaseTracer):
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
tracing_func = self.get_tracing_func_for_np_function(func)
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"**kwargs are currently not supported for numpy functions, func: {func}"
|
||||
custom_assert(
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy functions, func: {func}",
|
||||
)
|
||||
return tracing_func(*args, **kwargs)
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
@@ -69,10 +72,13 @@ class NPTracer(BaseTracer):
|
||||
Returns:
|
||||
NPTracer: The NPTracer representing the casting operation
|
||||
"""
|
||||
assert len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"astype currently only supports tracing without **kwargs, got {kwargs}"
|
||||
custom_assert(
|
||||
len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
|
||||
)
|
||||
custom_assert(
|
||||
(len(kwargs) == 0),
|
||||
f"astype currently only supports tracing without **kwargs, got {kwargs}",
|
||||
)
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
|
||||
@@ -139,9 +145,9 @@ class NPTracer(BaseTracer):
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
custom_assert(len(input_tracers) == 1)
|
||||
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
custom_assert(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
@@ -167,7 +173,7 @@ class NPTracer(BaseTracer):
|
||||
dot_inputs = (self, self._sanitize(other_tracer))
|
||||
|
||||
common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs)
|
||||
assert len(common_output_dtypes) == 1
|
||||
custom_assert(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = Dot(
|
||||
[input_tracer.output for input_tracer in dot_inputs],
|
||||
|
||||
29
tests/common/debugging/test_custom_assert.py
Normal file
29
tests/common/debugging/test_custom_assert.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Test custom assert functions."""
|
||||
import pytest
|
||||
|
||||
from concrete.common.debugging.custom_assert import (
|
||||
assert_false,
|
||||
assert_not_reached,
|
||||
assert_true,
|
||||
)
|
||||
|
||||
|
||||
def test_assert_not_functions():
|
||||
"""Test custom assert functions"""
|
||||
assert_true(True, "one check")
|
||||
assert_false(False, "another check")
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_not_reached("yet another one")
|
||||
|
||||
assert "yet another one" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_true(False, "one failing check")
|
||||
|
||||
assert "one failing check" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
assert_false(True, "another failing check")
|
||||
|
||||
assert "another failing check" in str(excinfo.value)
|
||||
Reference in New Issue
Block a user