mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(nptracer): add dot tracing abilities
- remove no cover from Dot.label - small refactor of BaseTracer to make _sanitize a class method - small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs - add function routing to NPTracer - add dot tracing to NPTracer - small refactor to get tracing functions for numpy funcs and ufuncs
This commit is contained in:
@@ -351,6 +351,5 @@ class Dot(IntermediateNode):
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
# TODO: Coverage will come with the ability to trace the operator in a subsequent PR
|
||||
def label(self) -> str: # pragma: no cover
|
||||
def label(self) -> str:
|
||||
return "dot"
|
||||
|
||||
@@ -53,6 +53,11 @@ class BaseTracer(ABC):
|
||||
def _get_mix_values_func(cls):
|
||||
return cls._mix_values_func
|
||||
|
||||
def _sanitize(self, inp) -> "BaseTracer":
|
||||
if not isinstance(inp, BaseTracer):
|
||||
return self._make_const_input_tracer(inp)
|
||||
return inp
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
@@ -69,13 +74,9 @@ class BaseTracer(ABC):
|
||||
Returns:
|
||||
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
|
||||
"""
|
||||
# For inputs which are actually constant, first convert into a tracer
|
||||
def sanitize(inp):
|
||||
if not isinstance(inp, BaseTracer):
|
||||
return self._make_const_input_tracer(inp)
|
||||
return inp
|
||||
|
||||
sanitized_inputs = [sanitize(inp) for inp in inputs]
|
||||
# For inputs which are actually constant, first convert into a tracer
|
||||
sanitized_inputs = [self._sanitize(inp) for inp in inputs]
|
||||
|
||||
additional_parameters = (
|
||||
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -154,23 +154,25 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
return constant_data_value
|
||||
|
||||
|
||||
def get_ufunc_numpy_output_dtype(
|
||||
ufunc: numpy.ufunc,
|
||||
def get_numpy_function_output_dtype(
|
||||
function: Union[numpy.ufunc, Callable],
|
||||
input_dtypes: List[BaseDataType],
|
||||
) -> List[numpy.dtype]:
|
||||
"""Function to record the output dtype of a numpy.ufunc given some input types.
|
||||
"""Function to record the output dtype of a numpy function given some input types.
|
||||
|
||||
Args:
|
||||
ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded
|
||||
input_dtypes (List[BaseDataType]): Common dtypes in the same order as they will be used with
|
||||
the ufunc inputs
|
||||
function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to
|
||||
be recorded
|
||||
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
|
||||
the function inputs
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
|
||||
"""
|
||||
assert (
|
||||
len(input_dtypes) == ufunc.nin
|
||||
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
|
||||
if isinstance(function, numpy.ufunc):
|
||||
assert (
|
||||
len(input_dtypes) == function.nin
|
||||
), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}"
|
||||
|
||||
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
@@ -183,7 +185,7 @@ def get_ufunc_numpy_output_dtype(
|
||||
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
|
||||
)
|
||||
|
||||
outputs = ufunc(*dummy_inputs)
|
||||
outputs = function(*dummy_inputs)
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
"""hnumpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from ..common.values import BaseValue
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_ufunc_numpy_output_dtype,
|
||||
get_numpy_function_output_dtype,
|
||||
)
|
||||
|
||||
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
|
||||
@@ -39,13 +39,24 @@ class NPTracer(BaseTracer):
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
if method == "__call__":
|
||||
tracing_func = self.get_tracing_func_for_np_ufunc(ufunc)
|
||||
tracing_func = self.get_tracing_func_for_np_function(ufunc)
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
|
||||
return tracing_func(self, *input_tracers, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
"""Catch calls to numpy function in routes them to hnp functions if supported.
|
||||
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
tracing_func = self.get_tracing_func_for_np_function(func)
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"hnumpy does not support **kwargs currently for numpy functions, func: {func}"
|
||||
return tracing_func(*args, **kwargs)
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
r"""Support numpy astype feature.
|
||||
|
||||
@@ -77,22 +88,27 @@ class NPTracer(BaseTracer):
|
||||
return output_tracer
|
||||
|
||||
@staticmethod
|
||||
def get_tracing_func_for_np_ufunc(ufunc: numpy.ufunc) -> Callable:
|
||||
"""Get the tracing function for a numpy ufunc.
|
||||
def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable:
|
||||
"""Get the tracing function for a numpy function.
|
||||
|
||||
Args:
|
||||
ufunc (numpy.ufunc): The numpy ufunc that will be traced
|
||||
func (Union[numpy.ufunc, Callable]): The numpy function that will be traced
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if the passed ufunc is not supported by NPTracer
|
||||
NotImplementedError: Raised if the passed function is not supported by NPTracer
|
||||
|
||||
Returns:
|
||||
Callable: the tracing function that needs to be called to trace ufunc
|
||||
Callable: the tracing function that needs to be called to trace func
|
||||
"""
|
||||
tracing_func = NPTracer.UFUNC_ROUTING.get(ufunc, None)
|
||||
tracing_func: Optional[Callable]
|
||||
if isinstance(func, numpy.ufunc):
|
||||
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
|
||||
else:
|
||||
tracing_func = NPTracer.FUNC_ROUTING.get(func, None)
|
||||
|
||||
if tracing_func is None:
|
||||
raise NotImplementedError(
|
||||
f"NPTracer does not yet manage the following ufunc: {ufunc.__name__}"
|
||||
f"NPTracer does not yet manage the following func: {func.__name__}"
|
||||
)
|
||||
return tracing_func
|
||||
|
||||
@@ -105,8 +121,8 @@ class NPTracer(BaseTracer):
|
||||
return self.__class__([], NPConstant(constant_data), 0)
|
||||
|
||||
@staticmethod
|
||||
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
|
||||
output_dtypes = get_ufunc_numpy_output_dtype(
|
||||
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
|
||||
output_dtypes = get_numpy_function_output_dtype(
|
||||
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
|
||||
)
|
||||
common_output_dtypes = [
|
||||
@@ -132,7 +148,9 @@ class NPTracer(BaseTracer):
|
||||
op_name="np.rint",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@@ -154,7 +172,34 @@ class NPTracer(BaseTracer):
|
||||
op_name="np.sin",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.dot.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
# input_tracers contains the other tracer of the dot product
|
||||
dot_inputs = (self, self._sanitize(other_tracer))
|
||||
|
||||
common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = Dot(
|
||||
[input_tracer.output for input_tracer in dot_inputs],
|
||||
common_output_dtypes[0],
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
)
|
||||
|
||||
output_tracer = self.__class__(
|
||||
dot_inputs,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@@ -163,6 +208,10 @@ class NPTracer(BaseTracer):
|
||||
numpy.sin: sin,
|
||||
}
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
numpy.dot: dot,
|
||||
}
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Test file for hnumpy debugging functions"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.common.values import ClearValue, EncryptedValue
|
||||
from hdk.common.values import ClearValue, EncryptedTensor, EncryptedValue
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
|
||||
@@ -178,6 +179,36 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
# pylint: disable=unnecessary-lambda
|
||||
(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
},
|
||||
"\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)",
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
"Test hnumpy get_printable_graph and draw_graph on graphs with dot"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot {str_of_the_graph}"
|
||||
f"\n==================\nExpected {ref_graph_str}"
|
||||
f"\n==================\n"
|
||||
)
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# dataset
|
||||
|
||||
@@ -291,10 +291,64 @@ def test_trace_hnumpy_supported_ufuncs(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"np_ufunc,expected_tracing_func",
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
# pylint: disable=unnecessary-lambda
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
|
||||
"y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedValue(Integer(32, False)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Float(64), shape=(42,)),
|
||||
"y": EncryptedTensor(Float(64), shape=(10,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
|
||||
},
|
||||
ir.Dot,
|
||||
ClearValue(Integer(64, is_signed=True)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
|
||||
},
|
||||
ir.Dot,
|
||||
EncryptedValue(Integer(64, True)),
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value):
|
||||
"""Function to test dot tracing"""
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"np_function,expected_tracing_func",
|
||||
[
|
||||
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
||||
pytest.param(numpy.sin, tracing.NPTracer.sin),
|
||||
pytest.param(numpy.dot, tracing.NPTracer.dot),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
# good long term candidate to check for an unsupported function
|
||||
@@ -303,9 +357,9 @@ def test_trace_hnumpy_supported_ufuncs(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func):
|
||||
"""Test NPTracer get_tracing_func_for_np_ufunc"""
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func):
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user