From f530a0b7396fd032bd09da8114ecac7401d04ec0 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 3 Nov 2021 16:14:10 +0100 Subject: [PATCH] refactor: make GenericFunction accept several inputs - remove baked constants - manage table generation for the updated node closes #600 closes #822 --- concrete/common/debugging/printing.py | 13 --- concrete/common/mlir/converters.py | 24 ++++-- concrete/common/mlir/mlir_converter.py | 9 +- concrete/common/mlir/utils.py | 24 +++++- concrete/common/operator_graph.py | 16 ++++ concrete/common/optimization/topological.py | 22 ++--- .../common/representation/intermediate.py | 53 +++++++++--- concrete/numpy/np_dtypes_helpers.py | 25 +++--- concrete/numpy/np_mlir_converter.py | 11 +-- concrete/numpy/tracing.py | 85 ++++++++----------- tests/common/extensions/test_table.py | 4 +- tests/numpy/test_debugging.py | 42 ++++++--- tests/numpy/test_np_mlir_converter.py | 8 +- 13 files changed, 206 insertions(+), 130 deletions(-) diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 04c800da9..33ba7b0c9 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -114,19 +114,6 @@ def get_printable_graph( prefix_to_add_to_what_to_print = "" suffix_to_add_to_what_to_print = "" - # Print constant that may be in the GenericFunction. For the moment, it considers - # there is a single constant maximally and that there is 2 inputs maximally - if isinstance(node, GenericFunction) and "baked_constant" in node.op_kwargs: - baked_constant = node.op_kwargs["baked_constant"] - if node.op_attributes["in_which_input_is_constant"] == 0: - prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, " - else: - assert_true( - node.op_attributes["in_which_input_is_constant"] == 1, - "'in_which_input_is_constant' should be a key of node.op_attributes", - ) - suffix_to_add_to_what_to_print = f", {shorten_a_constant(baked_constant)}" - # Then, just print the predecessors in the right order what_to_print += prefix_to_add_to_what_to_print what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 66e4eb53e..36ed34d2d 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -9,7 +9,7 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml from typing import cast # pylint: disable=no-name-in-module,no-member -import numpy as np +import numpy from mlir.dialects import arith as arith_dialect from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType from zamalang.dialects import hlfhe @@ -163,12 +163,24 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): """Convert a GenericFunction intermediate node.""" - assert_true(len(node.inputs) == 1, "LUT should have a single input") + + variable_input_indices = [ + idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant) + ] + + assert_true( + (non_constant_pred_count := len(variable_input_indices)) == 1, + f"LUT should have a single variable input (got {non_constant_pred_count})", + ) + + variable_input_idx = variable_input_indices[0] + variable_input_value = node.inputs[variable_input_idx] + assert_true(len(node.outputs) == 1, "LUT should have a single output") - if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): + if not value_is_encrypted_scalar_unsigned_integer(variable_input_value): raise TypeError( f"Only support LUT with encrypted unsigned integers inputs " - f"(but {node.inputs[0]} is provided)" + f"(but {variable_input_value} is provided)" ) if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]): raise TypeError( @@ -176,7 +188,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): f"(but {node.outputs[0]} is provided)" ) - x_node = preds[0] + x_node = preds[variable_input_idx] x = ir_to_mlir_node[x_node] tables = additional_conversion_info["tables"][node] @@ -192,7 +204,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): out_dtype = cast(Integer, node.outputs[0].dtype) # Create table - dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx) + dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx) tensor_lut = arith_dialect.ConstantOp( RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)), dense_elem, diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index ddb3a8fe2..b132b1220 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -171,14 +171,7 @@ class MLIRConverter(ABC): f"we don't yet support conversion to MLIR of computations using" f"{type(node)}" ) - # get sorted preds: sorted by their input index - # replication of pred is possible (e.g lambda x: x + x) - idx_to_pred = {} - for pred in op_graph.graph.pred[node]: - edge_data = op_graph.graph.get_edge_data(pred, node) - for data in edge_data.values(): - idx_to_pred[data["input_idx"]] = pred - preds = [idx_to_pred[i] for i in range(len(idx_to_pred))] + preds = op_graph.get_ordered_preds(node) # convert to mlir result = mlir_op( node, diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 2a7b566a6..c2b7baebf 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -1,6 +1,8 @@ """Utilities for MLIR conversion.""" from typing import Dict, List, Optional, cast +import networkx as nx + from ..data_types import Integer from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, @@ -21,11 +23,16 @@ from ..representation.intermediate import GenericFunction, IntermediateNode ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 -def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) -> Optional[str]: +def check_node_compatibility_with_mlir( + node: IntermediateNode, + nx_graph: nx.MultiDiGraph, + is_output: bool, +) -> Optional[str]: """Check if node is compatible with MLIR. Args: node (IntermediateNode): node to check + nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs is_output (bool): whether the node is an output node or not Returns: @@ -66,7 +73,16 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions if node.op_kind == "TLU": - assert_true(len(inputs) == 1) + assert_true( + len( + [ + pred_node + for pred_node in nx_graph.pred[node] + if not isinstance(pred_node, intermediate.Constant) + ] + ) + == 1 + ) if node.op_name == "MultiTLU": return "direct multi table lookup is not supported for the time being" if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]): @@ -124,7 +140,9 @@ def check_graph_values_compatibility_with_mlir( for node in op_graph.graph.nodes: is_output = node in op_graph.output_nodes.values() - if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None: + if ( + reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output) + ) is not None: offending_nodes[node] = [reason] return None if len(offending_nodes) == 0 else offending_nodes diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index 6222db791..a6c27a46c 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -115,6 +115,22 @@ class OPGraph: """ return [self.output_nodes[idx] for idx in range(len(self.output_nodes))] + def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]: + """Get node predecessors ordered by their indices. + + Args: + node (IntermediateNode): The node for which we want the ordered predecessors. + + Returns: + List[IntermediateNode]: The list of predecessors ordered by input index. + """ + # Replication of pred is managed e.g. x + x will yield the proper pred x twice + idx_to_pred: Dict[int, IntermediateNode] = {} + for pred in self.graph.pred[node]: + edge_data = self.graph.get_edge_data(pred, node) + idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values()) + return [idx_to_pred[i] for i in range(len(idx_to_pred))] + def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]: """Evaluate a graph and get intermediate values for all nodes. diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 5a755c258..26c7a86c3 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -323,20 +323,10 @@ def subgraph_nodes_and_values_allow_fusing( if len(explicitely_non_fusable) > 0: return False - # Some GenericFunction nodes have baked constants that need to be taken into account for the - # max size computation - baked_constants_ir_nodes = [ - baked_constant_ir_node - for node in subgraph_all_nodes - if isinstance(node, GenericFunction) - if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None)) - is not None - ] - all_values_are_tensors = all( all(isinstance(input_, TensorValue) for input_ in node.inputs) and all(isinstance(output, TensorValue) for output in node.outputs) - for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes) + for node in subgraph_all_nodes ) if not all_values_are_tensors: @@ -360,8 +350,14 @@ def subgraph_nodes_and_values_allow_fusing( variable_input_node_output.shape, ) max_inputs_size = max( - cast(TensorValue, input_node.outputs[0]).size - for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes) + itertools.chain( + (variable_input_node_output_size,), + ( + cast(TensorValue, constant_input_node.outputs[0]).size + for constant_input_node in subgraph_all_nodes + if isinstance(constant_input_node, Constant) + ), + ) ) if variable_input_node_output_size < max_inputs_size: diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index b0ee33ee0..445ec6e27 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -322,7 +322,6 @@ class GenericFunction(IntermediateNode): ) -> None: super().__init__(inputs) self._n_in = len(self.inputs) - assert_true(self._n_in == 1) # TODO: remove in later parts of refactoring of #600 self.arbitrary_func = arbitrary_func self.op_kind = GenericFunctionKind(op_kind) self.op_args = op_args if op_args is not None else () @@ -344,22 +343,42 @@ class GenericFunction(IntermediateNode): def label(self) -> str: return self.op_name - def get_table(self) -> List[Any]: + def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]: """Get the table for the current input value of this GenericFunction. - This function only works if the GenericFunction input value is an unsigned Integer. + This function only works if the GenericFunction variable input value is an unsigned Integer. + It only works if there is a single variable input node among ordered_preds. + + Args: + ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must + contain a single non constant node and any number of Constant nodes. Returns: List[Any]: The table. """ - input_dtype = self.inputs[0].dtype + variable_input_indices = [ + idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant) + ] + + assert_true( + (non_constant_pred_count := len(variable_input_indices)) == 1, + f"Can only have 1 non constant predecessor in {self.get_table.__name__}, " + f"got {non_constant_pred_count}", + ) + + variable_input_idx = variable_input_indices[0] + variable_input_dtype = self.inputs[variable_input_idx].dtype # Check the input is an unsigned integer to be able to build a table assert_true( - isinstance(input_dtype, Integer), "get_table only works for an unsigned Integer input" + isinstance(variable_input_dtype, Integer), + f"{self.get_table.__name__} only works for an unsigned Integer input", + ) + variable_input_dtype = cast(Integer, variable_input_dtype) + assert_true( + not variable_input_dtype.is_signed, + f"{self.get_table.__name__} only works for an unsigned Integer input", ) - input_dtype = cast(Integer, input_dtype) - assert_true(not input_dtype.is_signed, "get_table only works for an unsigned Integer input") input_value_constructor = self.inputs[0].underlying_constructor if input_value_constructor is None: @@ -368,8 +387,8 @@ class GenericFunction(IntermediateNode): ) input_value_constructor = int - min_input_range = input_dtype.min_value() - max_input_range = input_dtype.max_value() + 1 + min_input_range = variable_input_dtype.min_value() + max_input_range = variable_input_dtype.max_value() + 1 def catch(func, *args, **kwargs): try: @@ -378,8 +397,22 @@ class GenericFunction(IntermediateNode): except Exception: # pragma: no cover # pylint: disable=broad-except return None + template_input_dict = { + idx: node.evaluate({}) if isinstance(node, Constant) else None + for idx, node in enumerate(ordered_preds) + } + + def update_and_return_dict(dict_to_update: dict, update_values): + dict_to_update.update(update_values) + return dict_to_update + table = [ - catch(self.evaluate, {0: input_value_constructor(input_value)}) + catch( + self.evaluate, + update_and_return_dict( + template_input_dict, {variable_input_idx: input_value_constructor(input_value)} + ), + ) for input_value in range(min_input_range, max_input_range) ] diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index ecf51d31a..0e942ec8f 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -183,11 +183,11 @@ def get_base_value_for_numpy_or_python_constant_data( return constant_data_value -def get_numpy_function_output_dtype_from_input_dtypes( +def get_numpy_function_output_dtype_and_shape_from_input_dtypes( function: Union[numpy.ufunc, Callable], input_dtypes: List[BaseDataType], input_shapes: List[Tuple[int, ...]], -) -> List[numpy.dtype]: +) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]: """Record the output dtype of a numpy function given some input types. Args: @@ -199,7 +199,8 @@ def get_numpy_function_output_dtype_from_input_dtypes( the function inputs Returns: - List[numpy.dtype]: The ordered numpy dtypes of the function outputs + List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each + output of the function """ if isinstance(function, numpy.ufunc): assert_true( @@ -226,14 +227,14 @@ def get_numpy_function_output_dtype_from_input_dtypes( if not isinstance(outputs, tuple): outputs = (outputs,) - return [output.dtype for output in outputs] + return [(output.dtype, output.shape) for output in outputs] -def get_numpy_function_output_dtype_from_input_tracers( +def get_numpy_function_output_dtype_and_shape_from_input_tracers( func: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer, -) -> List[BaseDataType]: - """Determine output dtypes for a numpy function. +) -> List[Tuple[BaseDataType, Tuple[int, ...]]]: + """Determine output dtypes and shapes for a numpy function. This function is responsible for determining the output dtype of a numpy function after inputs with specific dtypes are passed to it. @@ -243,19 +244,23 @@ def get_numpy_function_output_dtype_from_input_tracers( *input_tracers (BaseTracer): inputs to the function Returns: - List[numpy.dtype]: appropriate BaseDataType for each output of the function + List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each + output of the function """ input_shapes = [ input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else () for input_tracer in input_tracers ] - output_dtypes = get_numpy_function_output_dtype_from_input_dtypes( + output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes( func, [input_tracer.output.dtype for input_tracer in input_tracers], input_shapes, ) - common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes] + common_output_dtypes = [ + (convert_numpy_dtype_to_base_data_type(dtype), shape) + for dtype, shape in output_dtypes_and_shapes + ] return common_output_dtypes diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py index 5326d1fc5..5b2f7e668 100644 --- a/concrete/numpy/np_mlir_converter.py +++ b/concrete/numpy/np_mlir_converter.py @@ -10,7 +10,7 @@ import numpy from ..common.debugging import assert_true from ..common.mlir.mlir_converter import MLIRConverter from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import GenericFunction +from ..common.representation.intermediate import GenericFunction, IntermediateNode class HashableNPArray: @@ -33,12 +33,13 @@ class HashableNPArray: def generate_deduplicated_tables( - node: GenericFunction, + node: GenericFunction, ordered_preds: List[IntermediateNode] ) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: """Deduplicate the tables for the different cells of a tensor if needed. Args: - node (GenericFunction): the node for which to deduplicate the table + node (GenericFunction): the node for which to deduplicate the table. + ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node. Returns: Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose @@ -47,7 +48,7 @@ def generate_deduplicated_tables( """ # This is the tensor containing the tables for each cell of the tensor for node node_complete_table = numpy.concatenate( - tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1 + tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1 ) all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1])) @@ -85,7 +86,7 @@ class NPMLIRConverter(MLIRConverter): # Disable numpy warnings during conversion to avoid issues during TLU generation with numpy.errstate(all="ignore"): additional_conversion_info["tables"] = { - node: generate_deduplicated_tables(node) + node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node)) for node in op_graph.graph.nodes() if isinstance(node, GenericFunction) } diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 27edfc554..f395448e8 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -16,7 +16,7 @@ from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, convert_numpy_dtype_to_base_data_type, get_base_value_for_numpy_or_python_constant_data, - get_numpy_function_output_dtype_from_input_tracers, + get_numpy_function_output_dtype_and_shape_from_input_tracers, ) from .np_indexing_helpers import process_indexing_element @@ -161,14 +161,19 @@ class NPTracer(BaseTracer): NPTracer: The output NPTracer containing the traced function """ assert_true(len(input_tracers) == 1) - common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( - unary_operator, - *input_tracers, + common_output_dtypes_and_shapes = ( + get_numpy_function_output_dtype_and_shape_from_input_tracers( + unary_operator, + *input_tracers, + ) ) - assert_true(len(common_output_dtypes) == 1) + assert_true(len(common_output_dtypes_and_shapes) == 1) - generic_function_output_value = deepcopy(input_tracers[0].output) - generic_function_output_value.dtype = common_output_dtypes[0] + output_dtype, output_shape = common_output_dtypes_and_shapes[0] + + generic_function_output_value = TensorValue( + output_dtype, input_tracers[0].output.is_encrypted, output_shape + ) traced_computation = GenericFunction( inputs=[deepcopy(input_tracers[0].output)], @@ -201,58 +206,40 @@ class NPTracer(BaseTracer): # One of the inputs has to be constant if isinstance(input_tracers[0].traced_computation, Constant): in_which_input_is_constant = 0 - baked_constant = deepcopy(input_tracers[0].traced_computation.constant_data) elif isinstance(input_tracers[1].traced_computation, Constant): in_which_input_is_constant = 1 - baked_constant = deepcopy(input_tracers[1].traced_computation.constant_data) else: raise NotImplementedError(f"Can't manage binary operator {binary_operator}") in_which_input_is_variable = 1 - in_which_input_is_constant - - if in_which_input_is_constant == 0: - - def arbitrary_func(x, baked_constant, **kwargs): - return binary_operator(baked_constant, x, **kwargs) - - else: - - def arbitrary_func(x, baked_constant, **kwargs): - return binary_operator(x, baked_constant, **kwargs) - - common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( - binary_operator, - *input_tracers, + common_output_dtypes_and_shapes = ( + get_numpy_function_output_dtype_and_shape_from_input_tracers( + binary_operator, + *input_tracers, + ) + ) + assert_true(len(common_output_dtypes_and_shapes) == 1) + + output_dtype, output_shape = common_output_dtypes_and_shapes[0] + + generic_function_output_value = TensorValue( + output_dtype, + input_tracers[in_which_input_is_variable].output.is_encrypted, + output_shape, ) - assert_true(len(common_output_dtypes) == 1) op_kwargs = deepcopy(kwargs) - op_kwargs["baked_constant"] = baked_constant - # Store info on the operation being treated - # Currently: the base value and type corresponding to the baked constant and which input idx - # it was feeding - op_attributes = { - "baked_constant_ir_node": deepcopy( - input_tracers[in_which_input_is_constant].traced_computation - ), - "in_which_input_is_constant": in_which_input_is_constant, - } - generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output) - generic_function_output_value.dtype = common_output_dtypes[0] - - # TODO: update inputs for #600 refactor traced_computation = GenericFunction( - inputs=[deepcopy(input_tracers[in_which_input_is_variable].output)], - arbitrary_func=arbitrary_func, + inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers], + arbitrary_func=binary_operator, output_value=generic_function_output_value, op_kind="TLU", op_kwargs=op_kwargs, op_name=binary_operator_string, - op_attributes=op_attributes, ) output_tracer = cls( - (input_tracers[in_which_input_is_variable],), + input_tracers, traced_computation=traced_computation, output_idx=0, ) @@ -266,12 +253,14 @@ class NPTracer(BaseTracer): """ assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}") - common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args) - assert_true(len(common_output_dtypes) == 1) + common_output_dtypes_and_shapes = ( + get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args) + ) + assert_true(len(common_output_dtypes_and_shapes) == 1) traced_computation = Dot( [input_tracer.output for input_tracer in args], - common_output_dtypes[0], + common_output_dtypes_and_shapes[0][0], delegate_evaluation_function=numpy.dot, ) @@ -638,14 +627,14 @@ def _on_numpy_multiply(lhs, rhs): def _on_numpy_matmul(lhs, rhs): - common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers( + common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers( numpy.matmul, lhs, rhs ) - assert_true(len(common_output_dtypes) == 1) + assert_true(len(common_output_dtypes_and_shapes) == 1) traced_computation = MatMul( [lhs.output, rhs.output], - common_output_dtypes[0], + common_output_dtypes_and_shapes[0][0], ) return NPTracer([lhs, rhs], traced_computation, output_idx=0) diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index a8aaccd31..82ca10374 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -45,7 +45,9 @@ def test_lookup_table_encrypted_lookup(test_helpers): x = EncryptedScalar(Integer(2, is_signed=False)) op_graph = tracing.trace_numpy_function(f, {"x": x}) - assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2] + table_node = op_graph.output_nodes[0] + + assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2] ref_graph = nx.MultiDiGraph() # Here is the ASCII drawing of the expected graph: diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index f79aecab7..d7542ca3f 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -114,11 +114,23 @@ def issue_130_c(x, y): ), ( lambda x, y: numpy.arctan2(x, 42) + y, - "%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n", + """%0 = y +%1 = x +%2 = Constant(42) +%3 = np.arctan2(%1, %2) +%4 = Add(%3, %0) +return(%4) +""", ), ( lambda x, y: numpy.arctan2(43, x) + y, - "%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n", + """%0 = y +%1 = Constant(43) +%2 = x +%3 = np.arctan2(%1, %2) +%4 = Add(%3, %0) +return(%4) +""", ), ], ) @@ -416,14 +428,22 @@ def test_numpy_long_constant(): ) expected = """ -%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor, shape=(1, 10)> -%1 = x # EncryptedTensor, shape=(10, 10)> -%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor, shape=(10, 10)> -%3 = Add(%1, %2) # EncryptedTensor, shape=(10, 10)> -%4 = Sub(%3, %0) # EncryptedTensor, shape=(10, 10)> -%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor, shape=(10, 10)> -%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor, shape=(10, 10)> -return(%6) +%0 = Constant([[100 101 ... 198 199]]) # ClearTensor, shape=(10, 10)> +%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor, shape=(1, 10)> +%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor, shape=(1, 10)> +%3 = x # EncryptedTensor, shape=(10, 10)> +%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor, shape=(10, 10)> +%5 = Add(%3, %4) # EncryptedTensor, shape=(10, 10)> +%6 = Sub(%5, %2) # EncryptedTensor, shape=(10, 10)> +%7 = np.arctan2(%1, %6) # EncryptedTensor, shape=(10, 10)> +%8 = np.arctan2(%0, %7) # EncryptedTensor, shape=(10, 10)> +return(%8) """.lstrip() # noqa: E501 - assert get_printable_graph(op_graph, show_data_types=True) == expected + str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + + assert str_of_the_graph == expected, ( + f"\n==================\nGot \n{str_of_the_graph}" + f"==================\nExpected \n{expected}" + f"==================\n" + ) diff --git a/tests/numpy/test_np_mlir_converter.py b/tests/numpy/test_np_mlir_converter.py index c63b7c659..c986def80 100644 --- a/tests/numpy/test_np_mlir_converter.py +++ b/tests/numpy/test_np_mlir_converter.py @@ -57,7 +57,9 @@ def test_generate_deduplicated_tables( tlu_node = univariate_function_nodes[0] - deduplication_result = generate_deduplicated_tables(tlu_node) + deduplication_result = generate_deduplicated_tables( + tlu_node, op_graph.get_ordered_preds(tlu_node) + ) assert len(deduplication_result) == expected_number_of_tables @@ -82,7 +84,9 @@ def test_deduplicated_tables_correctness(default_compilation_configuration): tlu_node = univariate_function_nodes[0] - deduplication_result = generate_deduplicated_tables(tlu_node) + deduplication_result = generate_deduplicated_tables( + tlu_node, op_graph.get_ordered_preds(tlu_node) + ) expected_result = tuple( (