feat: adding constant management

refs #49
This commit is contained in:
Benoit Chevallier-Mames
2021-08-03 15:18:44 +02:00
committed by Benoit Chevallier
parent 0baa02549c
commit 6157e4680b
8 changed files with 204 additions and 22 deletions

View File

@@ -0,0 +1,6 @@
"""File holding code to represent data types used for constants in programs"""
from typing import Union
# TODO: deal with more types
Scalars = Union[int, float]

View File

@@ -7,7 +7,13 @@ import networkx as nx
from hdk.common.operator_graph import OPGraph
from hdk.common.representation import intermediate as ir
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
IR_NODE_COLOR_MAPPING = {
ir.Input: "blue",
ir.ConstantInput: "cyan",
ir.Add: "red",
ir.Sub: "yellow",
ir.Mul: "green",
}
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
@@ -117,10 +123,14 @@ def draw_graph(
# For most types, we just pick the operation as the label, but for Input,
# we take the name of the variable, ie the argument name of the function
# to compile
label_dict = {
node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__
for node in graph.nodes()
}
def get_proper_name(node):
if isinstance(node, ir.Input):
return node.input_name
if isinstance(node, ir.ConstantInput):
return str(node.constant_data)
return node.__class__.__name__
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
# Draw nodes
nx.draw_networkx_nodes(
@@ -222,7 +232,11 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
for node in nx.topological_sort(graph):
if not isinstance(node, ir.Input):
if isinstance(node, ir.Input):
what_to_print = node.input_name
elif isinstance(node, ir.ConstantInput):
what_to_print = f"ConstantInput({node.constant_data})"
else:
what_to_print = node.__class__.__name__ + "("
# Find all the names of the current predecessors of the node
@@ -243,9 +257,6 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
list_of_arg_name.sort()
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
else:
what_to_print = node.input_name
returned_str += f"\n%{i} = {what_to_print}"
map_table[node] = i
i += 1

View File

@@ -27,11 +27,9 @@ class OPGraph:
self.input_nodes = {
node.program_input_idx: node
for node in self.graph.nodes()
if len(self.graph.pred[node]) == 0
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values()))
graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0)
assert set(self.output_nodes.values()) == graph_outputs

View File

@@ -6,6 +6,9 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..data_types.integers import Integer, get_bits_to_represent_int
from ..data_types.scalars import Scalars
from ..data_types.values import ClearValue
class IntermediateNode(ABC):
@@ -117,7 +120,7 @@ class Mul(IntermediateNode):
class Input(IntermediateNode):
"""Node representing an input of the numpy program"""
"""Node representing an input of the program"""
input_name: str
program_input_idx: int
@@ -136,3 +139,26 @@ class Input(IntermediateNode):
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0]
class ConstantInput(IntermediateNode):
"""Node representing a constant of the program"""
constant_data: Scalars
def __init__(
self,
constant_data: Scalars,
) -> None:
super().__init__([])
self.constant_data = constant_data
# TODO: manage other cases, we can't call get_bits_to_represent_int
assert isinstance(constant_data, int)
is_signed = constant_data < 0
self.outputs = [
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
]
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return self.constant_data

View File

