dev(ir): add Dot IntermediateNode

This commit is contained in:
Arthur Meyre
2021-08-23 15:53:52 +02:00
parent 4655bea987
commit 6d663ef63d
3 changed files with 145 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ IR_NODE_COLOR_MAPPING = {
ir.Sub: "yellow",
ir.Mul: "green",
ir.ArbitraryFunction: "orange",
ir.Dot: "purple",
"ArbitraryFunction": "orange",
"TLU": "grey",
"output": "magenta",

View File

@@ -9,7 +9,7 @@ from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
mix_scalar_values_determine_holding_dtype,
)
from ..values import BaseValue
from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
@@ -291,3 +291,66 @@ class ArbitraryFunction(IntermediateNode):
def label(self) -> str:
return self.op_name
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
"""Default python dot implementation for 1D iterable arrays.
Args:
lhs (Any): lhs vector of the dot.
rhs (Any): rhs vector of the dot.
Returns:
Any: the result of the dot operation.
"""
return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs))
class Dot(IntermediateNode):
"""Node representing a dot product."""
_n_in: int = 2
# Optional, same issue as in ArbitraryFunction for mypy
evaluation_function: Optional[Callable[[Any, Any], Any]]
# Allows to use specialized implementations from e.g. numpy
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
delegate_evaluation_function: Optional[
Callable[[Any, Any], Any]
] = default_dot_evaluation_function,
) -> None:
super().__init__(inputs)
assert len(self.inputs) == 2
assert all(
isinstance(input_value, TensorValue) and input_value.ndim == 1
for input_value in self.inputs
), f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)"
output_scalar_value = (
EncryptedValue
if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted)
else ClearValue
)
self.outputs = [output_scalar_value(output_dtype)]
self.evaluation_function = delegate_evaluation_function
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.evaluation_function is not None
return self.evaluation_function(inputs[0], inputs[1])
def is_equivalent_to(self, other: object) -> bool:
return (
isinstance(other, self.__class__)
and self.evaluation_function == other.evaluation_function
and super().is_equivalent_to(other)
)
# TODO: Coverage will come with the ability to trace the operator in a subsequent PR
def label(self) -> str: # pragma: no cover
return "dot"

View File

@@ -1,10 +1,12 @@
"""Test file for intermediate representation"""
import numpy
import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.common.representation import intermediate as ir
from hdk.common.values import ClearValue, EncryptedValue
from hdk.common.values import ClearTensor, ClearValue, EncryptedTensor, EncryptedValue
@pytest.mark.parametrize(
@@ -72,6 +74,46 @@ from hdk.common.values import ClearValue, EncryptedValue
4,
id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
),
[[1, 2, 3, 4], [4, 3, 2, 1]],
20,
id="Dot, [1, 2, 3, 4], [4, 3, 2, 1]",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Float(32), shape=(4,)),
ClearTensor(Float(32), shape=(4,)),
],
Float(32),
),
[[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]],
20,
id="Dot, [1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]",
),
pytest.param(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
[
numpy.array([1, 2, 3, 4], dtype=numpy.int32),
numpy.array([4, 3, 2, 1], dtype=numpy.int32),
],
20,
id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])",
),
],
)
def test_evaluate(
@@ -191,6 +233,43 @@ def test_evaluate(
ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)),
False,
),
(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
True,
),
(
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
delegate_evaluation_function=numpy.dot,
),
ir.Dot(
[
EncryptedTensor(Integer(32, True), shape=(4,)),
ClearTensor(Integer(32, True), shape=(4,)),
],
Integer(32, True),
),
False,
),
],
)
def test_is_equivalent_to(