diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 0c98d8eb6..e257556d5 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -200,8 +200,8 @@ class ArbitraryFunction(IntermediateNode): super().__init__([input_base_value]) assert len(self.inputs) == 1 self.arbitrary_func = arbitrary_func - self.op_args = deepcopy(op_args) if op_args is not None else () - self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {} + self.op_args = op_args if op_args is not None else () + self.op_kwargs = op_kwargs if op_kwargs is not None else {} # TLU/PBS has an encrypted output self.outputs = [EncryptedValue(output_dtype)] self.op_name = op_name if op_name is not None else self.__class__.__name__ diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index a06125399..1a4996061 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,4 +1,5 @@ """hnumpy tracing utilities.""" +from copy import deepcopy from typing import Callable, Dict, Mapping import numpy @@ -53,6 +54,7 @@ class NPTracer(BaseTracer): input_base_value=self.output, arbitrary_func=normalized_numpy_dtype.type, output_dtype=output_dtype, + op_name=f"astype({normalized_numpy_dtype})", ) output_tracer = self.__class__( [self], traced_computation=traced_computation, output_index=0 @@ -103,7 +105,8 @@ class NPTracer(BaseTracer): input_base_value=input_tracers[0].output, arbitrary_func=numpy.rint, output_dtype=common_output_dtypes[0], - op_kwargs=kwargs, + op_kwargs=deepcopy(kwargs), + op_name="numpy.rint", ) output_tracer = self.__class__( input_tracers, traced_computation=traced_computation, output_index=0