fix: don't fuse some operations as their semantics doesn't allow fusing

This commit is contained in:
Umut
2022-05-03 12:15:49 +02:00
parent 057232666a
commit 3629cafd82
2 changed files with 19 additions and 1 deletions

View File

@@ -389,7 +389,7 @@ def subgraph_can_be_fused(
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
for node in non_constant_nodes:
if node.output.shape != variable_input_node.output.shape:
if not node.is_fusable or node.output.shape != variable_input_node.output.shape:
return False
return True

View File

@@ -317,3 +317,21 @@ class Node:
"sum",
"transpose",
]
@property
def is_fusable(self) -> bool:
"""
Get whether the node is can be fused into a table lookup.
Returns:
bool:
True if the node can be fused into a table lookup, False otherwise
"""
if self.operation != Operation.Generic:
return True
if self.converted_to_table_lookup:
return True
return self.properties["name"] in ["add", "multiply", "negative", "subtract"]