refactor: replace if-elif-else chain with mapping in node converter

This commit is contained in:
Umut
2022-11-02 15:15:15 +01:00
parent cb9cbb05ab
commit 79951b51b7

View File

@@ -145,82 +145,41 @@ class NodeConverter:
in-memory MLIR representation corresponding to `self.node`
"""
# pylint: disable=too-many-branches,too-many-statements
if self.node.operation == Operation.Constant:
result = self._convert_constant()
else:
assert_that(self.node.operation == Operation.Generic)
return self._convert_constant()
name = self.node.properties["name"]
assert_that(self.node.operation == Operation.Generic)
if name == "add":
result = self._convert_add()
name = self.node.properties["name"]
converters = {
"add": self._convert_add,
"array": self._convert_array,
"assign.static": self._convert_static_assignment,
"broadcast_to": self._convert_broadcast_to,
"concatenate": self._convert_concat,
"conv1d": self._convert_conv1d,
"conv2d": self._convert_conv2d,
"conv3d": self._convert_conv3d,
"dot": self._convert_dot,
"expand_dims": self._convert_reshape,
"index.static": self._convert_static_indexing,
"matmul": self._convert_matmul,
"maxpool": self._convert_maxpool,
"multiply": self._convert_mul,
"negative": self._convert_neg,
"ones": self._convert_ones,
"reshape": self._convert_reshape,
"subtract": self._convert_sub,
"sum": self._convert_sum,
"transpose": self._convert_transpose,
"zeros": self._convert_zeros,
}
elif name == "assign.static":
result = self._convert_static_assignment()
if name in converters:
return converters[name]()
elif name == "array":
result = self._convert_array()
elif name == "broadcast_to":
result = self._convert_broadcast_to()
elif name == "concatenate":
result = self._convert_concat()
elif name == "conv1d":
result = self._convert_conv1d()
elif name == "conv2d":
result = self._convert_conv2d()
elif name == "conv3d":
result = self._convert_conv3d()
elif name == "dot":
result = self._convert_dot()
elif name == "index.static":
result = self._convert_static_indexing()
elif name == "matmul":
result = self._convert_matmul()
elif name == "maxpool":
result = self._convert_maxpool()
elif name == "multiply":
result = self._convert_mul()
elif name == "negative":
result = self._convert_neg()
elif name == "ones":
result = self._convert_ones()
elif name in ["reshape", "expand_dims"]:
result = self._convert_reshape()
elif name == "subtract":
result = self._convert_sub()
elif name == "sum":
result = self._convert_sum()
elif name == "transpose":
result = self._convert_transpose()
elif name == "zeros":
result = self._convert_zeros()
else:
assert_that(self.node.converted_to_table_lookup)
result = self._convert_tlu()
return result
# pylint: enable=too-many-branches
assert_that(self.node.converted_to_table_lookup)
return self._convert_tlu()
def _convert_add(self) -> OpResult:
"""