dev: add a function to check that a program is an actual integer program

This commit is contained in:
Arthur Meyre
2021-08-02 13:26:13 +02:00
parent c6a2b4b35c
commit 73f21c79a6
3 changed files with 110 additions and 0 deletions

View File

@@ -1,2 +1,3 @@
"""Module for shared data structures and code"""
from . import data_types, debugging, representation
from .common_helpers import check_op_graph_is_integer_program

View File

@@ -0,0 +1,51 @@
"""File to hold some helper code"""
from typing import List, Optional
from .data_types.integers import Integer
from .operator_graph import OPGraph
from .representation import intermediate as ir
def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool:
"""Check if an ir node has Integer inputs and outputs
Args:
node (ir.IntermediateNode): Node to check
Returns:
bool: True if all input and output values hold Integers
"""
return all(map(lambda x: isinstance(x.data_type, Integer), node.inputs)) and all(
map(lambda x: isinstance(x.data_type, Integer), node.outputs)
)
# This check makes sense as long as the compiler backend only manages integers, to be removed in the
# long run probably
def check_op_graph_is_integer_program(
op_graph: OPGraph,
offending_nodes_out: Optional[List[ir.IntermediateNode]] = None,
) -> bool:
"""Check if an op_graph inputs, outputs and intermediate values are Integers
Args:
op_graph (OPGraph): The OPGraph to check
offending_nodes_out (Optional[List[ir.IntermediateNode]]): Optionally pass a list that will
be populated with offending nodes, the list will be cleared before being filled
Returns:
bool: True if inputs, outputs and intermediate values are Integers, False otherwise
"""
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
assert isinstance(
offending_nodes, list
), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}"
offending_nodes.clear()
offending_nodes.extend(
node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node)
)
return len(offending_nodes) == 0

View File

@@ -0,0 +1,58 @@
"""Test file for common helpers"""
from copy import deepcopy
from hdk.common import check_op_graph_is_integer_program
from hdk.common.data_types.base import BaseDataType
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.tracing import trace_numpy_function
class DummyNotInteger(BaseDataType):
"""Dummy helper data type class"""
def test_check_op_graph_is_integer_program():
"""Test function for check_op_graph_is_integer_program"""
def function(x, y):
return x + y - y * y + x * y
op_graph = trace_numpy_function(
function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))}
)
# Test without and with output list
offending_nodes = []
assert check_op_graph_is_integer_program(op_graph)
assert check_op_graph_is_integer_program(op_graph, offending_nodes)
assert len(offending_nodes) == 0
op_graph_copy = deepcopy(op_graph)
op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger()
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 1
assert offending_nodes == [op_graph_copy.output_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 1
assert offending_nodes == [op_graph_copy.input_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger()
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
assert len(offending_nodes) == 2
assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]])