diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 4af5fd446..5cd78bc4f 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -7,7 +7,7 @@ import numpy from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype -from ..common.debugging.custom_assert import custom_assert +from ..common.debugging.custom_assert import assert_true, custom_assert from ..common.operator_graph import OPGraph from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters @@ -53,7 +53,7 @@ class NPTracer(BaseTracer): raise NotImplementedError("Only __call__ method is supported currently") def __array_function__(self, func, _types, args, kwargs): - """Catch calls to numpy function in routes them to hnp functions if supported. + """Catch calls to numpy function in routes them to tracing functions if supported. Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch """ @@ -62,7 +62,8 @@ class NPTracer(BaseTracer): (len(kwargs) == 0), f"**kwargs are currently not supported for numpy functions, func: {func}", ) - return tracing_func(*args, **kwargs) + sanitized_args = [self._sanitize(arg) for arg in args] + return tracing_func(self, *sanitized_args, **kwargs) def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": r"""Support numpy astype feature. @@ -222,26 +223,25 @@ class NPTracer(BaseTracer): ) return output_tracer - def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": + def dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer": """Trace numpy.dot. Returns: NPTracer: The output NPTracer containing the traced function """ - # input_tracers contains the other tracer of the dot product - dot_inputs = (self, self._sanitize(other_tracer)) + assert_true((num_args := len(args)) == 2, f"dot expects 2 inputs got {num_args}") - common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs) + common_output_dtypes = self._manage_dtypes(numpy.dot, *args) custom_assert(len(common_output_dtypes) == 1) traced_computation = Dot( - [input_tracer.output for input_tracer in dot_inputs], + [input_tracer.output for input_tracer in args], common_output_dtypes[0], delegate_evaluation_function=numpy.dot, ) output_tracer = self.__class__( - dot_inputs, + args, traced_computation=traced_computation, output_index=0, )