mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: manage sanitization properly for numpy functions
- allows to use dot with constant inputs
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..common.debugging.custom_assert import custom_assert
|
||||
from ..common.debugging.custom_assert import assert_true, custom_assert
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
@@ -53,7 +53,7 @@ class NPTracer(BaseTracer):
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
"""Catch calls to numpy function in routes them to hnp functions if supported.
|
||||
"""Catch calls to numpy function in routes them to tracing functions if supported.
|
||||
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
@@ -62,7 +62,8 @@ class NPTracer(BaseTracer):
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy functions, func: {func}",
|
||||
)
|
||||
return tracing_func(*args, **kwargs)
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
return tracing_func(self, *sanitized_args, **kwargs)
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
r"""Support numpy astype feature.
|
||||
@@ -222,26 +223,25 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
def dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Trace numpy.dot.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
# input_tracers contains the other tracer of the dot product
|
||||
dot_inputs = (self, self._sanitize(other_tracer))
|
||||
assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}")
|
||||
|
||||
common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs)
|
||||
common_output_dtypes = self._manage_dtypes(numpy.dot, *args)
|
||||
custom_assert(len(common_output_dtypes) == 1)
|
||||
|
||||
traced_computation = Dot(
|
||||
[input_tracer.output for input_tracer in dot_inputs],
|
||||
[input_tracer.output for input_tracer in args],
|
||||
common_output_dtypes[0],
|
||||
delegate_evaluation_function=numpy.dot,
|
||||
)
|
||||
|
||||
output_tracer = self.__class__(
|
||||
dot_inputs,
|
||||
args,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user