feat(debugging): let's stop using custom_assert

closes #637
This commit is contained in:
Benoit Chevallier-Mames
2021-10-12 17:32:15 +02:00
committed by Benoit Chevallier
parent 1c935f2d92
commit 0cd33b6f67
19 changed files with 92 additions and 94 deletions

View File

@@ -8,7 +8,7 @@ from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
is_data_type_compatible_with,
)
from ..debugging import custom_assert
from ..debugging import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
@@ -139,7 +139,7 @@ def eval_op_graph_bounds_on_inputset(
"""
def check_inputset_input_len_is_valid(data_to_check):
custom_assert(
assert_true(
len(data_to_check) == len(op_graph.input_nodes),
(
f"Got input data from inputset of len: {len(data_to_check)}, "

View File

@@ -3,7 +3,7 @@
from typing import List, Optional
from .data_types.integers import Integer
from .debugging import custom_assert
from .debugging import assert_true
from .operator_graph import OPGraph
from .representation.intermediate import IntermediateNode
@@ -54,7 +54,7 @@ def check_op_graph_is_integer_program(
"""
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
custom_assert(
assert_true(
isinstance(offending_nodes, list),
f"offending_nodes_out must be a list, got {type(offending_nodes_out)}",
)

View File

@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union
import networkx as nx
from PIL import Image
from ..debugging import custom_assert, draw_graph, get_printable_graph
from ..debugging import assert_true, draw_graph, get_printable_graph
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
from ..values import BaseValue
@@ -102,7 +102,7 @@ class CompilationArtifacts:
None
"""
custom_assert(self.final_operation_graph is not None)
assert_true(self.final_operation_graph is not None)
self.bounds_of_the_final_operation_graph = bounds
def add_final_operation_graph_mlir(self, mlir: str):
@@ -115,7 +115,7 @@ class CompilationArtifacts:
None
"""
custom_assert(self.final_operation_graph is not None)
assert_true(self.final_operation_graph is not None)
self.mlir_of_the_final_operation_graph = mlir
def export(self):
@@ -188,7 +188,7 @@ class CompilationArtifacts:
f.write(f"{representation}")
if self.bounds_of_the_final_operation_graph is not None:
custom_assert(self.final_operation_graph is not None)
assert_true(self.final_operation_graph is not None)
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
# TODO:
# if nx.topological_sort is not deterministic between calls,
@@ -196,11 +196,11 @@ class CompilationArtifacts:
# thus, we may want to change this in the future
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
bounds = self.bounds_of_the_final_operation_graph.get(node)
custom_assert(bounds is not None)
assert_true(bounds is not None)
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
if self.mlir_of_the_final_operation_graph is not None:
custom_assert(self.final_operation_graph is not None)
assert_true(self.final_operation_graph is not None)
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
f.write(self.mlir_of_the_final_operation_graph)

View File

@@ -4,7 +4,7 @@ from copy import deepcopy
from functools import partial
from typing import Callable, Optional, Tuple, Union, cast
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
from .base import BaseDataType
from .floats import Float
@@ -146,8 +146,8 @@ def find_type_to_hold_both_lossy(
Returns:
BaseDataType: The dtype able to hold (potentially lossy) dtype1 and dtype2
"""
custom_assert(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}")
custom_assert(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}")
assert_true(isinstance(dtype1, BASE_DATA_TYPES), f"Unsupported dtype1: {type(dtype1)}")
assert_true(isinstance(dtype2, BASE_DATA_TYPES), f"Unsupported dtype2: {type(dtype2)}")
type_to_return: BaseDataType
@@ -205,15 +205,15 @@ def mix_tensor_values_determine_holding_dtype(
value2 dtypes.
"""
custom_assert(
assert_true(
isinstance(value1, TensorValue), f"Unsupported value1: {value1}, expected TensorValue"
)
custom_assert(
assert_true(
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
)
resulting_shape = broadcast_shapes(value1.shape, value2.shape)
custom_assert(
assert_true(
resulting_shape is not None,
(
f"Tensors have incompatible shapes which is not supported.\n"
@@ -250,7 +250,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
dtypes.
"""
custom_assert(
assert_true(
(value1.__class__ == value2.__class__),
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
)
@@ -274,7 +274,7 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
BaseDataType: The corresponding BaseDataType
"""
constant_data_type: BaseDataType
custom_assert(
assert_true(
isinstance(constant_data, (int, float)),
f"Unsupported constant data of type {type(constant_data)}",
)

View File

@@ -2,7 +2,7 @@
from functools import partial
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from . import base
@@ -15,7 +15,7 @@ class Float(base.BaseDataType):
def __init__(self, bit_width: int) -> None:
super().__init__()
custom_assert(bit_width in (32, 64), "Only 32 and 64 bits floats are supported")
assert_true(bit_width in (32, 64), "Only 32 and 64 bits floats are supported")
self.bit_width = bit_width
def __repr__(self) -> str:

View File

@@ -3,7 +3,7 @@
import math
from typing import Any, Iterable
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from . import base
@@ -15,7 +15,7 @@ class Integer(base.BaseDataType):
def __init__(self, bit_width: int, is_signed: bool) -> None:
super().__init__()
custom_assert(bit_width > 0, "bit_width must be > 0")
assert_true(bit_width > 0, "bit_width must be > 0")
self.bit_width = bit_width
self.is_signed = is_signed

View File

@@ -1,4 +1,4 @@
"""Module for debugging."""
from .custom_assert import custom_assert
from .custom_assert import assert_true
from .drawing import draw_graph
from .printing import get_printable_graph

View File

@@ -1,7 +1,7 @@
"""Provide some variants of assert."""
def custom_assert(condition: bool, on_error_msg: str = "") -> None:
def _custom_assert(condition: bool, on_error_msg: str = "") -> None:
"""Provide a custom assert which is kept even if the optimized python mode is used.
See https://docs.python.org/3/reference/simple_stmts.html#assert for the documentation
@@ -25,7 +25,7 @@ def assert_true(condition: bool, on_error_msg: str = ""):
on_error_msg(str): optional message for precising the error, in case of error
"""
return custom_assert(condition, on_error_msg)
return _custom_assert(condition, on_error_msg)
def assert_false(condition: bool, on_error_msg: str = ""):
@@ -36,7 +36,7 @@ def assert_false(condition: bool, on_error_msg: str = ""):
on_error_msg(str): optional message for precising the error, in case of error
"""
return custom_assert(not condition, on_error_msg)
return _custom_assert(not condition, on_error_msg)
def assert_not_reached(on_error_msg: str):
@@ -46,4 +46,4 @@ def assert_not_reached(on_error_msg: str):
on_error_msg(str): message for precising the error
"""
return custom_assert(False, on_error_msg)
return _custom_assert(False, on_error_msg)

View File

@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import (
ALL_IR_NODES,
@@ -36,7 +36,7 @@ IR_NODE_COLOR_MAPPING = {
}
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
custom_assert(
assert_true(
len(_missing_nodes_in_mapping) == 0,
(
f"Missing IR node in IR_NODE_COLOR_MAPPING : "

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
import networkx as nx
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, UnivariateFunction
@@ -50,7 +50,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
Returns:
str: a string to print or save in a file
"""
custom_assert(isinstance(opgraph, OPGraph))
assert_true(isinstance(opgraph, OPGraph))
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
graph = opgraph.graph
@@ -64,7 +64,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
# This code doesn't work with more than a single output. For more outputs,
# we would need to change the way the destination are created: currently,
# they only are done by incrementing i
custom_assert(len(node.outputs) == 1)
assert_true(len(node.outputs) == 1)
if isinstance(node, Input):
what_to_print = node.input_name
@@ -91,9 +91,9 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
# Some checks, because the previous algorithm is not clear
custom_assert(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)))
assert_true(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)))
list_of_arg_name.sort()
custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
assert_true([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
prefix_to_add_to_what_to_print = ""
suffix_to_add_to_what_to_print = ""
@@ -105,7 +105,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
if node.op_attributes["in_which_input_is_constant"] == 0:
prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, "
else:
custom_assert(
assert_true(
node.op_attributes["in_which_input_is_constant"] == 1,
"'in_which_input_is_constant' should be a key of node.op_attributes",
)

View File

@@ -21,15 +21,15 @@ from ..data_types.dtypes_helpers import (
value_is_encrypted_tensor_integer,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction
from ..values import TensorValue
def add(node, preds, ir_to_mlir_node, ctx):
"""Convert an addition intermediate node."""
custom_assert(len(node.inputs) == 2, "addition should have two inputs")
custom_assert(len(node.outputs) == 1, "addition should have a single output")
assert_true(len(node.inputs) == 2, "addition should have two inputs")
assert_true(len(node.outputs) == 1, "addition should have a single output")
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
node.inputs[1]
):
@@ -72,8 +72,8 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
def sub(node, preds, ir_to_mlir_node, ctx):
"""Convert a subtraction intermediate node."""
custom_assert(len(node.inputs) == 2, "subtraction should have two inputs")
custom_assert(len(node.outputs) == 1, "subtraction should have a single output")
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer(
node.inputs[1]
):
@@ -96,8 +96,8 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
def mul(node, preds, ir_to_mlir_node, ctx):
"""Convert a multiplication intermediate node."""
custom_assert(len(node.inputs) == 2, "multiplication should have two inputs")
custom_assert(len(node.outputs) == 1, "multiplication should have a single output")
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
node.inputs[1]
):
@@ -166,8 +166,8 @@ def constant(node, _, __, ctx):
def apply_lut(node, preds, ir_to_mlir_node, ctx):
"""Convert a UnivariateFunction intermediate node."""
custom_assert(len(node.inputs) == 1, "LUT should have a single input")
custom_assert(len(node.outputs) == 1, "LUT should have a single output")
assert_true(len(node.inputs) == 1, "LUT should have a single input")
assert_true(len(node.outputs) == 1, "LUT should have a single output")
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
raise TypeError("Only support LUT with encrypted unsigned integers inputs")
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
@@ -192,8 +192,8 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
def dot(node, preds, ir_to_mlir_node, ctx):
"""Convert a dot intermediate node."""
custom_assert(len(node.inputs) == 2, "Dot should have two inputs")
custom_assert(len(node.outputs) == 1, "Dot should have a single output")
assert_true(len(node.inputs) == 2, "Dot should have two inputs")
assert_true(len(node.outputs) == 1, "Dot should have a single output")
if not (
(
value_is_encrypted_tensor_integer(node.inputs[0])

View File

@@ -17,7 +17,7 @@ from ..data_types.dtypes_helpers import (
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_unsigned_integer,
)
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Input
@@ -83,7 +83,7 @@ class MLIRConverter:
if is_signed and not is_encrypted: # clear signed
return IntegerType.get_signed(bit_width)
# should be clear unsigned at this point
custom_assert(not is_signed and not is_encrypted)
assert_true(not is_signed and not is_encrypted)
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(bit_width)

View File

@@ -12,7 +12,7 @@ from .data_types.dtypes_helpers import (
)
from .data_types.floats import Float
from .data_types.integers import Integer, make_integer_to_hold
from .debugging.custom_assert import custom_assert
from .debugging.custom_assert import assert_true
from .representation.intermediate import Input, IntermediateNode
from .tracing import BaseTracer
from .tracing.tracing_helpers import create_graph_from_output_tracers
@@ -31,14 +31,12 @@ class OPGraph:
input_nodes: Dict[int, Input],
output_nodes: Dict[int, IntermediateNode],
) -> None:
custom_assert(
len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
)
custom_assert(
assert_true(len(input_nodes) > 0, "Got a graph without input nodes which is not supported")
assert_true(
all(isinstance(node, Input) for node in input_nodes.values()),
"Got input nodes that were not Input, which is not supported",
)
custom_assert(
assert_true(
all(isinstance(node, IntermediateNode) for node in output_nodes.values()),
"Got output nodes which were not IntermediateNode, which is not supported",
)
@@ -51,7 +49,7 @@ class OPGraph:
def __call__(self, *args) -> Union[Any, Tuple[Any, ...]]:
inputs = dict(enumerate(args))
custom_assert(
assert_true(
len(inputs) == len(self.input_nodes),
f"Expected {len(self.input_nodes)} arguments, got {len(inputs)} : {args}",
)
@@ -183,7 +181,7 @@ class OPGraph:
min_data_type_constructor = get_type_constructor_for_constant_data(min_bound)
max_data_type_constructor = get_type_constructor_for_constant_data(max_bound)
custom_assert(
assert_true(
max_data_type_constructor == min_data_type_constructor,
(
f"Got two different type constructors for min and max bound: "
@@ -200,7 +198,7 @@ class OPGraph:
(min_bound, max_bound), force_signed=False
)
else:
custom_assert(
assert_true(
isinstance(min_data_type, Float) and isinstance(max_data_type, Float),
(
"min_bound and max_bound have different common types, "
@@ -212,7 +210,7 @@ class OPGraph:
output_value.dtype.underlying_type_constructor = data_type_constructor
else:
# Currently variable inputs are only allowed to be integers
custom_assert(
assert_true(
isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer),
(
f"Inputs to a graph should be integers, got bounds that were float, \n"
@@ -229,7 +227,7 @@ class OPGraph:
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
# adding an edge
custom_assert(len(node.outputs) == 1)
assert_true(len(node.outputs) == 1)
successors = self.graph.succ[node]
for succ in successors:

View File

@@ -8,7 +8,7 @@ import networkx as nx
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true, custom_assert
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
from ..values import TensorValue
@@ -119,7 +119,7 @@ def convert_float_subgraph_to_fused_node(
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
custom_assert(len(variable_input_nodes) == 1)
assert_true(len(variable_input_nodes) == 1)
current_subgraph_variable_input = variable_input_nodes[0]
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])

View File

@@ -12,7 +12,7 @@ from ..data_types.dtypes_helpers import (
mix_values_determine_holding_dtype,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..values import BaseValue, ClearScalar, EncryptedScalar, TensorValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
@@ -33,7 +33,7 @@ class IntermediateNode(ABC):
**_kwargs, # This is to be able to feed arbitrary arguments to IntermediateNodes
) -> None:
self.inputs = list(inputs)
custom_assert(all(isinstance(x, BaseValue) for x in self.inputs))
assert_true(all(isinstance(x, BaseValue) for x in self.inputs))
# Register all IR nodes
def __init_subclass__(cls, **kwargs):
@@ -49,7 +49,7 @@ class IntermediateNode(ABC):
"""__init__ for a binary operation, ie two inputs."""
IntermediateNode.__init__(self, inputs)
custom_assert(len(self.inputs) == 2)
assert_true(len(self.inputs) == 2)
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
@@ -148,7 +148,7 @@ class Input(IntermediateNode):
program_input_idx: int,
) -> None:
super().__init__((input_value,))
custom_assert(len(self.inputs) == 1)
assert_true(len(self.inputs) == 1)
self.input_name = input_name
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]
@@ -222,7 +222,7 @@ class UnivariateFunction(IntermediateNode):
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value])
custom_assert(len(self.inputs) == 1)
assert_true(len(self.inputs) == 1)
self.arbitrary_func = arbitrary_func
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
@@ -306,9 +306,9 @@ class Dot(IntermediateNode):
] = default_dot_evaluation_function,
) -> None:
super().__init__(inputs)
custom_assert(len(self.inputs) == 2)
assert_true(len(self.inputs) == 2)
custom_assert(
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 1
for input_value in self.inputs

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import (
IR_MIX_VALUES_FUNC_ARG_NAME,
Add,
@@ -111,7 +111,7 @@ class BaseTracer(ABC):
Add,
)
custom_assert(len(result_tracer) == 1)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
@@ -128,7 +128,7 @@ class BaseTracer(ABC):
Sub,
)
custom_assert(len(result_tracer) == 1)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
@@ -140,7 +140,7 @@ class BaseTracer(ABC):
Sub,
)
custom_assert(len(result_tracer) == 1)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
@@ -152,7 +152,7 @@ class BaseTracer(ABC):
Mul,
)
custom_assert(len(result_tracer) == 1)
assert_true(len(result_tracer) == 1)
return result_tracer[0]
# With that is that x * 3 and 3 * x have the same graph. If we want to keep

View File

@@ -6,7 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
from ..debugging.custom_assert import assert_true, custom_assert
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import Input
from ..values import BaseValue
from .base_tracer import BaseTracer
@@ -124,7 +124,7 @@ def create_graph_from_output_tracers(
current_tracers = next_tracers
custom_assert(is_directed_acyclic_graph(graph))
assert_true(is_directed_acyclic_graph(graph))
# Check each edge is unique
unique_edges = set(

View File

@@ -17,7 +17,7 @@ from ..common.data_types.dtypes_helpers import (
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.debugging.custom_assert import custom_assert
from ..common.debugging.custom_assert import assert_true
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
@@ -72,13 +72,13 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d
Returns:
numpy.dtype: The resulting numpy.dtype
"""
custom_assert(
assert_true(
isinstance(common_dtype, BASE_DATA_TYPES), f"Unsupported common_dtype: {type(common_dtype)}"
)
type_to_return: numpy.dtype
if isinstance(common_dtype, Float):
custom_assert(
assert_true(
common_dtype.bit_width
in (
32,
@@ -117,7 +117,7 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
BaseDataType: The corresponding BaseDataType
"""
base_dtype: BaseDataType
custom_assert(
assert_true(
isinstance(
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
),
@@ -159,12 +159,12 @@ def get_base_value_for_numpy_or_python_constant_data(
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
"""
constant_data_value: Callable[..., BaseValue]
custom_assert(
assert_true(
not isinstance(constant_data, list),
"Unsupported constant data of type list "
"(if you meant to use a list as an array, please use numpy.array instead)",
)
custom_assert(
assert_true(
isinstance(
constant_data,
(int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
@@ -198,7 +198,7 @@ def get_numpy_function_output_dtype(
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
"""
if isinstance(function, numpy.ufunc):
custom_assert(
assert_true(
(len(input_dtypes) == function.nin),
f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}",
)
@@ -231,7 +231,7 @@ def get_type_constructor_for_numpy_or_python_constant_data(constant_data: Any):
constant_data (Any): The data for which we want to determine the type constructor.
"""
custom_assert(
assert_true(
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
f"Unsupported constant data of type {type(constant_data)}",
)

View File

@@ -7,7 +7,7 @@ import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_true, custom_assert
from ..common.debugging.custom_assert import assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import Constant, Dot, UnivariateFunction
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
@@ -41,7 +41,7 @@ class NPTracer(BaseTracer):
"""
if method == "__call__":
tracing_func = self.get_tracing_func_for_np_function(ufunc)
custom_assert(
assert_true(
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}",
)
@@ -58,7 +58,7 @@ class NPTracer(BaseTracer):
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
tracing_func = self.get_tracing_func_for_np_function(func)
custom_assert(
assert_true(
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy functions, func: {func}",
)
@@ -77,10 +77,10 @@ class NPTracer(BaseTracer):
Returns:
NPTracer: The NPTracer representing the casting operation
"""
custom_assert(
assert_true(
len(args) == 0, f"astype currently only supports tracing without *args, got {args}"
)
custom_assert(
assert_true(
(len(kwargs) == 0),
f"astype currently only supports tracing without **kwargs, got {kwargs}",
)
@@ -150,9 +150,9 @@ class NPTracer(BaseTracer):
Returns:
NPTracer: The output NPTracer containing the traced function
"""
custom_assert(len(input_tracers) == 1)
assert_true(len(input_tracers) == 1)
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
custom_assert(len(common_output_dtypes) == 1)
assert_true(len(common_output_dtypes) == 1)
traced_computation = UnivariateFunction(
input_base_value=input_tracers[0].output,
@@ -179,7 +179,7 @@ class NPTracer(BaseTracer):
Returns:
NPTracer: The output NPTracer containing the traced function
"""
custom_assert(len(input_tracers) == 2)
assert_true(len(input_tracers) == 2)
# One of the inputs has to be constant
if isinstance(input_tracers[0].traced_computation, Constant):
@@ -204,7 +204,7 @@ class NPTracer(BaseTracer):
return binary_operator(x, baked_constant, **kwargs)
common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers)
custom_assert(len(common_output_dtypes) == 1)
assert_true(len(common_output_dtypes) == 1)
op_kwargs = deepcopy(kwargs)
op_kwargs["baked_constant"] = baked_constant
@@ -242,7 +242,7 @@ class NPTracer(BaseTracer):
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
common_output_dtypes = self._manage_dtypes(numpy.dot, *args)
custom_assert(len(common_output_dtypes) == 1)
assert_true(len(common_output_dtypes) == 1)
traced_computation = Dot(
[input_tracer.output for input_tracer in args],
@@ -399,7 +399,7 @@ list_of_not_supported = [
if ufunc.nin not in [1, 2]
]
custom_assert(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
assert_true(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
del list_of_not_supported
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`