From 4655bea98786f277eb6998285002a2906eed7a3e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 24 Aug 2021 14:43:01 +0200 Subject: [PATCH] fix: register IR nodes to check when nodes are missing debug draw colors --- hdk/common/debugging/drawing.py | 9 +++++++++ hdk/common/representation/intermediate.py | 9 ++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/hdk/common/debugging/drawing.py b/hdk/common/debugging/drawing.py index 0f854f27d..62a17aba3 100644 --- a/hdk/common/debugging/drawing.py +++ b/hdk/common/debugging/drawing.py @@ -10,6 +10,7 @@ from PIL import Image from ..operator_graph import OPGraph from ..representation import intermediate as ir +from ..representation.intermediate import ALL_IR_NODES IR_NODE_COLOR_MAPPING = { ir.Input: "blue", @@ -23,6 +24,14 @@ IR_NODE_COLOR_MAPPING = { "output": "magenta", } +_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys() +assert len(_missing_nodes_in_mapping) == 0, ( + f"Missing IR node in IR_NODE_COLOR_MAPPING : " + f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}" +) + +del _missing_nodes_in_mapping + def draw_graph( opgraph: OPGraph, diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 1eba9aecf..5978ce617 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type from ..data_types.base import BaseDataType from ..data_types.dtypes_helpers import ( @@ -13,6 +13,8 @@ from ..values import BaseValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" +ALL_IR_NODES: Set[Type] = set() + class IntermediateNode(ABC): """Abstract Base Class to derive from to represent source program operations.""" @@ -29,6 +31,11 @@ class IntermediateNode(ABC): self.inputs = list(inputs) assert all(isinstance(x, BaseValue) for x in self.inputs) + # Register all IR nodes + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + ALL_IR_NODES.add(cls) + def _init_binary( self, inputs: Iterable[BaseValue],