refactor: make GenericFunction accept several inputs

- remove baked constants
- manage table generation for the updated node

closes #600
closes #822
This commit is contained in:
Arthur Meyre
2021-11-03 16:14:10 +01:00
parent 7f32cf7965
commit f530a0b739
13 changed files with 206 additions and 130 deletions

View File

@@ -114,19 +114,6 @@ def get_printable_graph(
prefix_to_add_to_what_to_print = ""
suffix_to_add_to_what_to_print = ""
# Print constant that may be in the GenericFunction. For the moment, it considers
# there is a single constant maximally and that there is 2 inputs maximally
if isinstance(node, GenericFunction) and "baked_constant" in node.op_kwargs:
baked_constant = node.op_kwargs["baked_constant"]
if node.op_attributes["in_which_input_is_constant"] == 0:
prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, "
else:
assert_true(
node.op_attributes["in_which_input_is_constant"] == 1,
"'in_which_input_is_constant' should be a key of node.op_attributes",
)
suffix_to_add_to_what_to_print = f", {shorten_a_constant(baked_constant)}"
# Then, just print the predecessors in the right order
what_to_print += prefix_to_add_to_what_to_print
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])

View File

@@ -9,7 +9,7 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
from typing import cast
# pylint: disable=no-name-in-module,no-member
import numpy as np
import numpy
from mlir.dialects import arith as arith_dialect
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from zamalang.dialects import hlfhe
@@ -163,12 +163,24 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
"""Convert a GenericFunction intermediate node."""
assert_true(len(node.inputs) == 1, "LUT should have a single input")
variable_input_indices = [
idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"LUT should have a single variable input (got {non_constant_pred_count})",
)
variable_input_idx = variable_input_indices[0]
variable_input_value = node.inputs[variable_input_idx]
assert_true(len(node.outputs) == 1, "LUT should have a single output")
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
if not value_is_encrypted_scalar_unsigned_integer(variable_input_value):
raise TypeError(
f"Only support LUT with encrypted unsigned integers inputs "
f"(but {node.inputs[0]} is provided)"
f"(but {variable_input_value} is provided)"
)
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
raise TypeError(
@@ -176,7 +188,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
f"(but {node.outputs[0]} is provided)"
)
x_node = preds[0]
x_node = preds[variable_input_idx]
x = ir_to_mlir_node[x_node]
tables = additional_conversion_info["tables"][node]
@@ -192,7 +204,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx)
tensor_lut = arith_dialect.ConstantOp(
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
dense_elem,

View File

@@ -171,14 +171,7 @@ class MLIRConverter(ABC):
f"we don't yet support conversion to MLIR of computations using"
f"{type(node)}"
)
# get sorted preds: sorted by their input index
# replication of pred is possible (e.g lambda x: x + x)
idx_to_pred = {}
for pred in op_graph.graph.pred[node]:
edge_data = op_graph.graph.get_edge_data(pred, node)
for data in edge_data.values():
idx_to_pred[data["input_idx"]] = pred
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
preds = op_graph.get_ordered_preds(node)
# convert to mlir
result = mlir_op(
node,

View File

@@ -1,6 +1,8 @@
"""Utilities for MLIR conversion."""
from typing import Dict, List, Optional, cast
import networkx as nx
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
@@ -21,11 +23,16 @@ from ..representation.intermediate import GenericFunction, IntermediateNode
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) -> Optional[str]:
def check_node_compatibility_with_mlir(
node: IntermediateNode,
nx_graph: nx.MultiDiGraph,
is_output: bool,
) -> Optional[str]:
"""Check if node is compatible with MLIR.
Args:
node (IntermediateNode): node to check
nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs
is_output (bool): whether the node is an output node or not
Returns:
@@ -66,7 +73,16 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions
if node.op_kind == "TLU":
assert_true(len(inputs) == 1)
assert_true(
len(
[
pred_node
for pred_node in nx_graph.pred[node]
if not isinstance(pred_node, intermediate.Constant)
]
)
== 1
)
if node.op_name == "MultiTLU":
return "direct multi table lookup is not supported for the time being"
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
@@ -124,7 +140,9 @@ def check_graph_values_compatibility_with_mlir(
for node in op_graph.graph.nodes:
is_output = node in op_graph.output_nodes.values()
if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None:
if (
reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output)
) is not None:
offending_nodes[node] = [reason]
return None if len(offending_nodes) == 0 else offending_nodes

View File

@@ -115,6 +115,22 @@ class OPGraph:
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]:
"""Get node predecessors ordered by their indices.
Args:
node (IntermediateNode): The node for which we want the ordered predecessors.
Returns:
List[IntermediateNode]: The list of predecessors ordered by input index.
"""
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
idx_to_pred: Dict[int, IntermediateNode] = {}
for pred in self.graph.pred[node]:
edge_data = self.graph.get_edge_data(pred, node)
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
"""Evaluate a graph and get intermediate values for all nodes.

