diff --git a/concrete/common/extensions/multi_table.py b/concrete/common/extensions/multi_table.py index 64fdedefd..73ac38191 100644 --- a/concrete/common/extensions/multi_table.py +++ b/concrete/common/extensions/multi_table.py @@ -104,7 +104,7 @@ class MultiLookupTable: ) traced_computation = GenericFunction( - inputs=[deepcopy(key.output)], + inputs=[key.output], arbitrary_func=MultiLookupTable._checked_indexing, output_value=generic_function_output_value, op_kind="TLU", diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index e87fb2a05..981ebe7ce 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -42,7 +42,7 @@ class LookupTable: generic_function_output_value.dtype = self.output_dtype traced_computation = GenericFunction( - inputs=[deepcopy(key.output)], + inputs=[key.output], arbitrary_func=LookupTable._checked_indexing, output_value=generic_function_output_value, op_kind="TLU", diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index a2e6d435d..18bd39bd9 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -192,7 +192,7 @@ def convert_float_subgraph_to_fused_node( # Create fused_node fused_node = GenericFunction( - inputs=[deepcopy(new_subgraph_variable_input.inputs[0])], + inputs=[new_subgraph_variable_input.inputs[0]], arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate( {0: x} )[terminal_node], diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 0c6d851f7..7970ee188 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -333,7 +333,7 @@ class GenericFunction(IntermediateNode): op_kwargs: Optional[Dict[str, Any]] = None, op_attributes: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__(inputs) + super().__init__([deepcopy(i) for i in inputs]) self._n_in = len(self.inputs) self.arbitrary_func = arbitrary_func self.op_kind = GenericFunctionKind(op_kind) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 4b183e882..0f7785b8d 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -143,7 +143,7 @@ class BaseTracer(ABC): ) traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], + inputs=[first_arg_output], arbitrary_func=op_lambda, output_value=generic_function_output_value, op_kind="TLU", diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 069e082e4..7c4144413 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -100,7 +100,7 @@ class NPTracer(BaseTracer): generic_function_output_value = deepcopy(self.output) generic_function_output_value.dtype = output_dtype traced_computation = GenericFunction( - inputs=[deepcopy(self.output)], + inputs=[self.output], arbitrary_func=lambda x, dtype: x.astype(dtype), output_value=generic_function_output_value, op_kind="TLU", @@ -177,7 +177,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[deepcopy(input_tracers[0].output)], + inputs=[input_tracers[0].output], arbitrary_func=unary_operator, output_value=generic_function_output_value, op_kind="TLU", @@ -232,7 +232,7 @@ class NPTracer(BaseTracer): op_kwargs = deepcopy(kwargs) traced_computation = GenericFunction( - inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers], + inputs=[input_tracer.output for input_tracer in input_tracers], arbitrary_func=binary_operator, output_value=generic_function_output_value, op_kind="TLU", @@ -316,7 +316,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], + inputs=[first_arg_output], arbitrary_func=numpy.transpose, output_value=generic_function_output_value, op_kind="Memory", @@ -363,7 +363,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], + inputs=[first_arg_output], arbitrary_func=numpy.ravel, output_value=generic_function_output_value, op_kind="Memory", @@ -428,7 +428,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], + inputs=[first_arg_output], arbitrary_func=numpy.reshape, output_value=generic_function_output_value, op_kind="Memory", @@ -467,7 +467,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], + inputs=[first_arg_output], arbitrary_func=lambda x: x.flatten(), output_value=generic_function_output_value, op_kind="Memory",