mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(tracing): add tracing facilities
- add BaseTracer which will hold most of the boilerplate code - add hnumpy with a bare NPTracer and tracing function - update IR to be compatible with tracing helpers - update test helper to properly check that graphs are equivalent - add test tracing a simple addition - rename common/data_types/helpers.py to .../dtypes_helpers.py to avoid having too many files with the same name - ignore missing type stubs in the default mypy command - add a comfort Makefile target to get errors about missing mypy stubs
This commit is contained in:
9
Makefile
9
Makefile
@@ -30,10 +30,17 @@ pytest:
|
||||
poetry run pytest --cov=hdk -vv --cov-report=xml tests/
|
||||
.PHONY: pytest
|
||||
|
||||
# Not a huge fan of ignoring missing imports, but some packages do not have typing stubs
|
||||
mypy:
|
||||
poetry run mypy -p hdk
|
||||
poetry run mypy -p hdk --ignore-missing-imports
|
||||
.PHONY: mypy
|
||||
|
||||
# Friendly target to run mypy without ignoring missing stubs and still have errors messages
|
||||
# Allows to see which stubs we are missing
|
||||
mypy_ns:
|
||||
poetry run mypy -p hdk
|
||||
.PHONY: mypy_ns
|
||||
|
||||
docs:
|
||||
cd docs && poetry run make html
|
||||
.PHONY: docs
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
"""HDK's top import"""
|
||||
from . import common
|
||||
from . import common, hnumpy
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""HDK's module for data types code and data structures"""
|
||||
from . import helpers, integers, values
|
||||
from . import dtypes_helpers, integers, values
|
||||
from .values import BaseValue
|
||||
|
||||
@@ -17,6 +17,13 @@ class Integer(base.BaseDataType):
|
||||
signed_str = "signed" if self.is_signed else "unsigned"
|
||||
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.bit_width == other.bit_width
|
||||
and self.is_signed == other.is_signed
|
||||
)
|
||||
|
||||
def min_value(self) -> int:
|
||||
"""Minimum value representable by the Integer"""
|
||||
if self.is_signed:
|
||||
|
||||
@@ -16,6 +16,9 @@ class BaseValue(ABC):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}<{self.data_type!r}>"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.data_type == other.data_type
|
||||
|
||||
|
||||
class ClearValue(BaseValue):
|
||||
"""Class representing a clear/plaintext value (constant or not)"""
|
||||
|
||||
@@ -22,9 +22,28 @@ class IntermediateNode(ABC):
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
assert all(map(lambda x: isinstance(x, BaseValue), self.inputs))
|
||||
self.op_args = op_args
|
||||
self.op_kwargs = op_kwargs
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
"""Overriding __eq__ has unwanted side effects, this provides the same facility without
|
||||
disrupting expected behavior too much
|
||||
|
||||
Args:
|
||||
other (object): Other object to check against
|
||||
|
||||
Returns:
|
||||
bool: True if the other object is equivalent
|
||||
"""
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.inputs == other.inputs
|
||||
and self.outputs == other.outputs
|
||||
and self.op_args == other.op_args
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
)
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values"""
|
||||
@@ -32,14 +51,26 @@ class Add(IntermediateNode):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert op_args is None, f"Expected op_args to be None, got {op_args}"
|
||||
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
|
||||
|
||||
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
|
||||
assert len(self.inputs) == 2
|
||||
|
||||
# For now copy the first input type for the output type
|
||||
# We don't perform checks or enforce consistency here for now, so this is OK
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and (self.inputs == other.inputs or self.inputs == other.inputs[::-1])
|
||||
and self.outputs == other.outputs
|
||||
)
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the numpy program"""
|
||||
@@ -47,7 +78,12 @@ class Input(IntermediateNode):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(inputs)
|
||||
assert op_args is None, f"Expected op_args to be None, got {op_args}"
|
||||
assert op_kwargs is None, f"Expected op_kwargs to be None, got {op_kwargs}"
|
||||
|
||||
super().__init__(inputs, op_args=op_args, op_kwargs=op_kwargs)
|
||||
assert len(self.inputs) == 1
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
7
hdk/common/tracing/__init__.py
Normal file
7
hdk/common/tracing/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""HDK's module for basic tracing facilities"""
|
||||
from .base_tracer import BaseTracer
|
||||
from .tracing_helpers import (
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracer,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
67
hdk/common/tracing/base_tracer.py
Normal file
67
hdk/common/tracing/base_tracer.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""This file holds the code that can be shared between tracers"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
class BaseTracer(ABC):
|
||||
"""Base class for implementing tracers"""
|
||||
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: ir.IntermediateNode
|
||||
output: BaseValue
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: List["BaseTracer"],
|
||||
traced_computation: ir.IntermediateNode,
|
||||
output_index: int,
|
||||
) -> None:
|
||||
self.inputs = inputs
|
||||
self.traced_computation = traced_computation
|
||||
self.output = traced_computation.outputs[output_index]
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: List["BaseTracer"],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Helper functions to instantiate all output BaseTracer for a given computation
|
||||
|
||||
Args:
|
||||
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
|
||||
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
|
||||
to instantiate for the computation being traced
|
||||
op_args: *args coming from the call being traced
|
||||
op_kwargs: **kwargs coming from the call being traced
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
|
||||
"""
|
||||
traced_computation = computation_to_trace(
|
||||
map(lambda x: x.output, inputs),
|
||||
op_args=op_args,
|
||||
op_kwargs=op_kwargs,
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
self.__class__(inputs, traced_computation, output_index)
|
||||
for output_index in range(len(traced_computation.outputs))
|
||||
)
|
||||
|
||||
return output_tracers
|
||||
|
||||
def __add__(self, other: "BaseTracer") -> "BaseTracer":
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Add,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
95
hdk/common/tracing/tracing_helpers.py
Normal file
95
hdk/common/tracing/tracing_helpers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Helper functions for tracing"""
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Iterable, Set, Tuple, Type
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..representation import intermediate as ir
|
||||
from .base_tracer import BaseTracer
|
||||
|
||||
|
||||
def make_input_tracer(tracer_class: Type[BaseTracer], input_value: BaseValue) -> BaseTracer:
|
||||
"""Helper function to create a tracer for an input value
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
input_value (BaseValue): the Value that is an input and needs to be wrapped in an
|
||||
BaseTracer
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that input value
|
||||
"""
|
||||
return tracer_class([], ir.Input([input_value]), 0)
|
||||
|
||||
|
||||
def prepare_function_parameters(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> Dict[str, BaseValue]:
|
||||
"""Function to filter the passed function_parameters to trace function_to_trace
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): function that will be traced for which parameters are checked
|
||||
function_parameters (Dict[str, BaseValue]): parameters given to trace the function
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when some parameters are missing to trace function_to_trace
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseValue]: filtered function_parameters dictionary
|
||||
"""
|
||||
function_signature = signature(function_to_trace)
|
||||
|
||||
missing_args = function_signature.parameters.keys() - function_parameters.keys()
|
||||
|
||||
if len(missing_args) > 0:
|
||||
raise ValueError(
|
||||
f"The function '{function_to_trace.__name__}' requires the following parameters"
|
||||
f"that were not provided: {', '.join(sorted(missing_args))}"
|
||||
)
|
||||
|
||||
useless_arguments = function_parameters.keys() - function_signature.parameters.keys()
|
||||
useful_arguments = function_signature.parameters.keys() - useless_arguments
|
||||
|
||||
return {k: function_parameters[k] for k in useful_arguments}
|
||||
|
||||
|
||||
def create_graph_from_output_tracers(
|
||||
output_tracers: Iterable[BaseTracer],
|
||||
) -> nx.MultiDiGraph:
|
||||
"""Generate a networkx Directed Graph that will represent the computation from a traced function
|
||||
|
||||
Args:
|
||||
output_tracers (Iterable[BaseTracer]): the output tracers resulting from running the
|
||||
function over the proper input tracers
|
||||
|
||||
Returns:
|
||||
nx.MultiDiGraph: Directed Graph that is guaranteed to be a DAG containing the ir nodes
|
||||
representing the traced program/function
|
||||
"""
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
visited_tracers: Set[BaseTracer] = set()
|
||||
current_tracers = tuple(output_tracers)
|
||||
|
||||
while current_tracers:
|
||||
next_tracers: Tuple[BaseTracer, ...] = tuple()
|
||||
for tracer in current_tracers:
|
||||
current_ir_node = tracer.traced_computation
|
||||
graph.add_node(current_ir_node, content=current_ir_node)
|
||||
|
||||
for input_idx, input_tracer in enumerate(tracer.inputs):
|
||||
input_ir_node = input_tracer.traced_computation
|
||||
graph.add_node(input_ir_node, content=input_ir_node)
|
||||
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers += (input_tracer,)
|
||||
|
||||
visited_tracers.add(tracer)
|
||||
|
||||
current_tracers = next_tracers
|
||||
|
||||
assert is_directed_acyclic_graph(graph)
|
||||
|
||||
return graph
|
||||
2
hdk/hnumpy/__init__.py
Normal file
2
hdk/hnumpy/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""HDK's module for compiling numpy functions to homomorphic equivalents"""
|
||||
from . import tracing
|
||||
48
hdk/hnumpy/tracing.py
Normal file
48
hdk/hnumpy/tracing.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""hnumpy tracing utilities"""
|
||||
from typing import Callable, Dict
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.tracing import (
|
||||
BaseTracer,
|
||||
create_graph_from_output_tracers,
|
||||
make_input_tracer,
|
||||
prepare_function_parameters,
|
||||
)
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations"""
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> nx.MultiDiGraph:
|
||||
"""Function used to trace a numpy function
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
|
||||
|
||||
Returns:
|
||||
nx.MultiDiGraph: The graph containing the ir nodes representing the computation done in the
|
||||
input function
|
||||
"""
|
||||
function_parameters = prepare_function_parameters(function_to_trace, function_parameters)
|
||||
|
||||
input_tracers = {
|
||||
param_name: make_input_tracer(NPTracer, param)
|
||||
for param_name, param in function_parameters.items()
|
||||
}
|
||||
|
||||
# We could easily create a graph of NPTracer, but we may end up with dead nodes starting from
|
||||
# the inputs that's why we create the graph starting from the outputs
|
||||
output_tracers = function_to_trace(**input_tracers)
|
||||
if isinstance(output_tracers, NPTracer):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
graph = create_graph_from_output_tracers(output_tracers)
|
||||
|
||||
return graph
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Test file for HDK's common/data_types/helpers.py"""
|
||||
"""Test file for HDK's data types helpers"""
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.helpers import (
|
||||
from hdk.common.data_types.dtypes_helpers import (
|
||||
value_is_encrypted_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
26
tests/common/tracing/test_tracing_helpers.py
Normal file
26
tests/common/tracing/test_tracing_helpers.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Test file for HDK's common tracing helpers"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.tracing.tracing_helpers import prepare_function_parameters
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,function_parameters,ref_dict",
|
||||
[
|
||||
pytest.param(lambda x: None, {}, {}, id="Missing x", marks=pytest.mark.xfail(strict=True)),
|
||||
pytest.param(lambda x: None, {"x": None}, {"x": None}, id="Only x"),
|
||||
pytest.param(
|
||||
lambda x: None, {"x": None, "y": None}, {"x": None}, id="Additional y filtered"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_function_parameters(
|
||||
function, function_parameters: Dict[str, Any], ref_dict: Dict[str, Any]
|
||||
):
|
||||
"""Test prepare_function_parameters"""
|
||||
prepared_dict = prepare_function_parameters(function, function_parameters)
|
||||
|
||||
assert prepared_dict == ref_dict
|
||||
@@ -8,11 +8,13 @@ class TestHelpers:
|
||||
"""Class allowing to pass helper functions to tests"""
|
||||
|
||||
@staticmethod
|
||||
def digraphs_are_equivalent(reference: nx.DiGraph, to_compare: nx.DiGraph):
|
||||
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
|
||||
"""Check that two digraphs are equivalent without modifications"""
|
||||
# edge_match is a copy of node_match
|
||||
edge_matcher = iso.categorical_node_match("input_idx", None)
|
||||
node_matcher = iso.categorical_node_match("content", None)
|
||||
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
|
||||
node_matcher = iso.generic_node_match(
|
||||
"content", None, lambda lhs, rhs: lhs.is_equivalent_to(rhs)
|
||||
)
|
||||
graphs_are_isomorphic = nx.is_isomorphic(
|
||||
reference,
|
||||
to_compare,
|
||||
|
||||
@@ -16,11 +16,13 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
def __hash__(self) -> int:
|
||||
return self.computation.__hash__()
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self.computation == other.computation
|
||||
|
||||
g_1 = nx.DiGraph()
|
||||
g_2 = nx.DiGraph()
|
||||
is_equivalent_to = __eq__
|
||||
|
||||
g_1 = nx.MultiDiGraph()
|
||||
g_2 = nx.MultiDiGraph()
|
||||
|
||||
t_0 = TestNode("Add")
|
||||
t_1 = TestNode("Mul")
|
||||
@@ -44,7 +46,7 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
for node in g_2:
|
||||
g_2.add_node(node, content=node)
|
||||
|
||||
bad_g2 = nx.DiGraph()
|
||||
bad_g2 = nx.MultiDiGraph()
|
||||
|
||||
bad_t0 = TestNode("Not Add")
|
||||
|
||||
@@ -55,7 +57,7 @@ def test_digraphs_are_equivalent(test_helpers):
|
||||
for node in bad_g2:
|
||||
bad_g2.add_node(node, content=node)
|
||||
|
||||
bad_g3 = nx.DiGraph()
|
||||
bad_g3 = nx.MultiDiGraph()
|
||||
|
||||
bad_g3.add_edge(t_0, t_2, input_idx=1)
|
||||
bad_g3.add_edge(t_1, t_2, input_idx=0)
|
||||
|
||||
87
tests/hnumpy/test_tracing.py
Normal file
87
tests/hnumpy/test_tracing.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test file for HDK's hnumpy tracing"""
|
||||
|
||||
import networkx as nx
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.representation import intermediate as ir
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(64, is_signed=True)),
|
||||
id="Encrypted int",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=False)),
|
||||
id="Clear uint",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=True)),
|
||||
id="Clear int",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"y",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
pytest.param(
|
||||
EncryptedValue(Integer(64, is_signed=True)),
|
||||
id="Encrypted int",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=False)),
|
||||
id="Clear uint",
|
||||
),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=True)),
|
||||
id="Clear int",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_tracing_add(x, y, test_helpers):
|
||||
"Test hnumpy tracing __add__"
|
||||
|
||||
def simple_add_function(x, y):
|
||||
z = x + x
|
||||
return z + y
|
||||
|
||||
graph = tracing.trace_numpy_function(simple_add_function, {"x": x, "y": y})
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
|
||||
input_x = ir.Input((x,))
|
||||
input_y = ir.Input((y,))
|
||||
|
||||
add_node_z = ir.Add(
|
||||
(
|
||||
input_x.outputs[0],
|
||||
input_x.outputs[0],
|
||||
)
|
||||
)
|
||||
|
||||
return_add_node = ir.Add(
|
||||
(
|
||||
add_node_z.outputs[0],
|
||||
input_y.outputs[0],
|
||||
)
|
||||
)
|
||||
|
||||
ref_graph.add_node(input_x, content=input_x)
|
||||
ref_graph.add_node(input_y, content=input_y)
|
||||
ref_graph.add_node(add_node_z, content=add_node_z)
|
||||
ref_graph.add_node(return_add_node, content=return_add_node)
|
||||
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
|
||||
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
|
||||
|
||||
ref_graph.add_edge(add_node_z, return_add_node, input_idx=0)
|
||||
ref_graph.add_edge(input_y, return_add_node, input_idx=1)
|
||||
|
||||
assert test_helpers.digraphs_are_equivalent(ref_graph, graph)
|
||||
Reference in New Issue
Block a user