@@ -1,9 +1,10 @@
"""This file holds the code that can be shared between tracers"""
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ..data_types import BaseValue
from ..data_types.scalars import Scalars
from ..representation import intermediate as ir
@@ -26,7 +27,7 @@ class BaseTracer(ABC):
def instantiate_output_tracers(
self,
inputs: List["BaseTracer"],
inputs: List[Union["BaseTracer", Scalars]],
computation_to_trace: Type[ir.IntermediateNode],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
@@ -44,20 +45,30 @@ class BaseTracer(ABC):
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
"""
# For inputs which are actually constant, first convert into a tracer
def sanitize(inp):
if not isinstance(inp, BaseTracer):
return make_const_input_tracer(self.__class__, inp)
return inp
sanitized_inputs = [sanitize(inp) for inp in inputs]
traced_computation = computation_to_trace(
map(lambda x: x.output, inputs),
map(lambda x: x.output, sanitized_inputs),
op_args=op_args,
op_kwargs=op_kwargs,
)
output_tracers = tuple(
self.__class__(inputs, traced_computation, output_index)
self.__class__(sanitized_inputs, traced_computation, output_index)
for output_index in range(len(traced_computation.outputs))
)
return output_tracers
def __add__(self, other: "BaseTracer") -> "BaseTracer":
def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Add,
@@ -66,7 +77,13 @@ class BaseTracer(ABC):
assert len(result_tracer) == 1
return result_tracer[0]
def __sub__(self, other: "BaseTracer") -> "BaseTracer":
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
# the order, we need to do as in __rsub__, ie mostly a copy of __sub__ +
# some changes
__radd__ = __add__
def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Sub,
@@ -75,7 +92,17 @@ class BaseTracer(ABC):
assert len(result_tracer) == 1
return result_tracer[0]
def __mul__(self, other: "BaseTracer") -> "BaseTracer":
def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
result_tracer = self.instantiate_output_tracers(
[other, self],
ir.Sub,
)
assert len(result_tracer) == 1
return result_tracer[0]
def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Mul,
@@ -83,3 +110,21 @@ class BaseTracer(ABC):
assert len(result_tracer) == 1
return result_tracer[0]
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
# some changes
__rmul__ = __mul__
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer:
"""Helper function to create a tracer for a constant input
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
constant_data (Scalars): the constant
Returns:
BaseTracer: The BaseTracer for that constant
"""
return tracer_class([], ir.ConstantInput(constant_data), 0)

View File

@@ -23,7 +23,28 @@ from hdk.hnumpy.tracing import trace_numpy_function
((-10, 2), (-4, 5)),
(-14, 7),
Integer(5, is_signed=True),
id="x + y, (-10, 2), (-4, 5), (-14, 9)",
id="x + y, (-10, 2), (-4, 5), (-14, 7)",
),
pytest.param(
lambda x, y: x + y + 1,
((-10, 2), (-4, 5)),
(-13, 8),
Integer(5, is_signed=True),
id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)",
),
pytest.param(
lambda x, y: x + y + (-3),
((-10, 2), (-4, 5)),
(-17, 4),
Integer(6, is_signed=True),
id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)",
),
pytest.param(
lambda x, y: (1 + x) + y,
((-10, 2), (-4, 5)),
(-13, 8),
Integer(5, is_signed=True),
id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)",
),
pytest.param(
lambda x, y: x - y,
@@ -39,6 +60,27 @@ from hdk.hnumpy.tracing import trace_numpy_function
Integer(5, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
),
pytest.param(
lambda x, y: x - y - 42,
((-10, 2), (-4, 5)),
(-57, -36),
Integer(7, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-57, -36)",
),
pytest.param(
lambda x, y: 3 - x + y,
((-10, 2), (-4, 5)),
(-3, 18),
Integer(6, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-3, 18)",
),
pytest.param(
lambda x, y: (-13) - x + y,
((-10, 2), (-4, 5)),
(-19, 2),
Integer(6, is_signed=True),
id="x - y, (-10, 2), (-4, 5), (-16, 2)",
),
pytest.param(
lambda x, y: x * y,
((-10, 10), (-10, 10)),
@@ -53,6 +95,27 @@ from hdk.hnumpy.tracing import trace_numpy_function
Integer(7, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
),
pytest.param(
lambda x, y: (3 * x) * y,
((-10, 2), (-4, 5)),
(-150, 120),
Integer(9, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-150, 120)",
),
pytest.param(
lambda x, y: (x * 11) * y,
((-10, 2), (-4, 5)),
(-550, 440),
Integer(11, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-550, 440)",
),
pytest.param(
lambda x, y: (x * (-11)) * y,
((-10, 2), (-4, 5)),
(-440, 550),
Integer(11, is_signed=True),
id="x * y, (-10, 2), (-4, 5), (-440, 550)",
),
pytest.param(
lambda x, y: x + x + y,
((-10, 10), (-10, 10)),
@@ -88,6 +151,13 @@ from hdk.hnumpy.tracing import trace_numpy_function
Integer(7, is_signed=True),
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
),
pytest.param(
lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x),
((-10, 2), (-4, 5)),
(-2846, 574),
Integer(13, is_signed=True),
id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),",
),
],
)
def test_eval_op_graph_bounds_on_dataset(

View File

@@ -29,6 +29,8 @@ from hdk.common.representation import intermediate as ir
id="Mul",
),
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"),
pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"),
],
)
def test_evaluate(

View File

@@ -13,11 +13,34 @@ from hdk.hnumpy import tracing
[
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"),
(
lambda x, y: x + x - y * y * y + x,
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
),
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"),
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"),
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"),
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"),
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"),
(
lambda x, y: x + 13 - y * (-21) * y + 44,
"\n%0 = ConstantInput(44)"
"\n%1 = x"
"\n%2 = ConstantInput(13)"
"\n%3 = y"
"\n%4 = ConstantInput(-21)"
"\n%5 = Add(1, 2)"
"\n%6 = Mul(3, 4)"
"\n%7 = Mul(6, 3)"
"\n%8 = Sub(5, 7)"
"\n%9 = Add(8, 0)",
),
],
)
@pytest.mark.parametrize(
@@ -48,6 +71,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
str_of_the_graph = get_printable_graph(graph)
print(f"\n{str_of_the_graph}\n")
print(f"\nGot {str_of_the_graph}\n")
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str