mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
0baa02549c
commit
6157e4680b
6
hdk/common/data_types/scalars.py
Normal file
6
hdk/common/data_types/scalars.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""File holding code to represent data types used for constants in programs"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
# TODO: deal with more types
|
||||
Scalars = Union[int, float]
|
||||
@@ -7,7 +7,13 @@ import networkx as nx
|
||||
from hdk.common.operator_graph import OPGraph
|
||||
from hdk.common.representation import intermediate as ir
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
|
||||
IR_NODE_COLOR_MAPPING = {
|
||||
ir.Input: "blue",
|
||||
ir.ConstantInput: "cyan",
|
||||
ir.Add: "red",
|
||||
ir.Sub: "yellow",
|
||||
ir.Mul: "green",
|
||||
}
|
||||
|
||||
|
||||
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
|
||||
@@ -117,10 +123,14 @@ def draw_graph(
|
||||
# For most types, we just pick the operation as the label, but for Input,
|
||||
# we take the name of the variable, ie the argument name of the function
|
||||
# to compile
|
||||
label_dict = {
|
||||
node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__
|
||||
for node in graph.nodes()
|
||||
}
|
||||
def get_proper_name(node):
|
||||
if isinstance(node, ir.Input):
|
||||
return node.input_name
|
||||
if isinstance(node, ir.ConstantInput):
|
||||
return str(node.constant_data)
|
||||
return node.__class__.__name__
|
||||
|
||||
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
|
||||
|
||||
# Draw nodes
|
||||
nx.draw_networkx_nodes(
|
||||
@@ -222,7 +232,11 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
|
||||
|
||||
for node in nx.topological_sort(graph):
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
if isinstance(node, ir.Input):
|
||||
what_to_print = node.input_name
|
||||
elif isinstance(node, ir.ConstantInput):
|
||||
what_to_print = f"ConstantInput({node.constant_data})"
|
||||
else:
|
||||
what_to_print = node.__class__.__name__ + "("
|
||||
|
||||
# Find all the names of the current predecessors of the node
|
||||
@@ -243,9 +257,6 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
|
||||
list_of_arg_name.sort()
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
else:
|
||||
what_to_print = node.input_name
|
||||
|
||||
returned_str += f"\n%{i} = {what_to_print}"
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
@@ -27,11 +27,9 @@ class OPGraph:
|
||||
self.input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in self.graph.nodes()
|
||||
if len(self.graph.pred[node]) == 0
|
||||
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
|
||||
}
|
||||
|
||||
assert all(map(lambda x: isinstance(x, ir.Input), self.input_nodes.values()))
|
||||
|
||||
graph_outputs = set(node for node in self.graph.nodes() if len(self.graph.succ[node]) == 0)
|
||||
|
||||
assert set(self.output_nodes.values()) == graph_outputs
|
||||
|
||||
@@ -6,6 +6,9 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_int
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..data_types.values import ClearValue
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
@@ -117,7 +120,7 @@ class Mul(IntermediateNode):
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the numpy program"""
|
||||
"""Node representing an input of the program"""
|
||||
|
||||
input_name: str
|
||||
program_input_idx: int
|
||||
@@ -136,3 +139,26 @@ class Input(IntermediateNode):
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
|
||||
class ConstantInput(IntermediateNode):
|
||||
"""Node representing a constant of the program"""
|
||||
|
||||
constant_data: Scalars
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constant_data: Scalars,
|
||||
) -> None:
|
||||
super().__init__([])
|
||||
self.constant_data = constant_data
|
||||
|
||||
# TODO: manage other cases, we can't call get_bits_to_represent_int
|
||||
assert isinstance(constant_data, int)
|
||||
is_signed = constant_data < 0
|
||||
self.outputs = [
|
||||
ClearValue(Integer(get_bits_to_represent_int(constant_data, is_signed), is_signed))
|
||||
]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""This file holds the code that can be shared between tracers"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
@@ -26,7 +27,7 @@ class BaseTracer(ABC):
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: List["BaseTracer"],
|
||||
inputs: List[Union["BaseTracer", Scalars]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -44,20 +45,30 @@ class BaseTracer(ABC):
|
||||
Returns:
|
||||
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
|
||||
"""
|
||||
|
||||
# For inputs which are actually constant, first convert into a tracer
|
||||
def sanitize(inp):
|
||||
if not isinstance(inp, BaseTracer):
|
||||
return make_const_input_tracer(self.__class__, inp)
|
||||
return inp
|
||||
|
||||
sanitized_inputs = [sanitize(inp) for inp in inputs]
|
||||
|
||||
traced_computation = computation_to_trace(
|
||||
map(lambda x: x.output, inputs),
|
||||
map(lambda x: x.output, sanitized_inputs),
|
||||
op_args=op_args,
|
||||
op_kwargs=op_kwargs,
|
||||
)
|
||||
|
||||
output_tracers = tuple(
|
||||
self.__class__(inputs, traced_computation, output_index)
|
||||
self.__class__(sanitized_inputs, traced_computation, output_index)
|
||||
for output_index in range(len(traced_computation.outputs))
|
||||
)
|
||||
|
||||
return output_tracers
|
||||
|
||||
def __add__(self, other: "BaseTracer") -> "BaseTracer":
|
||||
def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Add,
|
||||
@@ -66,7 +77,13 @@ class BaseTracer(ABC):
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
def __sub__(self, other: "BaseTracer") -> "BaseTracer":
|
||||
# With that is that x + 1 and 1 + x have the same graph. If we want to keep
|
||||
# the order, we need to do as in __rsub__, ie mostly a copy of __sub__ +
|
||||
# some changes
|
||||
__radd__ = __add__
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Sub,
|
||||
@@ -75,7 +92,17 @@ class BaseTracer(ABC):
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
def __mul__(self, other: "BaseTracer") -> "BaseTracer":
|
||||
def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[other, self],
|
||||
ir.Sub,
|
||||
)
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Mul,
|
||||
@@ -83,3 +110,21 @@ class BaseTracer(ABC):
|
||||
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
# With that is that x * 3 and 3 * x have the same graph. If we want to keep
|
||||
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
|
||||
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer:
|
||||
"""Helper function to create a tracer for a constant input
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
|
||||
constant_data (Scalars): the constant
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that constant
|
||||
"""
|
||||
return tracer_class([], ir.ConstantInput(constant_data), 0)
|
||||
|
||||
@@ -23,7 +23,28 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
((-10, 2), (-4, 5)),
|
||||
(-14, 7),
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y, (-10, 2), (-4, 5), (-14, 9)",
|
||||
id="x + y, (-10, 2), (-4, 5), (-14, 7)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + 1,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-13, 8),
|
||||
Integer(5, is_signed=True),
|
||||
id="x + y + 1, (-10, 2), (-4, 5), (-13, 8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y + (-3),
|
||||
((-10, 2), (-4, 5)),
|
||||
(-17, 4),
|
||||
Integer(6, is_signed=True),
|
||||
id="x + y + 1, (-10, 2), (-4, 5), (-17, 4)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (1 + x) + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-13, 8),
|
||||
Integer(5, is_signed=True),
|
||||
id="(1 + x) + y, (-10, 2), (-4, 5), (-13, 8)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
@@ -39,6 +60,27 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
Integer(5, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-15, 6)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y - 42,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-57, -36),
|
||||
Integer(7, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-57, -36)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: 3 - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-3, 18),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-3, 18)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-13) - x + y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-19, 2),
|
||||
Integer(6, is_signed=True),
|
||||
id="x - y, (-10, 2), (-4, 5), (-16, 2)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
((-10, 10), (-10, 10)),
|
||||
@@ -53,6 +95,27 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-50, 40)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (3 * x) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-150, 120),
|
||||
Integer(9, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-150, 120)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * 11) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-550, 440),
|
||||
Integer(11, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-550, 440)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * (-11)) * y,
|
||||
((-10, 2), (-4, 5)),
|
||||
(-440, 550),
|
||||
Integer(11, is_signed=True),
|
||||
id="x * y, (-10, 2), (-4, 5), (-440, 550)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + x + y,
|
||||
((-10, 10), (-10, 10)),
|
||||
@@ -88,6 +151,13 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
Integer(7, is_signed=True),
|
||||
id="x * y - x, (-10, 2), (-4, 5), (-40, 50),",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x * 3) * y - (x + 3) + (y - 13) + x * (11 + y) * (12 + y) + (15 - x),
|
||||
((-10, 2), (-4, 5)),
|
||||
(-2846, 574),
|
||||
Integer(13, is_signed=True),
|
||||
id="x * y - x, (-10, 2), (-4, 5), (-2846, 574),",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eval_op_graph_bounds_on_dataset(
|
||||
|
||||
@@ -29,6 +29,8 @@ from hdk.common.representation import intermediate as ir
|
||||
id="Mul",
|
||||
),
|
||||
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
|
||||
pytest.param(ir.ConstantInput(42), None, 42, id="ConstantInput"),
|
||||
pytest.param(ir.ConstantInput(-42), None, -42, id="ConstantInput"),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
|
||||
@@ -13,11 +13,34 @@ from hdk.hnumpy import tracing
|
||||
[
|
||||
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"),
|
||||
(
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
|
||||
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
|
||||
),
|
||||
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"),
|
||||
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"),
|
||||
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"),
|
||||
(
|
||||
lambda x, y: x + 13 - y * (-21) * y + 44,
|
||||
"\n%0 = ConstantInput(44)"
|
||||
"\n%1 = x"
|
||||
"\n%2 = ConstantInput(13)"
|
||||
"\n%3 = y"
|
||||
"\n%4 = ConstantInput(-21)"
|
||||
"\n%5 = Add(1, 2)"
|
||||
"\n%6 = Mul(3, 4)"
|
||||
"\n%7 = Mul(6, 3)"
|
||||
"\n%8 = Sub(5, 7)"
|
||||
"\n%9 = Add(8, 0)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -48,6 +71,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
print(f"\nGot {str_of_the_graph}\n")
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
Reference in New Issue
Block a user