mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(NPTracer): add op_name for traced functions, deepcopy kwargs
- ir.ArbitraryFunction does not deepcopy op_args and op_kwargs by default anymore to let the control to the developer instantiating it
This commit is contained in:
@@ -200,8 +200,8 @@ class ArbitraryFunction(IntermediateNode):
|
||||
super().__init__([input_base_value])
|
||||
assert len(self.inputs) == 1
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = deepcopy(op_args) if op_args is not None else ()
|
||||
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
# TLU/PBS has an encrypted output
|
||||
self.outputs = [EncryptedValue(output_dtype)]
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""hnumpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Mapping
|
||||
|
||||
import numpy
|
||||
@@ -53,6 +54,7 @@ class NPTracer(BaseTracer):
|
||||
input_base_value=self.output,
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
output_dtype=output_dtype,
|
||||
op_name=f"astype({normalized_numpy_dtype})",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self], traced_computation=traced_computation, output_index=0
|
||||
@@ -103,7 +105,8 @@ class NPTracer(BaseTracer):
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.rint,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=kwargs,
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="numpy.rint",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
|
||||
Reference in New Issue
Block a user