dev(ir): add evalute to simulate the computation represented by IR nodes

This commit is contained in:
Arthur Meyre
2021-07-28 14:44:25 +02:00
parent d739e6672d
commit 8925fbd2db
2 changed files with 65 additions and 2 deletions

View File

@@ -1,8 +1,8 @@
"""File containing HDK's intermdiate representation of source programs operations"""
from abc import ABC
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
@@ -74,6 +74,17 @@ class IntermediateNode(ABC):
and self.op_kwargs == other.op_kwargs
)
@abstractmethod
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
"""Function to simulate what the represented computation would output for the given inputs
Args:
inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation
Returns:
Any: the result of the computation
"""
class Add(IntermediateNode):
"""Addition between two values"""
@@ -81,6 +92,9 @@ class Add(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0] + inputs[1]
class Sub(IntermediateNode):
"""Subtraction between two values"""
@@ -88,6 +102,9 @@ class Sub(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0] - inputs[1]
class Mul(IntermediateNode):
"""Multiplication between two values"""
@@ -95,6 +112,9 @@ class Mul(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0] * inputs[1]
class Input(IntermediateNode):
"""Node representing an input of the numpy program"""
@@ -113,3 +133,6 @@ class Input(IntermediateNode):
self.input_name = input_name
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
return inputs[0]

View File

@@ -0,0 +1,40 @@
"""Test file for HDK's common/representation/intermediate.py"""
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.representation import intermediate as ir
@pytest.mark.parametrize(
"node,input_data,expected_result",
[
pytest.param(
ir.Add([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
[10, 4589],
4599,
id="Add",
),
pytest.param(
ir.Sub([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
[10, 4589],
-4579,
id="Sub",
),
pytest.param(
ir.Mul([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
[10, 4589],
45890,
id="Mul",
),
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
],
)
def test_evaluate(
node: ir.IntermediateNode,
input_data,
expected_result: int,
):
"""Test evaluate methods on IntermediateNodes"""
assert node.evaluate(input_data) == expected_result