dev(NPTracer): add op_name for traced functions, deepcopy kwargs

- ir.ArbitraryFunction does not deepcopy op_args and op_kwargs by default
anymore to let the control to the developer instantiating it
This commit is contained in:
Arthur Meyre
2021-08-16 12:35:26 +02:00
parent 0eebbfcd26
commit 825d6422d0
2 changed files with 6 additions and 3 deletions

View File

@@ -200,8 +200,8 @@ class ArbitraryFunction(IntermediateNode):
super().__init__([input_base_value])
assert len(self.inputs) == 1
self.arbitrary_func = arbitrary_func
self.op_args = deepcopy(op_args) if op_args is not None else ()
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
# TLU/PBS has an encrypted output
self.outputs = [EncryptedValue(output_dtype)]
self.op_name = op_name if op_name is not None else self.__class__.__name__

View File

@@ -1,4 +1,5 @@
"""hnumpy tracing utilities."""
from copy import deepcopy
from typing import Callable, Dict, Mapping
import numpy
@@ -53,6 +54,7 @@ class NPTracer(BaseTracer):
input_base_value=self.output,
arbitrary_func=normalized_numpy_dtype.type,
output_dtype=output_dtype,
op_name=f"astype({normalized_numpy_dtype})",
)
output_tracer = self.__class__(
[self], traced_computation=traced_computation, output_index=0
@@ -103,7 +105,8 @@ class NPTracer(BaseTracer):
input_base_value=input_tracers[0].output,
arbitrary_func=numpy.rint,
output_dtype=common_output_dtypes[0],
op_kwargs=kwargs,
op_kwargs=deepcopy(kwargs),
op_name="numpy.rint",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0