From 79951b51b7a405c689134a88a8e4b430672aa82e Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 2 Nov 2022 15:15:15 +0100 Subject: [PATCH] refactor: replace if-elif-else chain with mapping in node converter --- concrete/numpy/mlir/node_converter.py | 101 ++++++++------------------ 1 file changed, 30 insertions(+), 71 deletions(-) diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index d1c527aaa..f24345251 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -145,82 +145,41 @@ class NodeConverter: in-memory MLIR representation corresponding to `self.node` """ - # pylint: disable=too-many-branches,too-many-statements - if self.node.operation == Operation.Constant: - result = self._convert_constant() - else: - assert_that(self.node.operation == Operation.Generic) + return self._convert_constant() - name = self.node.properties["name"] + assert_that(self.node.operation == Operation.Generic) - if name == "add": - result = self._convert_add() + name = self.node.properties["name"] + converters = { + "add": self._convert_add, + "array": self._convert_array, + "assign.static": self._convert_static_assignment, + "broadcast_to": self._convert_broadcast_to, + "concatenate": self._convert_concat, + "conv1d": self._convert_conv1d, + "conv2d": self._convert_conv2d, + "conv3d": self._convert_conv3d, + "dot": self._convert_dot, + "expand_dims": self._convert_reshape, + "index.static": self._convert_static_indexing, + "matmul": self._convert_matmul, + "maxpool": self._convert_maxpool, + "multiply": self._convert_mul, + "negative": self._convert_neg, + "ones": self._convert_ones, + "reshape": self._convert_reshape, + "subtract": self._convert_sub, + "sum": self._convert_sum, + "transpose": self._convert_transpose, + "zeros": self._convert_zeros, + } - elif name == "assign.static": - result = self._convert_static_assignment() + if name in converters: + return converters[name]() - elif name == "array": - result = self._convert_array() - - elif name == "broadcast_to": - result = self._convert_broadcast_to() - - elif name == "concatenate": - result = self._convert_concat() - - elif name == "conv1d": - result = self._convert_conv1d() - - elif name == "conv2d": - result = self._convert_conv2d() - - elif name == "conv3d": - result = self._convert_conv3d() - - elif name == "dot": - result = self._convert_dot() - - elif name == "index.static": - result = self._convert_static_indexing() - - elif name == "matmul": - result = self._convert_matmul() - - elif name == "maxpool": - result = self._convert_maxpool() - - elif name == "multiply": - result = self._convert_mul() - - elif name == "negative": - result = self._convert_neg() - - elif name == "ones": - result = self._convert_ones() - - elif name in ["reshape", "expand_dims"]: - result = self._convert_reshape() - - elif name == "subtract": - result = self._convert_sub() - - elif name == "sum": - result = self._convert_sum() - - elif name == "transpose": - result = self._convert_transpose() - - elif name == "zeros": - result = self._convert_zeros() - - else: - assert_that(self.node.converted_to_table_lookup) - result = self._convert_tlu() - - return result - - # pylint: enable=too-many-branches + assert_that(self.node.converted_to_table_lookup) + return self._convert_tlu() def _convert_add(self) -> OpResult: """