refactor: replace UnivariateFunction by GenericFunction

- add an attribute fusable to False for the operations that should not be
explicitely fused from the original addition of GenericFunction
- add op_kind instance attribute to differentiate between TLU and memory
operations for GenericFunction

refs #600
This commit is contained in:
Arthur Meyre
2021-11-02 11:34:14 +01:00
parent fed3342c5f
commit d2faa90106
21 changed files with 330 additions and 208 deletions

View File

@@ -22,7 +22,6 @@ from ..representation.intermediate import (
MatMul,
Mul,
Sub,
UnivariateFunction,
)
IR_NODE_COLOR_MAPPING = {
@@ -31,12 +30,10 @@ IR_NODE_COLOR_MAPPING = {
Add: "red",
Sub: "yellow",
Mul: "green",
UnivariateFunction: "orange",
GenericFunction: "orange",
IndexConstant: "black",
Dot: "purple",
MatMul: "brown",
"UnivariateFunction": "orange",
"GenericFunction": "orange",
"TLU": "grey",
"output": "magenta",
@@ -78,7 +75,7 @@ def draw_graph(
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
if node in output_nodes:
value_to_return = IR_NODE_COLOR_MAPPING["output"]
elif isinstance(node, UnivariateFunction):
elif isinstance(node, GenericFunction):
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return

View File

@@ -12,7 +12,6 @@ from ..representation.intermediate import (
IndexConstant,
Input,
IntermediateNode,
UnivariateFunction,
)
@@ -92,7 +91,7 @@ def get_printable_graph(
base_name = node.__class__.__name__
if isinstance(node, (UnivariateFunction, GenericFunction)):
if isinstance(node, GenericFunction):
base_name = node.op_name
what_to_print = base_name + "("
@@ -115,9 +114,9 @@ def get_printable_graph(
prefix_to_add_to_what_to_print = ""
suffix_to_add_to_what_to_print = ""
# Print constant that may be in the UnivariateFunction. For the moment, it considers
# Print constant that may be in the GenericFunction. For the moment, it considers
# there is a single constant maximally and that there is 2 inputs maximally
if isinstance(node, UnivariateFunction) and "baked_constant" in node.op_kwargs:
if isinstance(node, GenericFunction) and "baked_constant" in node.op_kwargs:
baked_constant = node.op_kwargs["baked_constant"]
if node.op_attributes["in_which_input_is_constant"] == 0:
prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, "

View File

@@ -6,8 +6,9 @@ from typing import List, Tuple, Union
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
from ..representation.intermediate import UnivariateFunction
from ..representation.intermediate import GenericFunction
from ..tracing.base_tracer import BaseTracer
from ..values import ClearTensor, EncryptedTensor
from .table import LookupTable
@@ -93,10 +94,20 @@ class MultiLookupTable:
def __getitem__(self, key: Union[int, BaseTracer]):
# this branch is used during tracing and the regular flow is used during evaluation
if isinstance(key, BaseTracer):
traced_computation = UnivariateFunction(
out_dtype = deepcopy(key.output.dtype)
out_shape = deepcopy(self.input_shape)
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if key.output.is_encrypted
else ClearTensor(out_dtype, out_shape)
)
traced_computation = GenericFunction(
input_base_value=key.output,
arbitrary_func=MultiLookupTable._checked_indexing,
output_dtype=self.output_dtype,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={
"input_shape": deepcopy(self.input_shape),
"tables": deepcopy(self.tables),

View File

@@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union
from ..common_helpers import is_a_power_of_2
from ..data_types.base import BaseDataType
from ..data_types.integers import make_integer_to_hold
from ..representation.intermediate import UnivariateFunction
from ..representation.intermediate import GenericFunction
from ..tracing.base_tracer import BaseTracer
@@ -32,13 +32,17 @@ class LookupTable:
def __getitem__(self, key: Union[int, BaseTracer]):
# if a tracer is used for indexing,
# we need to create an `UnivariateFunction` node
# we need to create an `GenericFunction` node
# because the result will be determined during the runtime
if isinstance(key, BaseTracer):
traced_computation = UnivariateFunction(
generic_function_output_value = deepcopy(key.output)
generic_function_output_value.dtype = self.output_dtype
traced_computation = GenericFunction(
input_base_value=key.output,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=self.output_dtype,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(self.table)},
op_name="TLU",
)

View File

@@ -23,7 +23,7 @@ from ..data_types.dtypes_helpers import (
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, UnivariateFunction
from ..representation.intermediate import Add, Constant, Dot, GenericFunction, Mul, Sub
from ..values import TensorValue
@@ -162,7 +162,7 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
"""Convert a UnivariateFunction intermediate node."""
"""Convert a GenericFunction intermediate node."""
assert_true(len(node.inputs) == 1, "LUT should have a single input")
assert_true(len(node.outputs) == 1, "LUT should have a single output")
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
@@ -238,7 +238,7 @@ V0_OPSET_CONVERSION_FUNCTIONS = {
Sub: sub,
Mul: mul,
Constant: constant,
UnivariateFunction: apply_lut,
GenericFunction: apply_lut,
Dot: dot,
}

View File

@@ -15,7 +15,7 @@ from ..debugging import get_printable_graph
from ..debugging.custom_assert import assert_not_reached, assert_true
from ..operator_graph import OPGraph
from ..representation import intermediate
from ..representation.intermediate import IntermediateNode, UnivariateFunction
from ..representation.intermediate import GenericFunction, IntermediateNode
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
@@ -64,15 +64,17 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
if not value_is_integer(outputs[0]):
return "only integer constants are supported" # pragma: no cover
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
assert_true(len(inputs) == 1)
if node.op_name == "MultiTLU":
return "direct multi table lookup is not supported for the time being"
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
return "only unsigned integer scalar lookup tables are supported"
elif isinstance(node, intermediate.GenericFunction): # constraints for generic functions
return f"{node.op_name} is not supported for the time being" # pragma: no cover
elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions
if node.op_kind == "TLU":
assert_true(len(inputs) == 1)
if node.op_name == "MultiTLU":
return "direct multi table lookup is not supported for the time being"
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
return "only unsigned integer scalar lookup tables are supported"
else:
return (
f"{node.op_name} of kind {node.op_kind.value} is not supported for the time being"
)
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)
@@ -192,7 +194,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
op_graph: graph to update lookup tables for
"""
for node in op_graph.graph.nodes:
if isinstance(node, UnivariateFunction) and node.op_name == "TLU":
if isinstance(node, GenericFunction) and node.op_name == "TLU":
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
expected_length = 2 ** bit_width

View File

@@ -190,7 +190,7 @@ class OPGraph:
node_bounds_and_samples (dict): Dictionary with nodes as keys, holding dicts with a
'min', 'max' and 'sample' keys. Those bounds will be taken as the data range to be
represented, per node. The sample allows to determine the data constructors to
prepare the UnivariateFunction nodes for table generation.
prepare the GenericFunction nodes for table generation.
get_base_data_type_for_constant_data (Callable[ [Any], BaseDataType ], optional): This
is a callback function to convert data encountered during value updates to
BaseDataType. This allows to manage data coming from foreign frameworks without

View File

@@ -13,7 +13,7 @@ from ..data_types.integers import Integer
from ..debugging import get_printable_graph
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode
from ..values import TensorValue
@@ -21,7 +21,7 @@ def fuse_float_operations(
op_graph: OPGraph,
compilation_artifacts: Optional[CompilationArtifacts] = None,
):
"""Find and fuse float domains into single Integer to Integer UnivariateFunction.
"""Find and fuse float domains into single Integer to Integer GenericFunction.
Args:
op_graph (OPGraph): The OPGraph to simplify
@@ -76,7 +76,7 @@ def fuse_float_operations(
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
# fused_node is always a UnivariateFunction so output_idx == 0 always
# fused_node is always a GenericFunction so output_idx == 0 always
new_edge_data = deepcopy(edge_data)
new_edge_data["output_idx"] = 0
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
@@ -99,8 +99,8 @@ def convert_float_subgraph_to_fused_node(
float_subgraph_start_nodes: Set[IntermediateNode],
terminal_node: IntermediateNode,
subgraph_all_nodes: Set[IntermediateNode],
) -> Optional[Tuple[UnivariateFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused UnivariateFunction node.
) -> Optional[Tuple[GenericFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused GenericFunction node.
Args:
op_graph (OPGraph): The OPGraph the float subgraph is part of.
@@ -110,7 +110,7 @@ def convert_float_subgraph_to_fused_node(
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
Returns:
Optional[Tuple[UnivariateFunction, IntermediateNode]]: None if the float subgraph
Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
output must be plugged as the input to the subgraph.
"""
@@ -123,7 +123,7 @@ def convert_float_subgraph_to_fused_node(
if subgraph_can_be_fused:
# subgraph_values_allow_fusing can be called iff the subgraph has a unique variable input
subgraph_can_be_fused = subgraph_values_allow_fusing(
subgraph_can_be_fused = subgraph_nodes_and_values_allow_fusing(
float_subgraph_start_nodes, subgraph_all_nodes, node_with_issues_for_fusing
)
@@ -193,12 +193,13 @@ def convert_float_subgraph_to_fused_node(
assert_true(len(terminal_node.outputs) == 1)
# Create fused_node
fused_node = UnivariateFunction(
fused_node = GenericFunction(
deepcopy(new_subgraph_variable_input.inputs[0]),
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
terminal_node
],
deepcopy(terminal_node.outputs[0].dtype),
terminal_node.outputs[0],
op_kind="TLU",
op_kwargs={
"float_op_subgraph": float_op_subgraph,
"terminal_node": terminal_node,
@@ -277,14 +278,14 @@ def find_float_subgraph_with_unique_terminal_node(
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
def subgraph_values_allow_fusing(
def subgraph_nodes_and_values_allow_fusing(
float_subgraph_start_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check if a subgraph's values are compatible with fusing.
A fused subgraph for example only works on an input tensor if the resulting UnivariateFunction
A fused subgraph for example only works on an input tensor if the resulting GenericFunction
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
Args:
@@ -298,22 +299,36 @@ def subgraph_values_allow_fusing(
i.e. outputs have the same shapes equal to the variable input.
"""
node: IntermediateNode
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
assert_true(
(num_variable_input_nodes := len(variable_input_nodes)) == 1,
f"{subgraph_values_allow_fusing.__name__} "
f"{subgraph_nodes_and_values_allow_fusing.__name__} "
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
)
# Some UnivariateFunction nodes have baked constants that need to be taken into account for the
explicitely_non_fusable = [
node
for node in subgraph_all_nodes
if isinstance(node, GenericFunction) and not node.op_attributes["fusable"]
]
for node in explicitely_non_fusable:
node_with_issues_for_fusing[node].append(
"this node is explicitely marked by the package as non-fusable"
)
if len(explicitely_non_fusable) > 0:
return False
# Some GenericFunction nodes have baked constants that need to be taken into account for the
# max size computation
baked_constants_ir_nodes = [
baked_constant_ir_node
for node in subgraph_all_nodes
if isinstance(node, UnivariateFunction)
if isinstance(node, GenericFunction)
if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None))
is not None
]
@@ -332,7 +347,7 @@ def subgraph_values_allow_fusing(
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
# elements, meaning all constants will broadcast to its shape. This is because the
# UnivariateFunction input and output must have the same shape so that it can be applied to each
# GenericFunction input and output must have the same shape so that it can be applied to each
# of the input tensor cells.
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
# array and then concatenating/stacking the results. This is not currently doable as we don't
@@ -413,7 +428,7 @@ def subgraph_has_unique_variable_input(
variable_inputs_num = len(variable_inputs_list)
# Only one input to the subgraph where computations are done in floats can be variable, this
# is the only case we can manage with UnivariateFunction fusing
# is the only case we can manage with GenericFunction fusing
has_unique_variable_input = variable_inputs_num == 1
if not has_unique_variable_input:

View File

@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
from loguru import logger
@@ -280,7 +281,15 @@ def flood_replace_none_values(table: list):
assert_true(all(value is not None for value in table))
class UnivariateFunction(IntermediateNode):
@unique
class GenericFunctionKind(str, Enum):
"""Enum to validate GenericFunction op_kind."""
TLU = "TLU"
MEMORY = "Memory"
class GenericFunction(IntermediateNode):
"""Node representing an univariate arbitrary function, e.g. sin(x)."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
@@ -289,17 +298,23 @@ class UnivariateFunction(IntermediateNode):
# be the first argument passed to it. You can add other constant arguments needed for the proper
# execution of the function through op_args and op_kwargs.
arbitrary_func: Optional[Callable]
op_kind: GenericFunctionKind
op_name: str
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_attributes: Dict[str, Any]
_n_in: int = 1
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/798 have a proper attribute
# system
DEFAULT_OP_ATTRIBUTES: Dict[str, Any] = {"fusable": True}
def __init__(
self,
input_base_value: BaseValue,
arbitrary_func: Callable,
output_dtype: BaseDataType,
output_value: BaseValue,
op_kind: Union[str, GenericFunctionKind],
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
@@ -308,13 +323,14 @@ class UnivariateFunction(IntermediateNode):
super().__init__([input_base_value])
assert_true(len(self.inputs) == 1)
self.arbitrary_func = arbitrary_func
self.op_kind = GenericFunctionKind(op_kind)
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.op_attributes = op_attributes if op_attributes is not None else {}
self.op_attributes = deepcopy(self.DEFAULT_OP_ATTRIBUTES)
if op_attributes is not None:
self.op_attributes.update(op_attributes)
output = deepcopy(input_base_value)
output.dtype = output_dtype
self.outputs = [output]
self.outputs = [deepcopy(output_value)]
self.op_name = op_name if op_name is not None else self.__class__.__name__
@@ -327,9 +343,9 @@ class UnivariateFunction(IntermediateNode):
return self.op_name
def get_table(self) -> List[Any]:
"""Get the table for the current input value of this UnivariateFunction.
"""Get the table for the current input value of this GenericFunction.
This function only works if the UnivariateFunction input value is an unsigned Integer.
This function only works if the GenericFunction input value is an unsigned Integer.
Returns:
List[Any]: The table.
@@ -385,7 +401,7 @@ class Dot(IntermediateNode):
"""Return the node representing a dot product."""
_n_in: int = 2
# Optional, same issue as in UnivariateFunction for mypy
# Optional, same issue as in GenericFunction for mypy
evaluation_function: Optional[Callable[[Any, Any], Any]]
# Allows to use specialized implementations from e.g. numpy
@@ -475,52 +491,3 @@ class MatMul(IntermediateNode):
def label(self) -> str:
return "@"
class GenericFunction(IntermediateNode):
"""Return the node representing a generic function."""
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
# arbitrary_func can take more than one argument but during evaluation the input variable will
# be the first argument passed to it. You can add other constant arguments needed for the proper
# execution of the function through op_args and op_kwargs.
arbitrary_func: Optional[Callable]
op_name: str
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_attributes: Dict[str, Any]
_n_in: int = 1
def __init__(
self,
input_base_value: TensorValue,
arbitrary_func: Callable,
output_dtype: BaseDataType,
output_shape: Tuple,
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value])
assert_true(len(self.inputs) == 1)
self.arbitrary_func = arbitrary_func
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.op_attributes = op_attributes if op_attributes is not None else {}
self.outputs = [
EncryptedTensor(output_dtype, output_shape)
if self.inputs[0].is_encrypted
else ClearTensor(output_dtype, output_shape)
]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
def label(self) -> str:
return self.op_name

View File

@@ -98,7 +98,7 @@ def _compile_numpy_function_into_op_graph_internal(
# Apply topological optimizations if they are enabled
if compilation_configuration.enable_topological_optimizations:
# Fuse float operations to have int to int UnivariateFunction
# Fuse float operations to have int to int GenericFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph, compilation_artifacts)

View File

@@ -10,7 +10,7 @@ import numpy
from ..common.debugging import assert_true
from ..common.mlir.mlir_converter import MLIRConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import UnivariateFunction
from ..common.representation.intermediate import GenericFunction
class HashableNPArray:
@@ -33,12 +33,12 @@ class HashableNPArray:
def generate_deduplicated_tables(
node: UnivariateFunction,
node: GenericFunction,
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
"""Deduplicate the tables for the different cells of a tensor if needed.
Args:
node (UnivariateFunction): the node for which to deduplicate the table
node (GenericFunction): the node for which to deduplicate the table
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
@@ -87,7 +87,7 @@ class NPMLIRConverter(MLIRConverter):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node)
for node in op_graph.graph.nodes()
if isinstance(node, UnivariateFunction)
if isinstance(node, GenericFunction)
}
return additional_conversion_info

View File

@@ -9,15 +9,9 @@ from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.debugging.custom_assert import assert_false, assert_true
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import (
Constant,
Dot,
GenericFunction,
MatMul,
UnivariateFunction,
)
from ..common.representation.intermediate import Constant, Dot, GenericFunction, MatMul
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue, TensorValue
from ..common.values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
@@ -102,10 +96,13 @@ class NPTracer(BaseTracer):
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
traced_computation = UnivariateFunction(
generic_function_output_value = deepcopy(self.output)
generic_function_output_value.dtype = output_dtype
traced_computation = GenericFunction(
input_base_value=self.output,
arbitrary_func=lambda x, dtype: x.astype(dtype),
output_dtype=output_dtype,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"dtype": normalized_numpy_dtype.type},
op_name=f"astype({normalized_numpy_dtype})",
)
@@ -170,10 +167,14 @@ class NPTracer(BaseTracer):
)
assert_true(len(common_output_dtypes) == 1)
traced_computation = UnivariateFunction(
generic_function_output_value = deepcopy(input_tracers[0].output)
generic_function_output_value.dtype = common_output_dtypes[0]
traced_computation = GenericFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=unary_operator,
output_dtype=common_output_dtypes[0],
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=deepcopy(kwargs),
op_name=unary_operator_string,
)
@@ -237,10 +238,14 @@ class NPTracer(BaseTracer):
"in_which_input_is_constant": in_which_input_is_constant,
}
traced_computation = UnivariateFunction(
generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output)
generic_function_output_value.dtype = common_output_dtypes[0]
traced_computation = GenericFunction(
input_base_value=input_tracers[in_which_input_is_variable].output,
arbitrary_func=arbitrary_func,
output_dtype=common_output_dtypes[0],
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs=op_kwargs,
op_name=binary_operator_string,
op_attributes=op_attributes,
@@ -309,13 +314,23 @@ class NPTracer(BaseTracer):
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
out_dtype = first_arg_output.dtype
out_shape = first_arg_output.shape[::-1]
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.transpose,
output_dtype=first_arg_output.dtype,
output_shape=first_arg_output.shape[::-1],
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.transpose",
op_attributes={"fusable": False},
)
output_tracer = self.__class__(
args,
@@ -345,13 +360,23 @@ class NPTracer(BaseTracer):
first_arg_output = cast(TensorValue, first_arg_output)
assert_false(first_arg_output.is_scalar)
out_dtype = first_arg_output.dtype
out_shape = (numpy.product(first_arg_output.shape),)
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.ravel,
output_dtype=first_arg_output.dtype,
output_shape=(numpy.product(first_arg_output.shape),),
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.ravel",
op_attributes={"fusable": False},
)
output_tracer = self.__class__(
args,
@@ -401,13 +426,23 @@ class NPTracer(BaseTracer):
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
)
out_dtype = first_arg_output.dtype
out_shape = newshape
generic_function_output_value = (
EncryptedTensor(out_dtype, out_shape)
if first_arg_output.is_encrypted
else ClearTensor(out_dtype, out_shape)
)
traced_computation = GenericFunction(
input_base_value=first_arg_output,
arbitrary_func=numpy.reshape,
output_dtype=first_arg_output.dtype,
output_shape=newshape,
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs={"newshape": newshape},
op_name="np.reshape",
op_attributes={"fusable": False},
)
output_tracer = self.__class__(
[arg0],

View File

@@ -31,7 +31,7 @@ The float subgraph that was detected:
![](../../_static/float_fusing_example/subgraph.png)
The simplified graph of operations with the float subgraph condensed in an `UnivariateFunction` node:
The simplified graph of operations with the float subgraph condensed in an `GenericFunction` node:
![](../../_static/float_fusing_example/after.png)
@@ -39,7 +39,7 @@ The simplified graph of operations with the float subgraph condensed in an `Univ
The first step consists in detecting where we go from floating point computation back to integers. This allows to identify the potential terminal node of the float subgraph we are going to fuse.
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent UnivariateFunction node.
From the terminal node, we go back up through the nodes until we find nodes that go from integers to floats. If we can guarantee the identified float subgraph has a single variable integer input then we can replace it by an equivalent GenericFunction node.
An example of a non fusable computation with that technique is:

View File

@@ -83,7 +83,7 @@ Traceback (most recent call last):
File "/src/concrete/numpy/compile.py", line 103, in _compile_numpy_function_into_op_graph_internal
raise ValueError(
ValueError: <lambda> cannot be compiled as it has nodes with either float inputs or outputs.
Offending nodes : <concrete.common.representation.intermediate.UnivariateFunction object at 0x7f6689fd37f0>
Offending nodes : <concrete.common.representation.intermediate.GenericFunction object at 0x7f6689fd37f0>
```
## Manual export

View File

@@ -54,12 +54,16 @@ def test_lookup_table_encrypted_lookup(test_helpers):
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x)
generic_function_output_value = deepcopy(x)
generic_function_output_value.dtype = table.output_dtype
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
output_arbitrary_function = ir.UnivariateFunction(
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
output_arbitrary_function = ir.GenericFunction(
input_base_value=x,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=table.output_dtype,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
@@ -68,7 +72,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
@@ -94,12 +98,16 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
input_x = ir.Input(input_value=x, input_name="x", program_input_idx=0)
ref_graph.add_node(input_x)
generic_function_output_value = deepcopy(x)
generic_function_output_value.dtype = table.output_dtype
# pylint: disable=protected-access
# Need access to _checked_indexing to have is_equivalent_to work for ir.UnivariateFunction
intermediate_arbitrary_function = ir.UnivariateFunction(
# Need access to _checked_indexing to have is_equivalent_to work for ir.GenericFunction
intermediate_arbitrary_function = ir.GenericFunction(
input_base_value=x,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=table.output_dtype,
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
@@ -117,5 +125,5 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0)
ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
# TODO: discuss if this check is enough as == is not overloaded properly for GenericFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)

View File

@@ -32,6 +32,26 @@ def no_fuse_dot(x):
return numpy.dot(x, numpy.full((10,), 1.33, dtype=numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely(f, x):
"""No fuse because the function is explicitely marked as unfusable in our code."""
return f(x.astype(numpy.float64)).astype(numpy.int32)
def no_fuse_explicitely_ravel(x):
"""No fuse ravel"""
return no_fuse_explicitely(numpy.ravel, x)
def no_fuse_explicitely_transpose(x):
"""No fuse transpose"""
return no_fuse_explicitely(numpy.transpose, x)
def no_fuse_explicitely_reshape(x):
"""No fuse reshape"""
return no_fuse_explicitely(lambda x: numpy.reshape(x, (1,)), x)
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
@@ -112,11 +132,13 @@ def mix_x_and_y_into_integer_and_call_f(function, x, y):
)
def get_func_params_scalar_int32(func):
def get_func_params_int32(func, scalar=True):
"""Returns a dict with parameters as scalar int32"""
return {
param_name: EncryptedScalar(Integer(32, True))
if scalar
else EncryptedTensor(Integer(32, True), (1,))
for param_name in signature(func).parameters.keys()
}
@@ -124,11 +146,11 @@ def get_func_params_scalar_int32(func):
@pytest.mark.parametrize(
"function_to_trace,fused,params,warning_message",
[
pytest.param(no_fuse, False, get_func_params_scalar_int32(no_fuse), "", id="no_fuse"),
pytest.param(no_fuse, False, get_func_params_int32(no_fuse), "", id="no_fuse"),
pytest.param(
no_fuse_unhandled,
False,
get_func_params_scalar_int32(no_fuse_unhandled),
get_func_params_int32(no_fuse_unhandled),
"""The following subgraph is not fusable:
%0 = x # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
@@ -159,31 +181,70 @@ return(%7)""", # noqa: E501 # pylint: disable=line-too-long
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
no_fuse_explicitely_ravel,
False,
get_func_params_int32(no_fuse_explicitely_ravel, scalar=False),
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_ravel",
),
pytest.param(
no_fuse_explicitely_transpose,
False,
get_func_params_int32(no_fuse_explicitely_transpose, scalar=False),
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_transpose",
),
pytest.param(
no_fuse_explicitely_reshape,
False,
get_func_params_int32(no_fuse_explicitely_reshape, scalar=False),
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(1,)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(1,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
pytest.param(
simple_fuse_not_output,
True,
get_func_params_scalar_int32(simple_fuse_not_output),
get_func_params_int32(simple_fuse_not_output),
None,
id="simple_fuse_not_output",
),
pytest.param(
simple_fuse_output,
True,
get_func_params_scalar_int32(simple_fuse_output),
get_func_params_int32(simple_fuse_output),
None,
id="simple_fuse_output",
),
pytest.param(
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
True,
get_func_params_scalar_int32(lambda x, y: None),
get_func_params_int32(lambda x, y: None),
None,
id="mix_x_and_y_intricately_and_call_f_with_rint",
),
pytest.param(
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
True,
get_func_params_scalar_int32(lambda x, y: None),
get_func_params_int32(lambda x, y: None),
None,
id="mix_x_and_y_and_call_f_with_rint",
),

View File

@@ -35,45 +35,51 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
pytest.param(ir.Constant(42), None, 42, id="Constant"),
pytest.param(ir.Constant(-42), None, -42, id="Constant"),
pytest.param(
ir.UnivariateFunction(
EncryptedScalar(Integer(7, False)), lambda x: x + 3, Integer(7, False)
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
lambda x: x + 3,
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
),
[10],
13,
id="UnivariateFunction, x + 3",
id="GenericFunction, x + 3",
),
pytest.param(
ir.UnivariateFunction(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: x + y,
Integer(7, False),
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": 3},
),
[10],
13,
id="UnivariateFunction, (x, y) -> x + y, where y is constant == 3",
id="GenericFunction, (x, y) -> x + y, where y is constant == 3",
),
pytest.param(
ir.UnivariateFunction(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: y[x],
Integer(7, False),
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
3,
id="UnivariateFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
id="GenericFunction, (x, y) -> y[x], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.UnivariateFunction(
ir.GenericFunction(
EncryptedScalar(Integer(7, False)),
lambda x, y: y[3],
Integer(7, False),
EncryptedScalar(Integer(7, False)),
op_kind="TLU",
op_kwargs={"y": (1, 2, 3, 4)},
),
[2],
4,
id="UnivariateFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
id="GenericFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.Dot(
@@ -179,8 +185,8 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.transpose(x),
Integer(32, False),
output_shape=(5, 3),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.array([[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]]),
@@ -190,8 +196,8 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.ravel(x),
Integer(32, False),
output_shape=(5, 3),
EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15),
@@ -201,8 +207,8 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En
ir.GenericFunction(
EncryptedTensor(Integer(32, False), shape=(3, 5)),
lambda x: numpy.reshape(x, (5, 3)),
Integer(32, False),
output_shape=(5, 3),
output_value=EncryptedTensor(Integer(32, False), shape=(5, 3)),
op_kind="Memory",
),
[numpy.arange(15).reshape(3, 5)],
numpy.arange(15).reshape(5, 3),
@@ -306,35 +312,49 @@ def test_evaluate(
False,
),
(
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
True,
),
(
ir.UnivariateFunction(
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
Integer(8, False),
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_args=(1, 2, 3),
),
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
False,
),
(
ir.UnivariateFunction(
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
Integer(8, False),
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
op_kwargs={"tuple": (1, 2, 3)},
),
ir.UnivariateFunction(
EncryptedScalar(Integer(8, False)), lambda x: x, Integer(8, False)
ir.GenericFunction(
EncryptedScalar(Integer(8, False)),
lambda x: x,
EncryptedScalar(Integer(8, False)),
op_kind="TLU",
),
False,
),

View File

@@ -22,7 +22,6 @@ from concrete.common.representation.intermediate import (
MatMul,
Mul,
Sub,
UnivariateFunction,
)
@@ -111,13 +110,15 @@ def python_functions_are_equal_or_equivalent(lhs: object, rhs: object) -> bool:
return False
def is_equivalent_arbitrary_function(lhs: UnivariateFunction, rhs: object) -> bool:
"""Helper function to check if an UnivariateFunction node is equivalent to an other object."""
def is_equivalent_arbitrary_function(lhs: GenericFunction, rhs: object) -> bool:
"""Helper function to check if an GenericFunction node is equivalent to an other object."""
return (
isinstance(rhs, UnivariateFunction)
isinstance(rhs, GenericFunction)
and python_functions_are_equal_or_equivalent(lhs.arbitrary_func, rhs.arbitrary_func)
and lhs.op_kind == rhs.op_kind
and lhs.op_args == rhs.op_args
and lhs.op_kwargs == rhs.op_kwargs
and lhs.op_attributes == rhs.op_attributes
and lhs.op_name == rhs.op_name
and is_equivalent_intermediate_node(lhs, rhs)
)
@@ -186,7 +187,6 @@ def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
UnivariateFunction: is_equivalent_arbitrary_function,
GenericFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Dot: is_equivalent_dot,

View File

@@ -972,12 +972,13 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.transpose is not supported for the time being\n" # noqa: E501
"return(%1)\n"
"""function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.transpose of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
),
),
pytest.param(
@@ -985,12 +986,13 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(6,)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.ravel is not supported for the time being\n" # noqa: E501
"return(%1)\n"
"""function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(6,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.ravel of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
),
),
pytest.param(
@@ -998,12 +1000,13 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 4)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 4)>\n" # noqa: E501
"%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 6)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.reshape is not supported for the time being\n" # noqa: E501
"return(%1)\n"
"""function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 4)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 6)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.reshape of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
),
),
],

View File

@@ -6,7 +6,7 @@ import numpy
import pytest
import concrete.numpy as hnp
from concrete.common.representation.intermediate import UnivariateFunction
from concrete.common.representation.intermediate import GenericFunction
from concrete.numpy.np_mlir_converter import generate_deduplicated_tables
@@ -50,7 +50,7 @@ def test_generate_deduplicated_tables(
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
]
assert len(univariate_function_nodes) == 1
@@ -75,7 +75,7 @@ def test_deduplicated_tables_correctness(default_compilation_configuration):
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
node for node in op_graph.graph.nodes() if isinstance(node, GenericFunction)
]
assert len(univariate_function_nodes) == 1

View File

@@ -425,24 +425,24 @@ def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
[
pytest.param(
{"x": EncryptedScalar(Integer(7, is_signed=False))},
ir.UnivariateFunction,
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
ir.UnivariateFunction,
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(64, is_signed=True))},
ir.UnivariateFunction,
ir.GenericFunction,
),
pytest.param(
{"x": EncryptedScalar(Integer(128, is_signed=True))},
ir.UnivariateFunction,
ir.GenericFunction,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
{"x": EncryptedScalar(Float(64))},
ir.UnivariateFunction,
ir.GenericFunction,
),
],
)