From 73f21c79a6be3362ad62af0117b3efe5f84fe65f Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 2 Aug 2021 13:26:13 +0200 Subject: [PATCH] dev: add a function to check that a program is an actual integer program --- hdk/common/__init__.py | 1 + hdk/common/common_helpers.py | 51 +++++++++++++++++++++++++ tests/common/test_common_helpers.py | 58 +++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 hdk/common/common_helpers.py create mode 100644 tests/common/test_common_helpers.py diff --git a/hdk/common/__init__.py b/hdk/common/__init__.py index eecdb8723..f498f0d8d 100644 --- a/hdk/common/__init__.py +++ b/hdk/common/__init__.py @@ -1,2 +1,3 @@ """Module for shared data structures and code""" from . import data_types, debugging, representation +from .common_helpers import check_op_graph_is_integer_program diff --git a/hdk/common/common_helpers.py b/hdk/common/common_helpers.py new file mode 100644 index 000000000..e0c872c9a --- /dev/null +++ b/hdk/common/common_helpers.py @@ -0,0 +1,51 @@ +"""File to hold some helper code""" + +from typing import List, Optional + +from .data_types.integers import Integer +from .operator_graph import OPGraph +from .representation import intermediate as ir + + +def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool: + """Check if an ir node has Integer inputs and outputs + + Args: + node (ir.IntermediateNode): Node to check + + Returns: + bool: True if all input and output values hold Integers + """ + return all(map(lambda x: isinstance(x.data_type, Integer), node.inputs)) and all( + map(lambda x: isinstance(x.data_type, Integer), node.outputs) + ) + + +# This check makes sense as long as the compiler backend only manages integers, to be removed in the +# long run probably +def check_op_graph_is_integer_program( + op_graph: OPGraph, + offending_nodes_out: Optional[List[ir.IntermediateNode]] = None, +) -> bool: + """Check if an op_graph inputs, outputs and intermediate values are Integers + + Args: + op_graph (OPGraph): The OPGraph to check + offending_nodes_out (Optional[List[ir.IntermediateNode]]): Optionally pass a list that will + be populated with offending nodes, the list will be cleared before being filled + + Returns: + bool: True if inputs, outputs and intermediate values are Integers, False otherwise + """ + offending_nodes = [] if offending_nodes_out is None else offending_nodes_out + + assert isinstance( + offending_nodes, list + ), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}" + + offending_nodes.clear() + offending_nodes.extend( + node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node) + ) + + return len(offending_nodes) == 0 diff --git a/tests/common/test_common_helpers.py b/tests/common/test_common_helpers.py new file mode 100644 index 000000000..fdbee8538 --- /dev/null +++ b/tests/common/test_common_helpers.py @@ -0,0 +1,58 @@ +"""Test file for common helpers""" + +from copy import deepcopy + +from hdk.common import check_op_graph_is_integer_program +from hdk.common.data_types.base import BaseDataType +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import EncryptedValue +from hdk.hnumpy.tracing import trace_numpy_function + + +class DummyNotInteger(BaseDataType): + """Dummy helper data type class""" + + +def test_check_op_graph_is_integer_program(): + """Test function for check_op_graph_is_integer_program""" + + def function(x, y): + return x + y - y * y + x * y + + op_graph = trace_numpy_function( + function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))} + ) + + # Test without and with output list + offending_nodes = [] + assert check_op_graph_is_integer_program(op_graph) + assert check_op_graph_is_integer_program(op_graph, offending_nodes) + assert len(offending_nodes) == 0 + + op_graph_copy = deepcopy(op_graph) + op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger() + + offending_nodes = [] + assert not check_op_graph_is_integer_program(op_graph_copy) + assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) + assert len(offending_nodes) == 1 + assert offending_nodes == [op_graph_copy.output_nodes[0]] + + op_graph_copy = deepcopy(op_graph) + op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger() + + offending_nodes = [] + assert not check_op_graph_is_integer_program(op_graph_copy) + assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) + assert len(offending_nodes) == 1 + assert offending_nodes == [op_graph_copy.input_nodes[0]] + + op_graph_copy = deepcopy(op_graph) + op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger() + op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger() + + offending_nodes = [] + assert not check_op_graph_is_integer_program(op_graph_copy) + assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes) + assert len(offending_nodes) == 2 + assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]])