From 6157e4680b99b445370b8211ff4eb1f230fd3566 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 3 Aug 2021 15:18:44 +0200 Subject: [PATCH] feat: adding constant management refs #49 --- hdk/common/data_types/scalars.py | 6 ++ hdk/common/debugging/draw_graph.py | 29 +++++--- hdk/common/operator_graph.py | 4 +- hdk/common/representation/intermediate.py | 28 +++++++- hdk/common/tracing/base_tracer.py | 59 +++++++++++++-- .../bounds_measurement/test_dataset_eval.py | 72 ++++++++++++++++++- .../representation/test_intermediate.py | 2 + tests/hnumpy/test_debugging.py | 26 ++++++- 8 files changed, 204 insertions(+), 22 deletions(-) create mode 100644 hdk/common/data_types/scalars.py diff --git a/hdk/common/data_types/scalars.py b/hdk/common/data_types/scalars.py new file mode 100644 index 000000000..968d0b7fb --- /dev/null +++ b/hdk/common/data_types/scalars.py @@ -0,0 +1,6 @@ +"""File holding code to represent data types used for constants in programs""" + +from typing import Union + +# TODO: deal with more types +Scalars = Union[int, float] diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index e91623a04..d29c42f07 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -7,7 +7,13 @@ import networkx as nx from hdk.common.operator_graph import OPGraph from hdk.common.representation import intermediate as ir -IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"} +IR_NODE_COLOR_MAPPING = { + ir.Input: "blue", + ir.ConstantInput: "cyan", + ir.Add: "red", + ir.Sub: "yellow", + ir.Mul: "green", +} def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict: @@ -117,10 +123,14 @@ def draw_graph( # For most types, we just pick the operation as the label, but for Input, # we take the name of the variable, ie the argument name of the function # to compile - label_dict = { - node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__ - for node in graph.nodes() - } + def get_proper_name(node): + if isinstance(node, ir.Input): + return node.input_name + if isinstance(node, ir.ConstantInput): + return str(node.constant_data) + return node.__class__.__name__ + + label_dict = {node: get_proper_name(node) for node in graph.nodes()} # Draw nodes nx.draw_networkx_nodes( @@ -222,7 +232,11 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str: for node in nx.topological_sort(graph): - if not isinstance(node, ir.Input): + if isinstance(node, ir.Input): + what_to_print = node.input_name + elif isinstance(node, ir.ConstantInput): + what_to_print = f"ConstantInput({node.constant_data})" + else: what_to_print = node.__class__.__name__ + "(" # Find all the names of the current predecessors of the node @@ -243,9 +257,6 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str: list_of_arg_name.sort() what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" - else: - what_to_print = node.input_name - returned_str += f"\n%{i} = {what_to_print}" map_table[node] = i i += 1 diff --git a/hdk/common/operator_graph.py b/hdk/common/operator_graph.py index b1ad61619..31927c1ad 100644 --- a/hdk/common/operator_graph.py +++ b/hdk/common/operator_graph.py @@ -27,11 +27,9 @@ class OPGraph: self.input_nodes = { node.program_input_idx: node for node in self.graph.nodes() - if len(self.graph.pred[node]) == 0 + if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input) } - assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values())) - graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0) assert set(self.output_nodes.values()) == graph_outputs diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 9bc0ffddb..163b87d60 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -6,6 +6,9 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from ..data_types import BaseValue from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype +from ..data_types.integers import Integer, get_bits_to_represent_int +from ..data_types.scalars import Scalars +from ..data_types.values import ClearValue class IntermediateNode(ABC): @@ -117,7 +120,7 @@ class Mul(IntermediateNode): class Input(IntermediateNode): - """Node representing an input of the numpy program""" + """Node representing an input of the program""" input_name: str program_input_idx: int @@ -136,3 +139,26 @@ class Input(IntermediateNode): def evaluate(self, inputs: Mapping[int, Any]) -> Any: return inputs[0] + + +class ConstantInput(IntermediateNode): + """Node representing a constant of the program""" + + constant_data: Scalars + + def __init__( + self, + constant_data: Scalars, + ) -> None: + super().__init__([]) + self.constant_data = constant_data + + # TODO: manage other cases, we can't call get_bits_to_represent_int + assert isinstance(constant_data, int) + is_signed = constant_data < 0 + self.outputs = [ + ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed)) + ] + + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + return self.constant_data diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index 73b674fef..3519f30d4 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -1,9 +1,10 @@ """This file holds the code that can be shared between tracers""" from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..data_types import BaseValue +from ..data_types.scalars import Scalars from ..representation import intermediate as ir @@ -26,7 +27,7 @@ class BaseTracer(ABC): def instantiate_output_tracers( self, - inputs: List["BaseTracer"], + inputs: List[Union["BaseTracer", Scalars]], computation_to_trace: Type[ir.IntermediateNode], op_args: Optional[Tuple[Any, ...]] = None, op_kwargs: Optional[Dict[str, Any]] = None, @@ -44,20 +45,30 @@ class BaseTracer(ABC): Returns: Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function """ + + # For inputs which are actually constant, first convert into a tracer + def sanitize(inp): + if not isinstance(inp, BaseTracer): + return make_const_input_tracer(self.__class__, inp) + return inp + + sanitized_inputs = [sanitize(inp) for inp in inputs] + traced_computation = computation_to_trace( - map(lambda x: x.output, inputs), + map(lambda x: x.output, sanitized_inputs), op_args=op_args, op_kwargs=op_kwargs, ) output_tracers = tuple( - self.__class__(inputs, traced_computation, output_index) + self.__class__(sanitized_inputs, traced_computation, output_index) for output_index in range(len(traced_computation.outputs)) ) return output_tracers - def __add__(self, other: "BaseTracer") -> "BaseTracer": + def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + result_tracer = self.instantiate_output_tracers( [self, other], ir.Add, @@ -66,7 +77,13 @@ class BaseTracer(ABC): assert len(result_tracer) == 1 return result_tracer[0] - def __sub__(self, other: "BaseTracer") -> "BaseTracer": + # With that is that x + 1 and 1 + x have the same graph. If we want to keep + # the order, we need to do as in __rsub__, ie mostly a copy of __sub__ + + # some changes + __radd__ = __add__ + + def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + result_tracer = self.instantiate_output_tracers( [self, other], ir.Sub, @@ -75,7 +92,17 @@ class BaseTracer(ABC): assert len(result_tracer) == 1 return result_tracer[0] - def __mul__(self, other: "BaseTracer") -> "BaseTracer": + def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + + result_tracer = self.instantiate_output_tracers( + [other, self], + ir.Sub, + ) + + assert len(result_tracer) == 1 + return result_tracer[0] + + def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": result_tracer = self.instantiate_output_tracers( [self, other], ir.Mul, @@ -83,3 +110,21 @@ class BaseTracer(ABC): assert len(result_tracer) == 1 return result_tracer[0] + + # With that is that x * 3 and 3 * x have the same graph. If we want to keep + # the order, we need to do as in __rmul__, ie mostly a copy of __mul__ + + # some changes + __rmul__ = __mul__ + + +def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer: + """Helper function to create a tracer for a constant input + + Args: + tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for + constant_data (Scalars): the constant + + Returns: + BaseTracer: The BaseTracer for that constant + """ + return tracer_class([], ir.ConstantInput(constant_data), 0) diff --git a/tests/common/bounds_measurement/test_dataset_eval.py b/tests/common/bounds_measurement/test_dataset_eval.py index 398f190cf..2fd640c97 100644 --- a/tests/common/bounds_measurement/test_dataset_eval.py +++ b/tests/common/bounds_measurement/test_dataset_eval.py @@ -23,7 +23,28 @@ from hdk.hnumpy.tracing import trace_numpy_function ((-10, 2), (-4, 5)), (-14, 7), Integer(5, is_signed=True), - id="x + y, (-10, 2), (-4, 5), (-14, 9)", + id="x + y, (-10, 2), (-4, 5), (-14, 7)", + ), + pytest.param( + lambda x, y: x + y + 1, + ((-10, 2), (-4, 5)), + (-13, 8), + Integer(5, is_signed=True), + id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)", + ), + pytest.param( + lambda x, y: x + y + (-3), + ((-10, 2), (-4, 5)), + (-17, 4), + Integer(6, is_signed=True), + id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)", + ), + pytest.param( + lambda x, y: (1 + x) + y, + ((-10, 2), (-4, 5)), + (-13, 8), + Integer(5, is_signed=True), + id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)", ), pytest.param( lambda x, y: x - y, @@ -39,6 +60,27 @@ from hdk.hnumpy.tracing import trace_numpy_function Integer(5, is_signed=True), id="x - y, (-10, 2), (-4, 5), (-15, 6)", ), + pytest.param( + lambda x, y: x - y - 42, + ((-10, 2), (-4, 5)), + (-57, -36), + Integer(7, is_signed=True), + id="x - y, (-10, 2), (-4, 5), (-57, -36)", + ), + pytest.param( + lambda x, y: 3 - x + y, + ((-10, 2), (-4, 5)), + (-3, 18), + Integer(6, is_signed=True), + id="x - y, (-10, 2), (-4, 5), (-3, 18)", + ), + pytest.param( + lambda x, y: (-13) - x + y, + ((-10, 2), (-4, 5)), + (-19, 2), + Integer(6, is_signed=True), + id="x - y, (-10, 2), (-4, 5), (-16, 2)", + ), pytest.param( lambda x, y: x * y, ((-10, 10), (-10, 10)), @@ -53,6 +95,27 @@ from hdk.hnumpy.tracing import trace_numpy_function Integer(7, is_signed=True), id="x * y, (-10, 2), (-4, 5), (-50, 40)", ), + pytest.param( + lambda x, y: (3 * x) * y, + ((-10, 2), (-4, 5)), + (-150, 120), + Integer(9, is_signed=True), + id="x * y, (-10, 2), (-4, 5), (-150, 120)", + ), + pytest.param( + lambda x, y: (x * 11) * y, + ((-10, 2), (-4, 5)), + (-550, 440), + Integer(11, is_signed=True), + id="x * y, (-10, 2), (-4, 5), (-550, 440)", + ), + pytest.param( + lambda x, y: (x * (-11)) * y, + ((-10, 2), (-4, 5)), + (-440, 550), + Integer(11, is_signed=True), + id="x * y, (-10, 2), (-4, 5), (-440, 550)", + ), pytest.param( lambda x, y: x + x + y, ((-10, 10), (-10, 10)), @@ -88,6 +151,13 @@ from hdk.hnumpy.tracing import trace_numpy_function Integer(7, is_signed=True), id="x * y - x, (-10, 2), (-4, 5), (-40, 50),", ), + pytest.param( + lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x), + ((-10, 2), (-4, 5)), + (-2846, 574), + Integer(13, is_signed=True), + id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),", + ), ], ) def test_eval_op_graph_bounds_on_dataset( diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index ba8e8a629..2c4a665ec 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -29,6 +29,8 @@ from hdk.common.representation import intermediate as ir id="Mul", ), pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"), + pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"), + pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"), ], ) def test_evaluate( diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 5faf2fe6f..e28b956f0 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -13,11 +13,34 @@ from hdk.hnumpy import tracing [ (lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"), (lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"), + (lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"), ( lambda x, y: x + x - y * y * y + x, "\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)" "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)", ), + (lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"), + (lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"), + (lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"), + (lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"), + (lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"), + (lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"), + (lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"), + (lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"), + (lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"), + ( + lambda x, y: x + 13 - y * (-21) * y + 44, + "\n%0 = ConstantInput(44)" + "\n%1 = x" + "\n%2 = ConstantInput(13)" + "\n%3 = y" + "\n%4 = ConstantInput(-21)" + "\n%5 = Add(1, 2)" + "\n%6 = Mul(3, 4)" + "\n%7 = Mul(6, 3)" + "\n%8 = Sub(5, 7)" + "\n%9 = Add(8, 0)", + ), ], ) @pytest.mark.parametrize( @@ -48,6 +71,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): str_of_the_graph = get_printable_graph(graph) - print(f"\n{str_of_the_graph}\n") + print(f"\nGot {str_of_the_graph}\n") + print(f"\nExp {ref_graph_str}\n") assert str_of_the_graph == ref_graph_str