mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev: add a function to check that a program is an actual integer program
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
"""Module for shared data structures and code"""
|
||||
from . import data_types, debugging, representation
|
||||
from .common_helpers import check_op_graph_is_integer_program
|
||||
|
||||
51
hdk/common/common_helpers.py
Normal file
51
hdk/common/common_helpers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""File to hold some helper code"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .data_types.integers import Integer
|
||||
from .operator_graph import OPGraph
|
||||
from .representation import intermediate as ir
|
||||
|
||||
|
||||
def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool:
|
||||
"""Check if an ir node has Integer inputs and outputs
|
||||
|
||||
Args:
|
||||
node (ir.IntermediateNode): Node to check
|
||||
|
||||
Returns:
|
||||
bool: True if all input and output values hold Integers
|
||||
"""
|
||||
return all(map(lambda x: isinstance(x.data_type, Integer), node.inputs)) and all(
|
||||
map(lambda x: isinstance(x.data_type, Integer), node.outputs)
|
||||
)
|
||||
|
||||
|
||||
# This check makes sense as long as the compiler backend only manages integers, to be removed in the
|
||||
# long run probably
|
||||
def check_op_graph_is_integer_program(
|
||||
op_graph: OPGraph,
|
||||
offending_nodes_out: Optional[List[ir.IntermediateNode]] = None,
|
||||
) -> bool:
|
||||
"""Check if an op_graph inputs, outputs and intermediate values are Integers
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to check
|
||||
offending_nodes_out (Optional[List[ir.IntermediateNode]]): Optionally pass a list that will
|
||||
be populated with offending nodes, the list will be cleared before being filled
|
||||
|
||||
Returns:
|
||||
bool: True if inputs, outputs and intermediate values are Integers, False otherwise
|
||||
"""
|
||||
offending_nodes = [] if offending_nodes_out is None else offending_nodes_out
|
||||
|
||||
assert isinstance(
|
||||
offending_nodes, list
|
||||
), f"offending_nodes_out must be a list, got {type(offending_nodes_out)}"
|
||||
|
||||
offending_nodes.clear()
|
||||
offending_nodes.extend(
|
||||
node for node in op_graph.graph.nodes() if not ir_nodes_has_integer_input_and_output(node)
|
||||
)
|
||||
|
||||
return len(offending_nodes) == 0
|
||||
58
tests/common/test_common_helpers.py
Normal file
58
tests/common/test_common_helpers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Test file for common helpers"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from hdk.common import check_op_graph_is_integer_program
|
||||
from hdk.common.data_types.base import BaseDataType
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
class DummyNotInteger(BaseDataType):
|
||||
"""Dummy helper data type class"""
|
||||
|
||||
|
||||
def test_check_op_graph_is_integer_program():
|
||||
"""Test function for check_op_graph_is_integer_program"""
|
||||
|
||||
def function(x, y):
|
||||
return x + y - y * y + x * y
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function, {"x": EncryptedValue(Integer(64, True)), "y": EncryptedValue(Integer(64, True))}
|
||||
)
|
||||
|
||||
# Test without and with output list
|
||||
offending_nodes = []
|
||||
assert check_op_graph_is_integer_program(op_graph)
|
||||
assert check_op_graph_is_integer_program(op_graph, offending_nodes)
|
||||
assert len(offending_nodes) == 0
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.output_nodes[0].outputs[0].data_type = DummyNotInteger()
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 1
|
||||
assert offending_nodes == [op_graph_copy.output_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 1
|
||||
assert offending_nodes == [op_graph_copy.input_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = DummyNotInteger()
|
||||
op_graph_copy.input_nodes[1].inputs[0].data_type = DummyNotInteger()
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy, offending_nodes)
|
||||
assert len(offending_nodes) == 2
|
||||
assert set(offending_nodes) == set([op_graph_copy.input_nodes[0], op_graph_copy.input_nodes[1]])
|
||||
Reference in New Issue
Block a user