From 6d663ef63d2af0639bf0d3f421da6fd938a7f8d7 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 23 Aug 2021 15:53:52 +0200 Subject: [PATCH] dev(ir): add Dot IntermediateNode --- hdk/common/debugging/drawing.py | 1 + hdk/common/representation/intermediate.py | 65 ++++++++++++++- .../representation/test_intermediate.py | 81 ++++++++++++++++++- 3 files changed, 145 insertions(+), 2 deletions(-) diff --git a/hdk/common/debugging/drawing.py b/hdk/common/debugging/drawing.py index 62a17aba3..a4662da8b 100644 --- a/hdk/common/debugging/drawing.py +++ b/hdk/common/debugging/drawing.py @@ -19,6 +19,7 @@ IR_NODE_COLOR_MAPPING = { ir.Sub: "yellow", ir.Mul: "green", ir.ArbitraryFunction: "orange", + ir.Dot: "purple", "ArbitraryFunction": "orange", "TLU": "grey", "output": "magenta", diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 5978ce617..2e563b301 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -9,7 +9,7 @@ from ..data_types.dtypes_helpers import ( get_base_value_for_python_constant_data, mix_scalar_values_determine_holding_dtype, ) -from ..values import BaseValue +from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -291,3 +291,66 @@ class ArbitraryFunction(IntermediateNode): def label(self) -> str: return self.op_name + + +def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any: + """Default python dot implementation for 1D iterable arrays. + + Args: + lhs (Any): lhs vector of the dot. + rhs (Any): rhs vector of the dot. + + Returns: + Any: the result of the dot operation. + """ + return sum(lhs * rhs for lhs, rhs in zip(lhs, rhs)) + + +class Dot(IntermediateNode): + """Node representing a dot product.""" + + _n_in: int = 2 + # Optional, same issue as in ArbitraryFunction for mypy + evaluation_function: Optional[Callable[[Any, Any], Any]] + # Allows to use specialized implementations from e.g. numpy + + def __init__( + self, + inputs: Iterable[BaseValue], + output_dtype: BaseDataType, + delegate_evaluation_function: Optional[ + Callable[[Any, Any], Any] + ] = default_dot_evaluation_function, + ) -> None: + super().__init__(inputs) + assert len(self.inputs) == 2 + + assert all( + isinstance(input_value, TensorValue) and input_value.ndim == 1 + for input_value in self.inputs + ), f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)" + + output_scalar_value = ( + EncryptedValue + if (self.inputs[0].is_encrypted or self.inputs[1].is_encrypted) + else ClearValue + ) + + self.outputs = [output_scalar_value(output_dtype)] + self.evaluation_function = delegate_evaluation_function + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + # This is the continuation of the mypy bug workaround + assert self.evaluation_function is not None + return self.evaluation_function(inputs[0], inputs[1]) + + def is_equivalent_to(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and self.evaluation_function == other.evaluation_function + and super().is_equivalent_to(other) + ) + + # TODO: Coverage will come with the ability to trace the operator in a subsequent PR + def label(self) -> str: # pragma: no cover + return "dot" diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index 27f6618c4..9283fabd0 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -1,10 +1,12 @@ """Test file for intermediate representation""" +import numpy import pytest +from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.common.representation import intermediate as ir -from hdk.common.values import ClearValue, EncryptedValue +from hdk.common.values import ClearTensor, ClearValue, EncryptedTensor, EncryptedValue @pytest.mark.parametrize( @@ -72,6 +74,46 @@ from hdk.common.values import ClearValue, EncryptedValue 4, id="ArbitraryFunction, x, y -> y[3], where y is constant == (1, 2, 3, 4)", ), + pytest.param( + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + ), + [[1, 2, 3, 4], [4, 3, 2, 1]], + 20, + id="Dot, [1, 2, 3, 4], [4, 3, 2, 1]", + ), + pytest.param( + ir.Dot( + [ + EncryptedTensor(Float(32), shape=(4,)), + ClearTensor(Float(32), shape=(4,)), + ], + Float(32), + ), + [[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]], + 20, + id="Dot, [1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]", + ), + pytest.param( + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + delegate_evaluation_function=numpy.dot, + ), + [ + numpy.array([1, 2, 3, 4], dtype=numpy.int32), + numpy.array([4, 3, 2, 1], dtype=numpy.int32), + ], + 20, + id="Dot, np.array([1, 2, 3, 4]), np.array([4, 3, 2, 1])", + ), ], ) def test_evaluate( @@ -191,6 +233,43 @@ def test_evaluate( ir.ArbitraryFunction(EncryptedValue(Integer(8, False)), lambda x: x, Integer(8, False)), False, ), + ( + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + delegate_evaluation_function=numpy.dot, + ), + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + delegate_evaluation_function=numpy.dot, + ), + True, + ), + ( + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + delegate_evaluation_function=numpy.dot, + ), + ir.Dot( + [ + EncryptedTensor(Integer(32, True), shape=(4,)), + ClearTensor(Integer(32, True), shape=(4,)), + ], + Integer(32, True), + ), + False, + ), ], ) def test_is_equivalent_to(