refactor: move the deepcopy under the hood

closes #914
This commit is contained in:
Benoit Chevallier-Mames
2021-11-23 09:35:27 +01:00
committed by Benoit Chevallier
parent fe90b35392
commit 0a87c26b64
6 changed files with 12 additions and 12 deletions

View File

@@ -104,7 +104,7 @@ class MultiLookupTable:
)
traced_computation = GenericFunction(
inputs=[deepcopy(key.output)],
inputs=[key.output],
arbitrary_func=MultiLookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -42,7 +42,7 @@ class LookupTable:
generic_function_output_value.dtype = self.output_dtype
traced_computation = GenericFunction(
inputs=[deepcopy(key.output)],
inputs=[key.output],
arbitrary_func=LookupTable._checked_indexing,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -192,7 +192,7 @@ def convert_float_subgraph_to_fused_node(
# Create fused_node
fused_node = GenericFunction(
inputs=[deepcopy(new_subgraph_variable_input.inputs[0])],
inputs=[new_subgraph_variable_input.inputs[0]],
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
{0: x}
)[terminal_node],

View File

@@ -333,7 +333,7 @@ class GenericFunction(IntermediateNode):
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(inputs)
super().__init__([deepcopy(i) for i in inputs])
self._n_in = len(self.inputs)
self.arbitrary_func = arbitrary_func
self.op_kind = GenericFunctionKind(op_kind)

View File

@@ -143,7 +143,7 @@ class BaseTracer(ABC):
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
inputs=[first_arg_output],
arbitrary_func=op_lambda,
output_value=generic_function_output_value,
op_kind="TLU",

View File

@@ -100,7 +100,7 @@ class NPTracer(BaseTracer):
generic_function_output_value = deepcopy(self.output)
generic_function_output_value.dtype = output_dtype
traced_computation = GenericFunction(
inputs=[deepcopy(self.output)],
inputs=[self.output],
arbitrary_func=lambda x, dtype: x.astype(dtype),
output_value=generic_function_output_value,
op_kind="TLU",
@@ -177,7 +177,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[deepcopy(input_tracers[0].output)],
inputs=[input_tracers[0].output],
arbitrary_func=unary_operator,
output_value=generic_function_output_value,
op_kind="TLU",
@@ -232,7 +232,7 @@ class NPTracer(BaseTracer):
op_kwargs = deepcopy(kwargs)
traced_computation = GenericFunction(
inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers],
inputs=[input_tracer.output for input_tracer in input_tracers],
arbitrary_func=binary_operator,
output_value=generic_function_output_value,
op_kind="TLU",
@@ -316,7 +316,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
inputs=[first_arg_output],
arbitrary_func=numpy.transpose,
output_value=generic_function_output_value,
op_kind="Memory",
@@ -363,7 +363,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
inputs=[first_arg_output],
arbitrary_func=numpy.ravel,
output_value=generic_function_output_value,
op_kind="Memory",
@@ -428,7 +428,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
inputs=[first_arg_output],
arbitrary_func=numpy.reshape,
output_value=generic_function_output_value,
op_kind="Memory",
@@ -467,7 +467,7 @@ class NPTracer(BaseTracer):
)
traced_computation = GenericFunction(
inputs=[deepcopy(first_arg_output)],
inputs=[first_arg_output],
arbitrary_func=lambda x: x.flatten(),
output_value=generic_function_output_value,
op_kind="Memory",