From a060aaae99f30a25b133f39266d60f9caee9ef8b Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 21 Jul 2021 11:09:34 +0200 Subject: [PATCH] feat(tracing): add tracing facilities - add BaseTracer which will hold most of the boilerplate code - add hnumpy with a bare NPTracer and tracing function - update IR to be compatible with tracing helpers - update test helper to properly check that graphs are equivalent - add test tracing a simple addition - rename common/data_types/helpers.py to .../dtypes_helpers.py to avoid having too many files with the same name - ignore missing type stubs in the default mypy command - add a comfort Makefile target to get errors about missing mypy stubs --- Makefile | 9 +- hdk/__init__.py | 2 +- hdk/common/data_types/__init__.py | 2 +- .../{helpers.py => dtypes_helpers.py} | 0 hdk/common/data_types/integers.py | 7 ++ hdk/common/data_types/values.py | 3 + hdk/common/representation/intermediate.py | 40 +++++++- hdk/common/tracing/__init__.py | 7 ++ hdk/common/tracing/base_tracer.py | 67 +++++++++++++ hdk/common/tracing/tracing_helpers.py | 95 +++++++++++++++++++ hdk/hnumpy/__init__.py | 2 + hdk/hnumpy/tracing.py | 48 ++++++++++ ...test_helpers.py => test_dtypes_helpers.py} | 4 +- tests/common/tracing/test_tracing_helpers.py | 26 +++++ tests/conftest.py | 8 +- tests/helpers/test_conftest.py | 12 ++- tests/hnumpy/test_tracing.py | 87 +++++++++++++++++ 17 files changed, 404 insertions(+), 15 deletions(-) rename hdk/common/data_types/{helpers.py => dtypes_helpers.py} (100%) create mode 100644 hdk/common/tracing/__init__.py create mode 100644 hdk/common/tracing/base_tracer.py create mode 100644 hdk/common/tracing/tracing_helpers.py create mode 100644 hdk/hnumpy/__init__.py create mode 100644 hdk/hnumpy/tracing.py rename tests/common/data_types/{test_helpers.py => test_dtypes_helpers.py} (94%) create mode 100644 tests/common/tracing/test_tracing_helpers.py create mode 100644 tests/hnumpy/test_tracing.py diff --git a/Makefile b/Makefile index 02e5a8349..033e1b5d0 100644 --- a/Makefile +++ b/Makefile @@ -30,10 +30,17 @@ pytest: poetry run pytest --cov=hdk -vv --cov-report=xml tests/ .PHONY: pytest +# Not a huge fan of ignoring missing imports, but some packages do not have typing stubs mypy: - poetry run mypy -p hdk + poetry run mypy -p hdk --ignore-missing-imports .PHONY: mypy +# Friendly target to run mypy without ignoring missing stubs and still have errors messages +# Allows to see which stubs we are missing +mypy_ns: + poetry run mypy -p hdk +.PHONY: mypy_ns + docs: cd docs && poetry run make html .PHONY: docs diff --git a/hdk/__init__.py b/hdk/__init__.py index fb269a812..a121a3f0c 100644 --- a/hdk/__init__.py +++ b/hdk/__init__.py @@ -1,2 +1,2 @@ """HDK's top import""" -from . import common +from . import common, hnumpy diff --git a/hdk/common/data_types/__init__.py b/hdk/common/data_types/__init__.py index 1703a0aaf..5c2244e21 100644 --- a/hdk/common/data_types/__init__.py +++ b/hdk/common/data_types/__init__.py @@ -1,3 +1,3 @@ """HDK's module for data types code and data structures""" -from . import helpers, integers, values +from . import dtypes_helpers, integers, values from .values import BaseValue diff --git a/hdk/common/data_types/helpers.py b/hdk/common/data_types/dtypes_helpers.py similarity index 100% rename from hdk/common/data_types/helpers.py rename to hdk/common/data_types/dtypes_helpers.py diff --git a/hdk/common/data_types/integers.py b/hdk/common/data_types/integers.py index d8b431adc..91b34d992 100644 --- a/hdk/common/data_types/integers.py +++ b/hdk/common/data_types/integers.py @@ -17,6 +17,13 @@ class Integer(base.BaseDataType): signed_str = "signed" if self.is_signed else "unsigned" return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>" + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and self.bit_width == other.bit_width + and self.is_signed == other.is_signed + ) + def min_value(self) -> int: """Minimum value representable by the Integer""" if self.is_signed: diff --git a/hdk/common/data_types/values.py b/hdk/common/data_types/values.py index b00ddf5f5..9ca75d249 100644 --- a/hdk/common/data_types/values.py +++ b/hdk/common/data_types/values.py @@ -16,6 +16,9 @@ class BaseValue(ABC): def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.data_type!r}>" + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.data_type == other.data_type + class ClearValue(BaseValue): """Class representing a clear/plaintext value (constant or not)""" diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 7b1cb5433..360aa1c05 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -22,9 +22,28 @@ class IntermediateNode(ABC): op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: self.inputs = list(inputs) + assert all(map(lambda x: isinstance(x, BaseValue), self.inputs)) self.op_args = op_args self.op_kwargs = op_kwargs + def is_equivalent_to(self, other: object) -> bool: + """Overriding __eq__ has unwanted side effects, this provides the same facility without + disrupting expected behavior too much + + Args: + other (object): Other object to check against + + Returns: + bool: True if the other object is equivalent + """ + return ( + isinstance(other, self.__class__) + and self.inputs == other.inputs + and self.outputs == other.outputs + and self.op_args == other.op_args + and self.op_kwargs == other.op_kwargs + ) + class Add(IntermediateNode): """Addition between two values""" @@ -32,14 +51,26 @@ class Add(IntermediateNode): def __init__( self, inputs: Iterable[BaseValue], + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__(inputs) + assert op_args is None, f"Expected op_args to be None, got {op_args}" + assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}" + + super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs) assert len(self.inputs) == 2 # For now copy the first input type for the output type # We don't perform checks or enforce consistency here for now, so this is OK self.outputs = [deepcopy(self.inputs[0])] + def is_equivalent_to(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and (self.inputs == other.inputs or self.inputs == other.inputs[::-1]) + and self.outputs == other.outputs + ) + class Input(IntermediateNode): """Node representing an input of the numpy program""" @@ -47,7 +78,12 @@ class Input(IntermediateNode): def __init__( self, inputs: Iterable[BaseValue], + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__(inputs) + assert op_args is None, f"Expected op_args to be None, got {op_args}" + assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}" + + super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs) assert len(self.inputs) == 1 self.outputs = [deepcopy(self.inputs[0])] diff --git a/hdk/common/tracing/__init__.py b/hdk/common/tracing/__init__.py new file mode 100644 index 000000000..1818cb5d9 --- /dev/null +++ b/hdk/common/tracing/__init__.py @@ -0,0 +1,7 @@ +"""HDK's module for basic tracing facilities""" +from .base_tracer import BaseTracer +from .tracing_helpers import ( + create_graph_from_output_tracers, + make_input_tracer, + prepare_function_parameters, +) diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py new file mode 100644 index 000000000..925678fb4 --- /dev/null +++ b/hdk/common/tracing/base_tracer.py @@ -0,0 +1,67 @@ +"""This file holds the code that can be shared between tracers""" + +from abc import ABC +from typing import Any, Dict, List, Optional, Tuple, Type + +from ..data_types import BaseValue +from ..representation import intermediate as ir + + +class BaseTracer(ABC): + """Base class for implementing tracers""" + + inputs: List["BaseTracer"] + traced_computation: ir.IntermediateNode + output: BaseValue + + def __init__( + self, + inputs: List["BaseTracer"], + traced_computation: ir.IntermediateNode, + output_index: int, + ) -> None: + self.inputs = inputs + self.traced_computation = traced_computation + self.output = traced_computation.outputs[output_index] + + def instantiate_output_tracers( + self, + inputs: List["BaseTracer"], + computation_to_trace: Type[ir.IntermediateNode], + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple["BaseTracer", ...]: + """Helper functions to instantiate all output BaseTracer for a given computation + + Args: + inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node + computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class + to instantiate for the computation being traced + op_args: *args coming from the call being traced + op_kwargs: **kwargs coming from the call being traced + + + Returns: + Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function + """ + traced_computation = computation_to_trace( + map(lambda x: x.output, inputs), + op_args=op_args, + op_kwargs=op_kwargs, + ) + + output_tracers = tuple( + self.__class__(inputs, traced_computation, output_index) + for output_index in range(len(traced_computation.outputs)) + ) + + return output_tracers + + def __add__(self, other: "BaseTracer") -> "BaseTracer": + result_tracer = self.instantiate_output_tracers( + [self, other], + ir.Add, + ) + + assert len(result_tracer) == 1 + return result_tracer[0] diff --git a/hdk/common/tracing/tracing_helpers.py b/hdk/common/tracing/tracing_helpers.py new file mode 100644 index 000000000..a101ec5db --- /dev/null +++ b/hdk/common/tracing/tracing_helpers.py @@ -0,0 +1,95 @@ +"""Helper functions for tracing""" +from inspect import signature +from typing import Callable, Dict, Iterable, Set, Tuple, Type + +import networkx as nx +from networkx.algorithms.dag import is_directed_acyclic_graph + +from ..data_types import BaseValue +from ..representation import intermediate as ir +from .base_tracer import BaseTracer + + +def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer: + """Helper function to create a tracer for an input value + + Args: + tracer_class (Type[BaseTracer]): the class of tracer to create an Input for + input_value (BaseValue): the Value that is an input and needs to be wrapped in an + BaseTracer + + Returns: + BaseTracer: The BaseTracer for that input value + """ + return tracer_class([], ir.Input([input_value]), 0) + + +def prepare_function_parameters( + function_to_trace: Callable, function_parameters: Dict[str, BaseValue] +) -> Dict[str, BaseValue]: + """Function to filter the passed function_parameters to trace function_to_trace + + Args: + function_to_trace (Callable): function that will be traced for which parameters are checked + function_parameters (Dict[str, BaseValue]): parameters given to trace the function + + Raises: + ValueError: Raised when some parameters are missing to trace function_to_trace + + Returns: + Dict[str, BaseValue]: filtered function_parameters dictionary + """ + function_signature = signature(function_to_trace) + + missing_args = function_signature.parameters.keys() - function_parameters.keys() + + if len(missing_args) > 0: + raise ValueError( + f"The function '{function_to_trace.__name__}' requires the following parameters" + f"that were not provided: {', '.join(sorted(missing_args))}" + ) + + useless_arguments = function_parameters.keys() - function_signature.parameters.keys() + useful_arguments = function_signature.parameters.keys() - useless_arguments + + return {k: function_parameters[k] for k in useful_arguments} + + +def create_graph_from_output_tracers( + output_tracers: Iterable[BaseTracer], +) -> nx.MultiDiGraph: + """Generate a networkx Directed Graph that will represent the computation from a traced function + + Args: + output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the + function over the proper input tracers + + Returns: + nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes + representing the traced program/function + """ + graph = nx.MultiDiGraph() + + visited_tracers: Set[BaseTracer] = set() + current_tracers = tuple(output_tracers) + + while current_tracers: + next_tracers: Tuple[BaseTracer, ...] = tuple() + for tracer in current_tracers: + current_ir_node = tracer.traced_computation + graph.add_node(current_ir_node, content=current_ir_node) + + for input_idx, input_tracer in enumerate(tracer.inputs): + input_ir_node = input_tracer.traced_computation + graph.add_node(input_ir_node, content=input_ir_node) + graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx) + if input_tracer not in visited_tracers: + next_tracers += (input_tracer,) + + visited_tracers.add(tracer) + + current_tracers = next_tracers + + assert is_directed_acyclic_graph(graph) + + return graph diff --git a/hdk/hnumpy/__init__.py b/hdk/hnumpy/__init__.py new file mode 100644 index 000000000..5af83dc23 --- /dev/null +++ b/hdk/hnumpy/__init__.py @@ -0,0 +1,2 @@ +"""HDK's module for compiling numpy functions to homomorphic equivalents""" +from . import tracing diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py new file mode 100644 index 000000000..8e1ac38a9 --- /dev/null +++ b/hdk/hnumpy/tracing.py @@ -0,0 +1,48 @@ +"""hnumpy tracing utilities""" +from typing import Callable, Dict + +import networkx as nx + +from ..common.data_types import BaseValue +from ..common.tracing import ( + BaseTracer, + create_graph_from_output_tracers, + make_input_tracer, + prepare_function_parameters, +) + + +class NPTracer(BaseTracer): + """Tracer class for numpy operations""" + + +def trace_numpy_function( + function_to_trace: Callable, function_parameters: Dict[str, BaseValue] +) -> nx.MultiDiGraph: + """Function used to trace a numpy function + + Args: + function_to_trace (Callable): The function you want to trace + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedValue holding a 7bits unsigned Integer + + Returns: + nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the + input function + """ + function_parameters = prepare_function_parameters(function_to_trace, function_parameters) + + input_tracers = { + param_name: make_input_tracer(NPTracer, param) + for param_name, param in function_parameters.items() + } + + # We could easily create a graph of NPTracer, but we may end up with dead nodes starting from + # the inputs that's why we create the graph starting from the outputs + output_tracers = function_to_trace(**input_tracers) + if isinstance(output_tracers, NPTracer): + output_tracers = (output_tracers,) + + graph = create_graph_from_output_tracers(output_tracers) + + return graph diff --git a/tests/common/data_types/test_helpers.py b/tests/common/data_types/test_dtypes_helpers.py similarity index 94% rename from tests/common/data_types/test_helpers.py rename to tests/common/data_types/test_dtypes_helpers.py index cb76467eb..43b5fd872 100644 --- a/tests/common/data_types/test_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -1,8 +1,8 @@ -"""Test file for HDK's common/data_types/helpers.py""" +"""Test file for HDK's data types helpers""" import pytest -from hdk.common.data_types.helpers import ( +from hdk.common.data_types.dtypes_helpers import ( value_is_encrypted_integer, value_is_encrypted_unsigned_integer, ) diff --git a/tests/common/tracing/test_tracing_helpers.py b/tests/common/tracing/test_tracing_helpers.py new file mode 100644 index 000000000..20adb02ad --- /dev/null +++ b/tests/common/tracing/test_tracing_helpers.py @@ -0,0 +1,26 @@ +"""Test file for HDK's common tracing helpers""" + +from typing import Any, Dict + +import pytest + +from hdk.common.tracing.tracing_helpers import prepare_function_parameters + + +@pytest.mark.parametrize( + "function,function_parameters,ref_dict", + [ + pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)), + pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"), + pytest.param( + lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered" + ), + ], +) +def test_prepare_function_parameters( + function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any] +): + """Test prepare_function_parameters""" + prepared_dict = prepare_function_parameters(function, function_parameters) + + assert prepared_dict == ref_dict diff --git a/tests/conftest.py b/tests/conftest.py index b68f743fe..93206f8f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,11 +8,13 @@ class TestHelpers: """Class allowing to pass helper functions to tests""" @staticmethod - def digraphs_are_equivalent(reference: nx.DiGraph, to_compare: nx.DiGraph): + def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph): """Check that two digraphs are equivalent without modifications""" # edge_match is a copy of node_match - edge_matcher = iso.categorical_node_match("input_idx", None) - node_matcher = iso.categorical_node_match("content", None) + edge_matcher = iso.categorical_multiedge_match("input_idx", None) + node_matcher = iso.generic_node_match( + "content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs) + ) graphs_are_isomorphic = nx.is_isomorphic( reference, to_compare, diff --git a/tests/helpers/test_conftest.py b/tests/helpers/test_conftest.py index d1c5b4cc4..9ed6185af 100644 --- a/tests/helpers/test_conftest.py +++ b/tests/helpers/test_conftest.py @@ -16,11 +16,13 @@ def test_digraphs_are_equivalent(test_helpers): def __hash__(self) -> int: return self.computation.__hash__() - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: return self.computation == other.computation - g_1 = nx.DiGraph() - g_2 = nx.DiGraph() + is_equivalent_to = __eq__ + + g_1 = nx.MultiDiGraph() + g_2 = nx.MultiDiGraph() t_0 = TestNode("Add") t_1 = TestNode("Mul") @@ -44,7 +46,7 @@ def test_digraphs_are_equivalent(test_helpers): for node in g_2: g_2.add_node(node, content=node) - bad_g2 = nx.DiGraph() + bad_g2 = nx.MultiDiGraph() bad_t0 = TestNode("Not Add") @@ -55,7 +57,7 @@ def test_digraphs_are_equivalent(test_helpers): for node in bad_g2: bad_g2.add_node(node, content=node) - bad_g3 = nx.DiGraph() + bad_g3 = nx.MultiDiGraph() bad_g3.add_edge(t_0, t_2, input_idx=1) bad_g3.add_edge(t_1, t_2, input_idx=0) diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py new file mode 100644 index 000000000..1dbe878ec --- /dev/null +++ b/tests/hnumpy/test_tracing.py @@ -0,0 +1,87 @@ +"""Test file for HDK's hnumpy tracing""" + +import networkx as nx +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearValue, EncryptedValue +from hdk.common.representation import intermediate as ir +from hdk.hnumpy import tracing + + +@pytest.mark.parametrize( + "x", + [ + pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), + pytest.param( + EncryptedValue(Integer(64, is_signed=True)), + id="Encrypted int", + ), + pytest.param( + ClearValue(Integer(64, is_signed=False)), + id="Clear uint", + ), + pytest.param( + ClearValue(Integer(64, is_signed=True)), + id="Clear int", + ), + ], +) +@pytest.mark.parametrize( + "y", + [ + pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), + pytest.param( + EncryptedValue(Integer(64, is_signed=True)), + id="Encrypted int", + ), + pytest.param( + ClearValue(Integer(64, is_signed=False)), + id="Clear uint", + ), + pytest.param( + ClearValue(Integer(64, is_signed=True)), + id="Clear int", + ), + ], +) +def test_hnumpy_tracing_add(x, y, test_helpers): + "Test hnumpy tracing __add__" + + def simple_add_function(x, y): + z = x + x + return z + y + + graph = tracing.trace_numpy_function(simple_add_function, {"x": x, "y": y}) + + ref_graph = nx.MultiDiGraph() + + input_x = ir.Input((x,)) + input_y = ir.Input((y,)) + + add_node_z = ir.Add( + ( + input_x.outputs[0], + input_x.outputs[0], + ) + ) + + return_add_node = ir.Add( + ( + add_node_z.outputs[0], + input_y.outputs[0], + ) + ) + + ref_graph.add_node(input_x, content=input_x) + ref_graph.add_node(input_y, content=input_y) + ref_graph.add_node(add_node_z, content=add_node_z) + ref_graph.add_node(return_add_node, content=return_add_node) + + ref_graph.add_edge(input_x, add_node_z, input_idx=0) + ref_graph.add_edge(input_x, add_node_z, input_idx=1) + + ref_graph.add_edge(add_node_z, return_add_node, input_idx=0) + ref_graph.add_edge(input_y, return_add_node, input_idx=1) + + assert test_helpers.digraphs_are_equivalent(ref_graph, graph)