mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
fe90b35392
commit
0a87c26b64
@@ -104,7 +104,7 @@ class MultiLookupTable:
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(key.output)],
|
||||
inputs=[key.output],
|
||||
arbitrary_func=MultiLookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -42,7 +42,7 @@ class LookupTable:
|
||||
generic_function_output_value.dtype = self.output_dtype
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(key.output)],
|
||||
inputs=[key.output],
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -192,7 +192,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
# Create fused_node
|
||||
fused_node = GenericFunction(
|
||||
inputs=[deepcopy(new_subgraph_variable_input.inputs[0])],
|
||||
inputs=[new_subgraph_variable_input.inputs[0]],
|
||||
arbitrary_func=lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate(
|
||||
{0: x}
|
||||
)[terminal_node],
|
||||
|
||||
@@ -333,7 +333,7 @@ class GenericFunction(IntermediateNode):
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
op_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
super().__init__([deepcopy(i) for i in inputs])
|
||||
self._n_in = len(self.inputs)
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_kind = GenericFunctionKind(op_kind)
|
||||
|
||||
@@ -143,7 +143,7 @@ class BaseTracer(ABC):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
|
||||
@@ -100,7 +100,7 @@ class NPTracer(BaseTracer):
|
||||
generic_function_output_value = deepcopy(self.output)
|
||||
generic_function_output_value.dtype = output_dtype
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(self.output)],
|
||||
inputs=[self.output],
|
||||
arbitrary_func=lambda x, dtype: x.astype(dtype),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -177,7 +177,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(input_tracers[0].output)],
|
||||
inputs=[input_tracers[0].output],
|
||||
arbitrary_func=unary_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -232,7 +232,7 @@ class NPTracer(BaseTracer):
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(input_tracer.output) for input_tracer in input_tracers],
|
||||
inputs=[input_tracer.output for input_tracer in input_tracers],
|
||||
arbitrary_func=binary_operator,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
@@ -316,7 +316,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.transpose,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
@@ -363,7 +363,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.ravel,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
@@ -428,7 +428,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=numpy.reshape,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
@@ -467,7 +467,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
inputs=[first_arg_output],
|
||||
arbitrary_func=lambda x: x.flatten(),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
|
||||
Reference in New Issue
Block a user