From 8925fbd2db64829a39b6fb3a6f9bee9b89e7b76c Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 28 Jul 2021 14:44:25 +0200 Subject: [PATCH] dev(ir): add evalute to simulate the computation represented by IR nodes --- hdk/common/representation/intermediate.py | 27 ++++++++++++- .../representation/test_intermediate.py | 40 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 tests/common/representation/test_intermediate.py diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 9e86ed569..a9a6d6c77 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -1,8 +1,8 @@ """File containing HDK's intermdiate representation of source programs operations""" -from abc import ABC +from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from ..data_types import BaseValue from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype @@ -74,6 +74,17 @@ class IntermediateNode(ABC): and self.op_kwargs == other.op_kwargs ) + @abstractmethod + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + """Function to simulate what the represented computation would output for the given inputs + + Args: + inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation + + Returns: + Any: the result of the computation + """ + class Add(IntermediateNode): """Addition between two values""" @@ -81,6 +92,9 @@ class Add(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + return inputs[0] + inputs[1] + class Sub(IntermediateNode): """Subtraction between two values""" @@ -88,6 +102,9 @@ class Sub(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + return inputs[0] - inputs[1] + class Mul(IntermediateNode): """Multiplication between two values""" @@ -95,6 +112,9 @@ class Mul(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + return inputs[0] * inputs[1] + class Input(IntermediateNode): """Node representing an input of the numpy program""" @@ -113,3 +133,6 @@ class Input(IntermediateNode): self.input_name = input_name self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] + + def evaluate(self, inputs: Mapping[int, Any]) -> Any: + return inputs[0] diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py new file mode 100644 index 000000000..d2956eb9a --- /dev/null +++ b/tests/common/representation/test_intermediate.py @@ -0,0 +1,40 @@ +"""Test file for HDK's common/representation/intermediate.py""" + +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearValue, EncryptedValue +from hdk.common.representation import intermediate as ir + + +@pytest.mark.parametrize( + "node,input_data,expected_result", + [ + pytest.param( + ir.Add([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]), + [10, 4589], + 4599, + id="Add", + ), + pytest.param( + ir.Sub([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]), + [10, 4589], + -4579, + id="Sub", + ), + pytest.param( + ir.Mul([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]), + [10, 4589], + 45890, + id="Mul", + ), + pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"), + ], +) +def test_evaluate( + node: ir.IntermediateNode, + input_data, + expected_result: int, +): + """Test evaluate methods on IntermediateNodes""" + assert node.evaluate(input_data) == expected_result