View File

@@ -323,20 +323,10 @@ def subgraph_nodes_and_values_allow_fusing(
if len(explicitely_non_fusable) > 0:
return False
# Some GenericFunction nodes have baked constants that need to be taken into account for the
# max size computation
baked_constants_ir_nodes = [
baked_constant_ir_node
for node in subgraph_all_nodes
if isinstance(node, GenericFunction)
if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None))
is not None
]
all_values_are_tensors = all(
all(isinstance(input_, TensorValue) for input_ in node.inputs)
and all(isinstance(output, TensorValue) for output in node.outputs)
for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
for node in subgraph_all_nodes
)
if not all_values_are_tensors:
@@ -360,8 +350,14 @@ def subgraph_nodes_and_values_allow_fusing(
variable_input_node_output.shape,
)
max_inputs_size = max(
cast(TensorValue, input_node.outputs[0]).size
for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
itertools.chain(
(variable_input_node_output_size,),
(
cast(TensorValue, constant_input_node.outputs[0]).size
for constant_input_node in subgraph_all_nodes
if isinstance(constant_input_node, Constant)
),
)
)
if variable_input_node_output_size < max_inputs_size:

View File

@@ -322,7 +322,6 @@ class GenericFunction(IntermediateNode):
) -> None:
super().__init__(inputs)
self._n_in = len(self.inputs)
assert_true(self._n_in == 1) # TODO: remove in later parts of refactoring of #600
self.arbitrary_func = arbitrary_func
self.op_kind = GenericFunctionKind(op_kind)
self.op_args = op_args if op_args is not None else ()
@@ -344,22 +343,42 @@ class GenericFunction(IntermediateNode):
def label(self) -> str:
return self.op_name
def get_table(self) -> List[Any]:
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
"""Get the table for the current input value of this GenericFunction.
This function only works if the GenericFunction input value is an unsigned Integer.
This function only works if the GenericFunction variable input value is an unsigned Integer.
It only works if there is a single variable input node among ordered_preds.
Args:
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
contain a single non constant node and any number of Constant nodes.
Returns:
List[Any]: The table.
"""
input_dtype = self.inputs[0].dtype
variable_input_indices = [
idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"Can only have 1 non constant predecessor in {self.get_table.__name__}, "
f"got {non_constant_pred_count}",
)
variable_input_idx = variable_input_indices[0]
variable_input_dtype = self.inputs[variable_input_idx].dtype
# Check the input is an unsigned integer to be able to build a table
assert_true(
isinstance(input_dtype, Integer), "get_table only works for an unsigned Integer input"
isinstance(variable_input_dtype, Integer),
f"{self.get_table.__name__} only works for an unsigned Integer input",
)
variable_input_dtype = cast(Integer, variable_input_dtype)
assert_true(
not variable_input_dtype.is_signed,
f"{self.get_table.__name__} only works for an unsigned Integer input",
)
input_dtype = cast(Integer, input_dtype)
assert_true(not input_dtype.is_signed, "get_table only works for an unsigned Integer input")
input_value_constructor = self.inputs[0].underlying_constructor
if input_value_constructor is None:
@@ -368,8 +387,8 @@ class GenericFunction(IntermediateNode):
)
input_value_constructor = int
min_input_range = input_dtype.min_value()
max_input_range = input_dtype.max_value() + 1
min_input_range = variable_input_dtype.min_value()
max_input_range = variable_input_dtype.max_value() + 1
def catch(func, *args, **kwargs):
try:
@@ -378,8 +397,22 @@ class GenericFunction(IntermediateNode):
except Exception: # pragma: no cover # pylint: disable=broad-except
return None
template_input_dict = {
idx: node.evaluate({}) if isinstance(node, Constant) else None
for idx, node in enumerate(ordered_preds)
}
def update_and_return_dict(dict_to_update: dict, update_values):
dict_to_update.update(update_values)
return dict_to_update
table = [
catch(self.evaluate, {0: input_value_constructor(input_value)})
catch(
self.evaluate,
update_and_return_dict(
template_input_dict, {variable_input_idx: input_value_constructor(input_value)}
),
)
for input_value in range(min_input_range, max_input_range)
]

View File

