mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: make GenericFunction accept several inputs
- remove baked constants - manage table generation for the updated node closes #600 closes #822
This commit is contained in:
@@ -114,19 +114,6 @@ def get_printable_graph(
|
||||
prefix_to_add_to_what_to_print = ""
|
||||
suffix_to_add_to_what_to_print = ""
|
||||
|
||||
# Print constant that may be in the GenericFunction. For the moment, it considers
|
||||
# there is a single constant maximally and that there is 2 inputs maximally
|
||||
if isinstance(node, GenericFunction) and "baked_constant" in node.op_kwargs:
|
||||
baked_constant = node.op_kwargs["baked_constant"]
|
||||
if node.op_attributes["in_which_input_is_constant"] == 0:
|
||||
prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, "
|
||||
else:
|
||||
assert_true(
|
||||
node.op_attributes["in_which_input_is_constant"] == 1,
|
||||
"'in_which_input_is_constant' should be a key of node.op_attributes",
|
||||
)
|
||||
suffix_to_add_to_what_to_print = f", {shorten_a_constant(baked_constant)}"
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
what_to_print += prefix_to_add_to_what_to_print
|
||||
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
|
||||
|
||||
@@ -9,7 +9,7 @@ Converter functions all have the same signature `converter(node, preds, ir_to_ml
|
||||
from typing import cast
|
||||
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import numpy as np
|
||||
import numpy
|
||||
from mlir.dialects import arith as arith_dialect
|
||||
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
@@ -163,12 +163,24 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
"""Convert a GenericFunction intermediate node."""
|
||||
assert_true(len(node.inputs) == 1, "LUT should have a single input")
|
||||
|
||||
variable_input_indices = [
|
||||
idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"LUT should have a single variable input (got {non_constant_pred_count})",
|
||||
)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_value = node.inputs[variable_input_idx]
|
||||
|
||||
assert_true(len(node.outputs) == 1, "LUT should have a single output")
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
if not value_is_encrypted_scalar_unsigned_integer(variable_input_value):
|
||||
raise TypeError(
|
||||
f"Only support LUT with encrypted unsigned integers inputs "
|
||||
f"(but {node.inputs[0]} is provided)"
|
||||
f"(but {variable_input_value} is provided)"
|
||||
)
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.outputs[0]):
|
||||
raise TypeError(
|
||||
@@ -176,7 +188,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
f"(but {node.outputs[0]} is provided)"
|
||||
)
|
||||
|
||||
x_node = preds[0]
|
||||
x_node = preds[variable_input_idx]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
tables = additional_conversion_info["tables"][node]
|
||||
|
||||
@@ -192,7 +204,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
|
||||
out_dtype = cast(Integer, node.outputs[0].dtype)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
|
||||
dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx)
|
||||
tensor_lut = arith_dialect.ConstantOp(
|
||||
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
|
||||
dense_elem,
|
||||
|
||||
@@ -171,14 +171,7 @@ class MLIRConverter(ABC):
|
||||
f"we don't yet support conversion to MLIR of computations using"
|
||||
f"{type(node)}"
|
||||
)
|
||||
# get sorted preds: sorted by their input index
|
||||
# replication of pred is possible (e.g lambda x: x + x)
|
||||
idx_to_pred = {}
|
||||
for pred in op_graph.graph.pred[node]:
|
||||
edge_data = op_graph.graph.get_edge_data(pred, node)
|
||||
for data in edge_data.values():
|
||||
idx_to_pred[data["input_idx"]] = pred
|
||||
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
preds = op_graph.get_ordered_preds(node)
|
||||
# convert to mlir
|
||||
result = mlir_op(
|
||||
node,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
@@ -21,11 +23,16 @@ from ..representation.intermediate import GenericFunction, IntermediateNode
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
|
||||
|
||||
def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) -> Optional[str]:
|
||||
def check_node_compatibility_with_mlir(
|
||||
node: IntermediateNode,
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
is_output: bool,
|
||||
) -> Optional[str]:
|
||||
"""Check if node is compatible with MLIR.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): node to check
|
||||
nx_graph (nx.MultiDiGraph): the networkx graph to which node belongs
|
||||
is_output (bool): whether the node is an output node or not
|
||||
|
||||
Returns:
|
||||
@@ -66,7 +73,16 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
|
||||
elif isinstance(node, intermediate.GenericFunction): # constraints for univariate functions
|
||||
if node.op_kind == "TLU":
|
||||
assert_true(len(inputs) == 1)
|
||||
assert_true(
|
||||
len(
|
||||
[
|
||||
pred_node
|
||||
for pred_node in nx_graph.pred[node]
|
||||
if not isinstance(pred_node, intermediate.Constant)
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
if node.op_name == "MultiTLU":
|
||||
return "direct multi table lookup is not supported for the time being"
|
||||
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
|
||||
@@ -124,7 +140,9 @@ def check_graph_values_compatibility_with_mlir(
|
||||
|
||||
for node in op_graph.graph.nodes:
|
||||
is_output = node in op_graph.output_nodes.values()
|
||||
if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None:
|
||||
if (
|
||||
reason := check_node_compatibility_with_mlir(node, op_graph.graph, is_output)
|
||||
) is not None:
|
||||
offending_nodes[node] = [reason]
|
||||
|
||||
return None if len(offending_nodes) == 0 else offending_nodes
|
||||
|
||||
@@ -115,6 +115,22 @@ class OPGraph:
|
||||
"""
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def get_ordered_preds(self, node: IntermediateNode) -> List[IntermediateNode]:
|
||||
"""Get node predecessors ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): The node for which we want the ordered predecessors.
|
||||
|
||||
Returns:
|
||||
List[IntermediateNode]: The list of predecessors ordered by input index.
|
||||
"""
|
||||
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
|
||||
idx_to_pred: Dict[int, IntermediateNode] = {}
|
||||
for pred in self.graph.pred[node]:
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
|
||||
"""Evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
|
||||
@@ -323,20 +323,10 @@ def subgraph_nodes_and_values_allow_fusing(
|
||||
if len(explicitely_non_fusable) > 0:
|
||||
return False
|
||||
|
||||
# Some GenericFunction nodes have baked constants that need to be taken into account for the
|
||||
# max size computation
|
||||
baked_constants_ir_nodes = [
|
||||
baked_constant_ir_node
|
||||
for node in subgraph_all_nodes
|
||||
if isinstance(node, GenericFunction)
|
||||
if (baked_constant_ir_node := node.op_attributes.get("baked_constant_ir_node", None))
|
||||
is not None
|
||||
]
|
||||
|
||||
all_values_are_tensors = all(
|
||||
all(isinstance(input_, TensorValue) for input_ in node.inputs)
|
||||
and all(isinstance(output, TensorValue) for output in node.outputs)
|
||||
for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
|
||||
for node in subgraph_all_nodes
|
||||
)
|
||||
|
||||
if not all_values_are_tensors:
|
||||
@@ -360,8 +350,14 @@ def subgraph_nodes_and_values_allow_fusing(
|
||||
variable_input_node_output.shape,
|
||||
)
|
||||
max_inputs_size = max(
|
||||
cast(TensorValue, input_node.outputs[0]).size
|
||||
for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
|
||||
itertools.chain(
|
||||
(variable_input_node_output_size,),
|
||||
(
|
||||
cast(TensorValue, constant_input_node.outputs[0]).size
|
||||
for constant_input_node in subgraph_all_nodes
|
||||
if isinstance(constant_input_node, Constant)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if variable_input_node_output_size < max_inputs_size:
|
||||
|
||||
@@ -322,7 +322,6 @@ class GenericFunction(IntermediateNode):
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
self._n_in = len(self.inputs)
|
||||
assert_true(self._n_in == 1) # TODO: remove in later parts of refactoring of #600
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_kind = GenericFunctionKind(op_kind)
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
@@ -344,22 +343,42 @@ class GenericFunction(IntermediateNode):
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
def get_table(self) -> List[Any]:
|
||||
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
|
||||
"""Get the table for the current input value of this GenericFunction.
|
||||
|
||||
This function only works if the GenericFunction input value is an unsigned Integer.
|
||||
This function only works if the GenericFunction variable input value is an unsigned Integer.
|
||||
It only works if there is a single variable input node among ordered_preds.
|
||||
|
||||
Args:
|
||||
ordered_preds (List[IntermediateNode]): List of predecessors of the node. This list must
|
||||
contain a single non constant node and any number of Constant nodes.
|
||||
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
"""
|
||||
|
||||
input_dtype = self.inputs[0].dtype
|
||||
variable_input_indices = [
|
||||
idx for idx, pred in enumerate(ordered_preds) if not isinstance(pred, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"Can only have 1 non constant predecessor in {self.get_table.__name__}, "
|
||||
f"got {non_constant_pred_count}",
|
||||
)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_dtype = self.inputs[variable_input_idx].dtype
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
assert_true(
|
||||
isinstance(input_dtype, Integer), "get_table only works for an unsigned Integer input"
|
||||
isinstance(variable_input_dtype, Integer),
|
||||
f"{self.get_table.__name__} only works for an unsigned Integer input",
|
||||
)
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
assert_true(
|
||||
not variable_input_dtype.is_signed,
|
||||
f"{self.get_table.__name__} only works for an unsigned Integer input",
|
||||
)
|
||||
input_dtype = cast(Integer, input_dtype)
|
||||
assert_true(not input_dtype.is_signed, "get_table only works for an unsigned Integer input")
|
||||
|
||||
input_value_constructor = self.inputs[0].underlying_constructor
|
||||
if input_value_constructor is None:
|
||||
@@ -368,8 +387,8 @@ class GenericFunction(IntermediateNode):
|
||||
)
|
||||
input_value_constructor = int
|
||||
|
||||
min_input_range = input_dtype.min_value()
|
||||
max_input_range = input_dtype.max_value() + 1
|
||||
min_input_range = variable_input_dtype.min_value()
|
||||
max_input_range = variable_input_dtype.max_value() + 1
|
||||
|
||||
def catch(func, *args, **kwargs):
|
||||
try:
|
||||
@@ -378,8 +397,22 @@ class GenericFunction(IntermediateNode):
|
||||
except Exception: # pragma: no cover # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
template_input_dict = {
|
||||
idx: node.evaluate({}) if isinstance(node, Constant) else None
|
||||
for idx, node in enumerate(ordered_preds)
|
||||
}
|
||||
|
||||
def update_and_return_dict(dict_to_update: dict, update_values):
|
||||
dict_to_update.update(update_values)
|
||||
return dict_to_update
|
||||
|
||||
table = [
|
||||
catch(self.evaluate, {0: input_value_constructor(input_value)})
|
||||
catch(
|
||||
self.evaluate,
|
||||
update_and_return_dict(
|
||||
template_input_dict, {variable_input_idx: input_value_constructor(input_value)}
|
||||
),
|
||||
)
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
|
||||
@@ -183,11 +183,11 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
return constant_data_value
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype_from_input_dtypes(
|
||||
def get_numpy_function_output_dtype_and_shape_from_input_dtypes(
|
||||
function: Union[numpy.ufunc, Callable],
|
||||
input_dtypes: List[BaseDataType],
|
||||
input_shapes: List[Tuple[int, ...]],
|
||||
) -> List[numpy.dtype]:
|
||||
) -> List[Tuple[numpy.dtype, Tuple[int, ...]]]:
|
||||
"""Record the output dtype of a numpy function given some input types.
|
||||
|
||||
Args:
|
||||
@@ -199,7 +199,8 @@ def get_numpy_function_output_dtype_from_input_dtypes(
|
||||
the function inputs
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
|
||||
List[Tuple[numpy.dtype, Tuple[int, ...]]]: appropriate (numpy.dtype, shape) tuple for each
|
||||
output of the function
|
||||
"""
|
||||
if isinstance(function, numpy.ufunc):
|
||||
assert_true(
|
||||
@@ -226,14 +227,14 @@ def get_numpy_function_output_dtype_from_input_dtypes(
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
return [output.dtype for output in outputs]
|
||||
return [(output.dtype, output.shape) for output in outputs]
|
||||
|
||||
|
||||
def get_numpy_function_output_dtype_from_input_tracers(
|
||||
def get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
func: Union[numpy.ufunc, Callable],
|
||||
*input_tracers: BaseTracer,
|
||||
) -> List[BaseDataType]:
|
||||
"""Determine output dtypes for a numpy function.
|
||||
) -> List[Tuple[BaseDataType, Tuple[int, ...]]]:
|
||||
"""Determine output dtypes and shapes for a numpy function.
|
||||
|
||||
This function is responsible for determining the output dtype
|
||||
of a numpy function after inputs with specific dtypes are passed to it.
|
||||
@@ -243,19 +244,23 @@ def get_numpy_function_output_dtype_from_input_tracers(
|
||||
*input_tracers (BaseTracer): inputs to the function
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: appropriate BaseDataType for each output of the function
|
||||
List[Tuple[BaseDataType, Tuple[int, ...]]]: appropriate (BaseDataType, shape) tuple for each
|
||||
output of the function
|
||||
"""
|
||||
|
||||
input_shapes = [
|
||||
input_tracer.output.shape if isinstance(input_tracer.output, TensorValue) else ()
|
||||
for input_tracer in input_tracers
|
||||
]
|
||||
output_dtypes = get_numpy_function_output_dtype_from_input_dtypes(
|
||||
output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_dtypes(
|
||||
func,
|
||||
[input_tracer.output.dtype for input_tracer in input_tracers],
|
||||
input_shapes,
|
||||
)
|
||||
common_output_dtypes = [convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes]
|
||||
common_output_dtypes = [
|
||||
(convert_numpy_dtype_to_base_data_type(dtype), shape)
|
||||
for dtype, shape in output_dtypes_and_shapes
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import numpy
|
||||
from ..common.debugging import assert_true
|
||||
from ..common.mlir.mlir_converter import MLIRConverter
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import GenericFunction
|
||||
from ..common.representation.intermediate import GenericFunction, IntermediateNode
|
||||
|
||||
|
||||
class HashableNPArray:
|
||||
@@ -33,12 +33,13 @@ class HashableNPArray:
|
||||
|
||||
|
||||
def generate_deduplicated_tables(
|
||||
node: GenericFunction,
|
||||
node: GenericFunction, ordered_preds: List[IntermediateNode]
|
||||
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
"""Deduplicate the tables for the different cells of a tensor if needed.
|
||||
|
||||
Args:
|
||||
node (GenericFunction): the node for which to deduplicate the table
|
||||
node (GenericFunction): the node for which to deduplicate the table.
|
||||
ordered_preds (List[IntermediateNode]): ordered list of predecessors of the node.
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
|
||||
@@ -47,7 +48,7 @@ def generate_deduplicated_tables(
|
||||
"""
|
||||
# This is the tensor containing the tables for each cell of the tensor for node
|
||||
node_complete_table = numpy.concatenate(
|
||||
tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1
|
||||
tuple(numpy.expand_dims(array, -1) for array in node.get_table(ordered_preds)), axis=-1
|
||||
)
|
||||
|
||||
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
|
||||
@@ -85,7 +86,7 @@ class NPMLIRConverter(MLIRConverter):
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: generate_deduplicated_tables(node)
|
||||
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, GenericFunction)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_numpy_function_output_dtype_from_input_tracers,
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers,
|
||||
)
|
||||
from .np_indexing_helpers import process_indexing_element
|
||||
|
||||
@@ -161,14 +161,19 @@ class NPTracer(BaseTracer):
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true(len(input_tracers) == 1)
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
unary_operator,
|
||||
*input_tracers,
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
unary_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
generic_function_output_value = deepcopy(input_tracers[0].output)
|
||||
generic_function_output_value.dtype = common_output_dtypes[0]
|
||||
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
output_dtype, input_tracers[0].output.is_encrypted, output_shape
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(input_tracers[0].output)],
|
||||
@@ -201,58 +206,40 @@ class NPTracer(BaseTracer):
|
||||
# One of the inputs has to be constant
|
||||
if isinstance(input_tracers[0].traced_computation, Constant):
|
||||
in_which_input_is_constant = 0
|
||||
baked_constant = deepcopy(input_tracers[0].traced_computation.constant_data)
|
||||
elif isinstance(input_tracers[1].traced_computation, Constant):
|
||||
in_which_input_is_constant = 1
|
||||
baked_constant = deepcopy(input_tracers[1].traced_computation.constant_data)
|
||||
else:
|
||||
raise NotImplementedError(f"Can't manage binary operator {binary_operator}")
|
||||
|
||||
in_which_input_is_variable = 1 - in_which_input_is_constant
|
||||
|
||||
if in_which_input_is_constant == 0:
|
||||
|
||||
def arbitrary_func(x, baked_constant, **kwargs):
|
||||
return binary_operator(baked_constant, x, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
def arbitrary_func(x, baked_constant, **kwargs):
|
||||
return binary_operator(x, baked_constant, **kwargs)
|
||||
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
binary_operator,
|
||||
*input_tracers,
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
binary_operator,
|
||||
*input_tracers,
|
||||
)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
output_dtype, output_shape = common_output_dtypes_and_shapes[0]
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
output_dtype,
|
||||
input_tracers[in_which_input_is_variable].output.is_encrypted,
|
||||
output_shape,
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
op_kwargs["baked_constant"] = baked_constant
|
||||
# Store info on the operation being treated
|
||||
# Currently: the base value and type corresponding to the baked constant and which input idx
|
||||
# it was feeding
|
||||
op_attributes = {
|
||||
"baked_constant_ir_node": deepcopy(
|
||||
input_tracers[in_which_input_is_constant].traced_computation
|
||||
),
|
||||
"in_which_input_is_constant": in_which_input_is_constant,
|
||||
}
|
||||
|
||||
generic_function_output_value = deepcopy(input_tracers[in_which_input_is_variable].output)
|
||||
generic_function_output_value.dtype = common_output_dtypes[0]
|
||||
|
||||
# TODO: update inputs for #600 refactor
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(input_tracers[in_which_input_is_variable].output)],
|
||||
arbitrary_func=arbitrary_func,
|
||||
inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers],
|
||||
arbitrary_func=binary_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs=op_kwargs,
|
||||
op_name=binary_operator_string,
|
||||
op_attributes=op_attributes,
|
||||
)
|
||||
output_tracer = cls(
|
||||
(input_tracers[in_which_input_is_variable],),
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
@@ -266,12 +253,14 @@ class NPTracer(BaseTracer):
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
|
||||
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(numpy.dot, *args)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
common_output_dtypes_and_shapes = (
|
||||
get_numpy_function_output_dtype_and_shape_from_input_tracers(numpy.dot, *args)
|
||||
)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
traced_computation = Dot(
|
||||
[input_tracer.output for input_tracer in args],
|
||||
common_output_dtypes[0],
|
||||
common_output_dtypes_and_shapes[0][0],
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
)
|
||||
|
||||
@@ -638,14 +627,14 @@ def _on_numpy_multiply(lhs, rhs):
|
||||
|
||||
|
||||
def _on_numpy_matmul(lhs, rhs):
|
||||
common_output_dtypes = get_numpy_function_output_dtype_from_input_tracers(
|
||||
common_output_dtypes_and_shapes = get_numpy_function_output_dtype_and_shape_from_input_tracers(
|
||||
numpy.matmul, lhs, rhs
|
||||
)
|
||||
assert_true(len(common_output_dtypes) == 1)
|
||||
assert_true(len(common_output_dtypes_and_shapes) == 1)
|
||||
|
||||
traced_computation = MatMul(
|
||||
[lhs.output, rhs.output],
|
||||
common_output_dtypes[0],
|
||||
common_output_dtypes_and_shapes[0][0],
|
||||
)
|
||||
return NPTracer([lhs, rhs], traced_computation, output_idx=0)
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
x = EncryptedScalar(Integer(2, is_signed=False))
|
||||
op_graph = tracing.trace_numpy_function(f, {"x": x})
|
||||
|
||||
assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2]
|
||||
table_node = op_graph.output_nodes[0]
|
||||
|
||||
assert table_node.get_table(op_graph.get_ordered_preds(table_node)) == [3, 6, 0, 2]
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
# Here is the ASCII drawing of the expected graph:
|
||||
|
||||
@@ -114,11 +114,23 @@ def issue_130_c(x, y):
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(x, 42) + y,
|
||||
"%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n",
|
||||
"""%0 = y
|
||||
%1 = x
|
||||
%2 = Constant(42)
|
||||
%3 = np.arctan2(%1, %2)
|
||||
%4 = Add(%3, %0)
|
||||
return(%4)
|
||||
""",
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(43, x) + y,
|
||||
"%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n",
|
||||
"""%0 = y
|
||||
%1 = Constant(43)
|
||||
%2 = x
|
||||
%3 = np.arctan2(%1, %2)
|
||||
%4 = Add(%3, %0)
|
||||
return(%4)
|
||||
""",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -416,14 +428,22 @@ def test_numpy_long_constant():
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
|
||||
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
|
||||
%3 = Add(%1, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%4 = Sub(%3, %0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
return(%6)
|
||||
%0 = Constant([[100 101 ... 198 199]]) # ClearTensor<Integer<unsigned, 8 bits>, shape=(10, 10)>
|
||||
%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor<Integer<unsigned, 5 bits>, shape=(1, 10)>
|
||||
%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
|
||||
%3 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
|
||||
%5 = Add(%3, %4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%6 = Sub(%5, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%7 = np.arctan2(%1, %6) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%8 = np.arctan2(%0, %7) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
return(%8)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == expected, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{expected}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
@@ -57,7 +57,9 @@ def test_generate_deduplicated_tables(
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(tlu_node)
|
||||
deduplication_result = generate_deduplicated_tables(
|
||||
tlu_node, op_graph.get_ordered_preds(tlu_node)
|
||||
)
|
||||
|
||||
assert len(deduplication_result) == expected_number_of_tables
|
||||
|
||||
@@ -82,7 +84,9 @@ def test_deduplicated_tables_correctness(default_compilation_configuration):
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(tlu_node)
|
||||
deduplication_result = generate_deduplicated_tables(
|
||||
tlu_node, op_graph.get_ordered_preds(tlu_node)
|
||||
)
|
||||
|
||||
expected_result = tuple(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user