feat(tracing): add tracing facilities

- add BaseTracer which will hold most of the boilerplate code
- add hnumpy with a bare NPTracer and tracing function
- update IR to be compatible with tracing helpers
- update test helper to properly check that graphs are equivalent
- add test tracing a simple addition
- rename common/data_types/helpers.py to .../dtypes_helpers.py to avoid
having too many files with the same name
- ignore missing type stubs in the default mypy command
- add a comfort Makefile target to get errors about missing mypy stubs
This commit is contained in:
Arthur Meyre
2021-07-21 11:09:34 +02:00
parent b45944b66a
commit a060aaae99
17 changed files with 404 additions and 15 deletions

View File

@@ -30,10 +30,17 @@ pytest:
poetry run pytest --cov=hdk -vv --cov-report=xml tests/
.PHONY: pytest
# Not a huge fan of ignoring missing imports, but some packages do not have typing stubs
mypy:
poetry run mypy -p hdk
poetry run mypy -p hdk --ignore-missing-imports
.PHONY: mypy
# Friendly target to run mypy without ignoring missing stubs and still have errors messages
# Allows to see which stubs we are missing
mypy_ns:
poetry run mypy -p hdk
.PHONY: mypy_ns
docs:
cd docs && poetry run make html
.PHONY: docs

View File

@@ -1,2 +1,2 @@
"""HDK's top import"""
from . import common
from . import common, hnumpy

View File

@@ -1,3 +1,3 @@
"""HDK's module for data types code and data structures"""
from . import helpers, integers, values
from . import dtypes_helpers, integers, values
from .values import BaseValue

View File

@@ -17,6 +17,13 @@ class Integer(base.BaseDataType):
signed_str = "signed" if self.is_signed else "unsigned"
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.bit_width == other.bit_width
and self.is_signed == other.is_signed
)
def min_value(self) -> int:
"""Minimum value representable by the Integer"""
if self.is_signed:

View File

@@ -16,6 +16,9 @@ class BaseValue(ABC):
def __repr__(self) -> str:
return f"{self.__class__.__name__}<{self.data_type!r}>"
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.data_type == other.data_type
class ClearValue(BaseValue):
"""Class representing a clear/plaintext value (constant or not)"""

View File

