From 3629cafd8200b49884b89381dd4d23764c4a45ba Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 3 May 2022 12:15:49 +0200 Subject: [PATCH] fix: don't fuse some operations as their semantics doesn't allow fusing --- concrete/numpy/compilation/utils.py | 2 +- concrete/numpy/representation/node.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/concrete/numpy/compilation/utils.py b/concrete/numpy/compilation/utils.py index 2d7e19ddf..66e5182e8 100644 --- a/concrete/numpy/compilation/utils.py +++ b/concrete/numpy/compilation/utils.py @@ -389,7 +389,7 @@ def subgraph_can_be_fused( non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant) for node in non_constant_nodes: - if node.output.shape != variable_input_node.output.shape: + if not node.is_fusable or node.output.shape != variable_input_node.output.shape: return False return True diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index de7001b2e..00c4172a0 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -317,3 +317,21 @@ class Node: "sum", "transpose", ] + + @property + def is_fusable(self) -> bool: + """ + Get whether the node is can be fused into a table lookup. + + Returns: + bool: + True if the node can be fused into a table lookup, False otherwise + """ + + if self.operation != Operation.Generic: + return True + + if self.converted_to_table_lookup: + return True + + return self.properties["name"] in ["add", "multiply", "negative", "subtract"]