fix: register IR nodes to check when nodes are missing debug draw colors

This commit is contained in:
Arthur Meyre
2021-08-24 14:43:01 +02:00
parent b41029d9c0
commit 4655bea987
2 changed files with 17 additions and 1 deletions

View File

@@ -10,6 +10,7 @@ from PIL import Image
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
from ..representation.intermediate import ALL_IR_NODES
IR_NODE_COLOR_MAPPING = {
ir.Input: "blue",
@@ -23,6 +24,14 @@ IR_NODE_COLOR_MAPPING = {
"output": "magenta",
}
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
assert len(_missing_nodes_in_mapping) == 0, (
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
)
del _missing_nodes_in_mapping
def draw_graph(
opgraph: OPGraph,

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import (
@@ -13,6 +13,8 @@ from ..values import BaseValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
ALL_IR_NODES: Set[Type] = set()
class IntermediateNode(ABC):
"""Abstract Base Class to derive from to represent source program operations."""
@@ -29,6 +31,11 @@ class IntermediateNode(ABC):
self.inputs = list(inputs)
assert all(isinstance(x, BaseValue) for x in self.inputs)
# Register all IR nodes
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
ALL_IR_NODES.add(cls)
def _init_binary(
self,
inputs: Iterable[BaseValue],