@@ -183,11 +183,11 @@ def get_base_value_for_numpy_or_python_constant_data(
return constant_data_value
def get_numpy_function_output_dtype_from_input_dtypes(
def get_numpy_function_output_dtype_and_shape_from_input_dtypes(
function: Union[numpy.ufunc, Callable],
input_dtypes: List[BaseDataType],
input_shapes: List[Tuple[int, ...]],
) -> List[numpy.dtype]:
) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]:
"""Record the output dtype of a numpy function given some input types.
Args:
@@ -199,7 +199,8 @@ def get_numpy_function_output_dtype_from_input_dtypes(
the function inputs
Returns:
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each
output of the function
"""
if isinstance(function, numpy.ufunc):
assert_true(
@@ -226,14 +227,14 @@ def get_numpy_function_output_dtype_from_input_dtypes(
if not isinstance(outputs, tuple):
outputs = (outputs,)
return [output.dtype for output in outputs]
return [(output.dtype, output.shape) for output in outputs]
def get_numpy_function_output_dtype_from_input_tracers(
def get_numpy_function_output_dtype_and_shape_from_input_tracers(
func: Union[numpy.ufunc, Callable],
*input_tracers: BaseTracer,
) -> List[BaseDataType]:
"""Determine output dtypes for a numpy function.
) -> List[Tuple[BaseDataType, Tuple[int, ...]]]:
"""Determine output dtypes and shapes for a numpy function.
This function is responsible for determining the output dtype
of a numpy function after inputs with specific dtypes are passed to it.
@@ -243,19 +244,23 @@ def get_numpy_function_output_dtype_from_input_tracers(
*input_tracers (BaseTracer): inputs to the function
Returns:
List[numpy.dtype]: appropriate BaseDataType for each output of the function
List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each
output of the function
"""
input_shapes = [
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
for input_tracer in input_tracers
]
output_dtypes = get_numpy_function_output_dtype_from_input_dtypes(
output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes(
func,
[input_tracer.output.dtype for input_tracer in input_tracers],
input_shapes,
)
common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes]
common_output_dtypes = [
(convert_numpy_dtype_to_base_data_type(dtype), shape)
for dtype, shape in output_dtypes_and_shapes
]
return common_output_dtypes

View File

@@ -10,7 +10,7 @@ import numpy
from ..common.debugging import assert_true
from ..common.mlir.mlir_converter import MLIRConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import GenericFunction
from ..common.representation.intermediate import GenericFunction, IntermediateNode
class HashableNPArray:
@@ -33,12 +33,13 @@ class HashableNPArray:
def generate_deduplicated_tables(
node: GenericFunction,
node: GenericFunction, ordered_preds: List[IntermediateNode]
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
"""Deduplicate the tables for the different cells of a tensor if needed.
Args:
node (GenericFunction): the node for which to deduplicate the table
node (GenericFunction): the node for which to deduplicate the table.
ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node.
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
@@ -47,7 +48,7 @@ def generate_deduplicated_tables(
"""
# This is the tensor containing the tables for each cell of the tensor for node
node_complete_table = numpy.concatenate(
tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1
tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1
)
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
@@ -85,7 +86,7 @@ class NPMLIRConverter(MLIRConverter):
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node)
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
for node in op_graph.graph.nodes()
if isinstance(node, GenericFunction)
}

View File

@@ -16,7 +16,7 @@ from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_numpy_function_output_dtype_from_input_tracers,
get_numpy_function_output_dtype_and_shape_from_input_tracers,
)
from .np_indexing_helpers import process_indexing_element
@@ -161,14 +161,19 @@ class NPTracer(BaseTracer):
NPTracer: The output NPTracer containing the traced function
"""
assert_true(len(input_tracers) == 1)
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
unary_operator,
*input_tracers,
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(
unary_operator,
*input_tracers,
)
)
assert_true(len(common_output_dtypes) == 1)
assert_true(len(common_output_dtypes_and_shapes) == 1)
generic_function_output_value = deepcopy(input_tracers[0].output)
generic_function_output_value.dtype = common_output_dtypes[0]
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
generic_function_output_value = TensorValue(
output_dtype, input_tracers[0].output.is_encrypted, output_shape
)
traced_computation = GenericFunction(
inputs=[deepcopy(input_tracers[0].output)],
@@ -201,58 +206,40 @@ class NPTracer(BaseTracer):
# One of the inputs has to be constant
if isinstance(input_tracers[0].traced_computation, Constant):
in_which_input_is_constant = 0
baked_constant = deepcopy(input_tracers[0].traced_computation.constant_data)
elif isinstance(input_tracers[1].traced_computation, Constant):
in_which_input_is_constant = 1
baked_constant = deepcopy(input_tracers[1].traced_computation.constant_data)
else:
raise NotImplementedError(f"Can't manage binary operator {binary_operator}")
in_which_input_is_variable = 1 - in_which_input_is_constant
if in_which_input_is_constant == 0:
def arbitrary_func(x, baked_constant, **kwargs):
return binary_operator(baked_constant, x, **kwargs)
else:
def arbitrary_func(x, baked_constant, **kwargs):
return binary_operator(x, baked_constant, **kwargs)
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
binary_operator,
*input_tracers,
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(
binary_operator,
*input_tracers,
)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
generic_function_output_value = TensorValue(
output_dtype,
input_tracers[in_which_input_is_variable].output.is_encrypted,
output_shape,
)
assert_true(len(common_output_dtypes) == 1)
op_kwargs = deepcopy(kwargs)
op_kwargs["baked_constant"] = baked_constant
# Store info on the operation being treated
# Currently: the base value and type corresponding to the baked constant and which input idx
# it was feeding
op_attributes = {
"baked_constant_ir_node": deepcopy(
input_tracers[in_which_input_is_constant].traced_computation
),
"in_which_input_is_constant": in_which_input_is_constant,
}
generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output)
generic_function_output_value.dtype = common_output_dtypes[0]
# TODO: update inputs for #600 refactor
traced_computation = GenericFunction(
inputs=[deepcopy(input_tracers[in_which_input_is_variable].output)],
arbitrary_func=arbitrary_func,
inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers],
arbitrary_func=binary_operator,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=op_kwargs,
op_name=binary_operator_string,
op_attributes=op_attributes,
)
output_tracer = cls(
(input_tracers[in_which_input_is_variable],),
input_tracers,
traced_computation=traced_computation,
output_idx=0,
)
@@ -266,12 +253,14 @@ class NPTracer(BaseTracer):
"""
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args)
assert_true(len(common_output_dtypes) == 1)
common_output_dtypes_and_shapes = (
get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args)
)
assert_true(len(common_output_dtypes_and_shapes) == 1)
traced_computation = Dot(
[input_tracer.output for input_tracer in args],
common_output_dtypes[0],
common_output_dtypes_and_shapes[0][0],
delegate_evaluation_function=numpy.dot,
)
@@ -638,14 +627,14 @@ def _on_numpy_multiply(lhs, rhs):
def _on_numpy_matmul(lhs, rhs):
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
numpy.matmul, lhs, rhs
)
assert_true(len(common_output_dtypes) == 1)
assert_true(len(common_output_dtypes_and_shapes) == 1)
traced_computation = MatMul(
[lhs.output, rhs.output],
common_output_dtypes[0],
common_output_dtypes_and_shapes[0][0],
)
return NPTracer([lhs, rhs], traced_computation, output_idx=0)

View File

@@ -45,7 +45,9 @@ def test_lookup_table_encrypted_lookup(test_helpers):
x = EncryptedScalar(Integer(2, is_signed=False))
op_graph = tracing.trace_numpy_function(f, {"x": x})
assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2]
table_node = op_graph.output_nodes[0]
assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2]
ref_graph = nx.MultiDiGraph()
# Here is the ASCII drawing of the expected graph:

View File

@@ -114,11 +114,23 @@ def issue_130_c(x, y):
),
(
lambda x, y: numpy.arctan2(x, 42) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n",
"""%0 = y
%1 = x
%2 = Constant(42)
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
(
lambda x, y: numpy.arctan2(43, x) + y,
"%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n",
"""%0 = y
%1 = Constant(43)
%2 = x
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
],
)
@@ -416,14 +428,22 @@ def test_numpy_long_constant():
)
expected = """
%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%3 = Add(%1, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Sub(%3, %0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%6)
%0 = Constant([[100 101 ... 198 199]]) # ClearTensor<Integer<unsigned, 8 bits>, shape=(10, 10)>
%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor<Integer<unsigned, 5 bits>, shape=(1, 10)>
%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%3 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%5 = Add(%3, %4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%6 = Sub(%5, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%7 = np.arctan2(%1, %6) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%8 = np.arctan2(%0, %7) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%8)
""".lstrip() # noqa: E501
assert get_printable_graph(op_graph, show_data_types=True) == expected
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
assert str_of_the_graph == expected, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{expected}"
f"==================\n"
)

View File

@@ -57,7 +57,9 @@ def test_generate_deduplicated_tables(
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(tlu_node)
deduplication_result = generate_deduplicated_tables(
tlu_node, op_graph.get_ordered_preds(tlu_node)
)
assert len(deduplication_result) == expected_number_of_tables
@@ -82,7 +84,9 @@ def test_deduplicated_tables_correctness(default_compilation_configuration):
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(tlu_node)
deduplication_result = generate_deduplicated_tables(
tlu_node, op_graph.get_ordered_preds(tlu_node)
)
expected_result = tuple(
(