@@ -22,9 +22,28 @@ class IntermediateNode(ABC):
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.inputs = list(inputs)
assert all(map(lambda x: isinstance(x, BaseValue), self.inputs))
self.op_args = op_args
self.op_kwargs = op_kwargs
def is_equivalent_to(self, other: object) -> bool:
"""Overriding __eq__ has unwanted side effects, this provides the same facility without
disrupting expected behavior too much
Args:
other (object): Other object to check against
Returns:
bool: True if the other object is equivalent
"""
return (
isinstance(other, self.__class__)
and self.inputs == other.inputs
and self.outputs == other.outputs
and self.op_args == other.op_args
and self.op_kwargs == other.op_kwargs
)
class Add(IntermediateNode):
"""Addition between two values"""
@@ -32,14 +51,26 @@ class Add(IntermediateNode):
def __init__(
self,
inputs: Iterable[BaseValue],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(inputs)
assert op_args is None, f"Expected op_args to be None, got {op_args}"
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
assert len(self.inputs) == 2
# For now copy the first input type for the output type
# We don't perform checks or enforce consistency here for now, so this is OK
self.outputs = [deepcopy(self.inputs[0])]
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and (self.inputs == other.inputs or self.inputs == other.inputs[::-1])
and self.outputs == other.outputs
)
class Input(IntermediateNode):
"""Node representing an input of the numpy program"""
@@ -47,7 +78,12 @@ class Input(IntermediateNode):
def __init__(
self,
inputs: Iterable[BaseValue],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(inputs)
assert op_args is None, f"Expected op_args to be None, got {op_args}"
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
assert len(self.inputs) == 1
self.outputs = [deepcopy(self.inputs[0])]

View File

@@ -0,0 +1,7 @@
"""HDK's module for basic tracing facilities"""
from .base_tracer import BaseTracer
from .tracing_helpers import (
create_graph_from_output_tracers,
make_input_tracer,
prepare_function_parameters,
)

View File

@@ -0,0 +1,67 @@
"""This file holds the code that can be shared between tracers"""
from abc import ABC
from typing import Any, Dict, List, Optional, Tuple, Type
from ..data_types import BaseValue
from ..representation import intermediate as ir
class BaseTracer(ABC):
"""Base class for implementing tracers"""
inputs: List["BaseTracer"]
traced_computation: ir.IntermediateNode
output: BaseValue
def __init__(
self,
inputs: List["BaseTracer"],
traced_computation: ir.IntermediateNode,
output_index: int,
) -> None:
self.inputs = inputs
self.traced_computation = traced_computation
self.output = traced_computation.outputs[output_index]
def instantiate_output_tracers(
self,
inputs: List["BaseTracer"],
computation_to_trace: Type[ir.IntermediateNode],
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["BaseTracer", ...]:
"""Helper functions to instantiate all output BaseTracer for a given computation
Args:
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
to instantiate for the computation being traced
op_args: *args coming from the call being traced
op_kwargs: **kwargs coming from the call being traced
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
"""
traced_computation = computation_to_trace(
map(lambda x: x.output, inputs),
op_args=op_args,
op_kwargs=op_kwargs,
)
output_tracers = tuple(
self.__class__(inputs, traced_computation, output_index)
for output_index in range(len(traced_computation.outputs))
)
return output_tracers
def __add__(self, other: "BaseTracer") -> "BaseTracer":
result_tracer = self.instantiate_output_tracers(
[self, other],
ir.Add,
)
assert len(result_tracer) == 1
return result_tracer[0]

View File

@@ -0,0 +1,95 @@
"""Helper functions for tracing"""
from inspect import signature
from typing import Callable, Dict, Iterable, Set, Tuple, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
from ..data_types import BaseValue
from ..representation import intermediate as ir
from .base_tracer import BaseTracer
def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer:
"""Helper function to create a tracer for an input value
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
BaseTracer
Returns:
BaseTracer: The BaseTracer for that input value
"""
return tracer_class([], ir.Input([input_value]), 0)
def prepare_function_parameters(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> Dict[str, BaseValue]:
"""Function to filter the passed function_parameters to trace function_to_trace
Args:
function_to_trace (Callable): function that will be traced for which parameters are checked
function_parameters (Dict[str, BaseValue]): parameters given to trace the function
Raises:
ValueError: Raised when some parameters are missing to trace function_to_trace
Returns:
Dict[str, BaseValue]: filtered function_parameters dictionary
"""
function_signature = signature(function_to_trace)
missing_args = function_signature.parameters.keys() - function_parameters.keys()
if len(missing_args) > 0:
raise ValueError(
f"The function '{function_to_trace.__name__}' requires the following parameters"
f"that were not provided: {', '.join(sorted(missing_args))}"
)
useless_arguments = function_parameters.keys() - function_signature.parameters.keys()
useful_arguments = function_signature.parameters.keys() - useless_arguments
return {k: function_parameters[k] for k in useful_arguments}
def create_graph_from_output_tracers(
output_tracers: Iterable[BaseTracer],
) -> nx.MultiDiGraph:
"""Generate a networkx Directed Graph that will represent the computation from a traced function
Args:
output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the
function over the proper input tracers
Returns:
nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes
representing the traced program/function
"""
graph = nx.MultiDiGraph()
visited_tracers: Set[BaseTracer] = set()
current_tracers = tuple(output_tracers)
while current_tracers:
next_tracers: Tuple[BaseTracer, ...] = tuple()
for tracer in current_tracers:
current_ir_node = tracer.traced_computation
graph.add_node(current_ir_node, content=current_ir_node)
for input_idx, input_tracer in enumerate(tracer.inputs):
input_ir_node = input_tracer.traced_computation
graph.add_node(input_ir_node, content=input_ir_node)
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
if input_tracer not in visited_tracers:
next_tracers += (input_tracer,)
visited_tracers.add(tracer)
current_tracers = next_tracers
assert is_directed_acyclic_graph(graph)
return graph

2
hdk/hnumpy/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""HDK's module for compiling numpy functions to homomorphic equivalents"""
from . import tracing

48
hdk/hnumpy/tracing.py Normal file
View File

@@ -0,0 +1,48 @@
"""hnumpy tracing utilities"""
from typing import Callable, Dict
import networkx as nx
from ..common.data_types import BaseValue
from ..common.tracing import (
BaseTracer,
create_graph_from_output_tracers,
make_input_tracer,
prepare_function_parameters,
)
class NPTracer(BaseTracer):
"""Tracer class for numpy operations"""
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> nx.MultiDiGraph:
"""Function used to trace a numpy function
Args:
function_to_trace (Callable): The function you want to trace
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
Returns:
nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the
input function
"""
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
input_tracers = {
param_name: make_input_tracer(NPTracer, param)
for param_name, param in function_parameters.items()
}
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
# the inputs that's why we create the graph starting from the outputs
output_tracers = function_to_trace(**input_tracers)
if isinstance(output_tracers, NPTracer):
output_tracers = (output_tracers,)
graph = create_graph_from_output_tracers(output_tracers)
return graph

View File

@@ -1,8 +1,8 @@
"""Test file for HDK's common/data_types/helpers.py"""
"""Test file for HDK's data types helpers"""
import pytest
from hdk.common.data_types.helpers import (
from hdk.common.data_types.dtypes_helpers import (
value_is_encrypted_integer,
value_is_encrypted_unsigned_integer,
)

View File

@@ -0,0 +1,26 @@
"""Test file for HDK's common tracing helpers"""
from typing import Any, Dict
import pytest
from hdk.common.tracing.tracing_helpers import prepare_function_parameters
@pytest.mark.parametrize(
"function,function_parameters,ref_dict",
[
pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)),
pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"),
pytest.param(
lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered"
),
],
)
def test_prepare_function_parameters(
function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any]
):
"""Test prepare_function_parameters"""
prepared_dict = prepare_function_parameters(function, function_parameters)
assert prepared_dict == ref_dict

View File

@@ -8,11 +8,13 @@ class TestHelpers:
"""Class allowing to pass helper functions to tests"""
@staticmethod
def digraphs_are_equivalent(reference: nx.DiGraph, to_compare: nx.DiGraph):
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
"""Check that two digraphs are equivalent without modifications"""
# edge_match is a copy of node_match
edge_matcher = iso.categorical_node_match("input_idx", None)
node_matcher = iso.categorical_node_match("content", None)
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
node_matcher = iso.generic_node_match(
"content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
)
graphs_are_isomorphic = nx.is_isomorphic(
reference,
to_compare,

View File

@@ -16,11 +16,13 @@ def test_digraphs_are_equivalent(test_helpers):
def __hash__(self) -> int:
return self.computation.__hash__()
def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
return self.computation == other.computation
g_1 = nx.DiGraph()
g_2 = nx.DiGraph()
is_equivalent_to = __eq__
g_1 = nx.MultiDiGraph()
g_2 = nx.MultiDiGraph()
t_0 = TestNode("Add")
t_1 = TestNode("Mul")
@@ -44,7 +46,7 @@ def test_digraphs_are_equivalent(test_helpers):
for node in g_2:
g_2.add_node(node, content=node)
bad_g2 = nx.DiGraph()
bad_g2 = nx.MultiDiGraph()
bad_t0 = TestNode("Not Add")
@@ -55,7 +57,7 @@ def test_digraphs_are_equivalent(test_helpers):
for node in bad_g2:
bad_g2.add_node(node, content=node)
bad_g3 = nx.DiGraph()
bad_g3 = nx.MultiDiGraph()
bad_g3.add_edge(t_0, t_2, input_idx=1)
bad_g3.add_edge(t_1, t_2, input_idx=0)

View File

@@ -0,0 +1,87 @@
"""Test file for HDK's hnumpy tracing"""
import networkx as nx
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.representation import intermediate as ir
from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"x",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="Clear int",
),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="Clear int",
),
],
)
def test_hnumpy_tracing_add(x, y, test_helpers):
"Test hnumpy tracing __add__"
def simple_add_function(x, y):
z = x + x
return z + y
graph = tracing.trace_numpy_function(simple_add_function, {"x": x, "y": y})
ref_graph = nx.MultiDiGraph()
input_x = ir.Input((x,))
input_y = ir.Input((y,))
add_node_z = ir.Add(
(
input_x.outputs[0],
input_x.outputs[0],
)
)
return_add_node = ir.Add(
(
add_node_z.outputs[0],
input_y.outputs[0],
)
)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_y, content=input_y)
ref_graph.add_node(add_node_z, content=add_node_z)
ref_graph.add_node(return_add_node, content=return_add_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
ref_graph.add_edge(add_node_z, return_add_node, input_idx=0)
ref_graph.add_edge(input_y, return_add_node, input_idx=1)
assert test_helpers.digraphs_are_equivalent(ref_graph, graph)