mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(ir): add evalute to simulate the computation represented by IR nodes
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""File containing HDK's intermdiate representation of source programs operations"""
|
||||
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.dtypes_helpers import mix_values_determine_holding_dtype
|
||||
@@ -74,6 +74,17 @@ class IntermediateNode(ABC):
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
"""Function to simulate what the represented computation would output for the given inputs
|
||||
|
||||
Args:
|
||||
inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation
|
||||
|
||||
Returns:
|
||||
Any: the result of the computation
|
||||
"""
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values"""
|
||||
@@ -81,6 +92,9 @@ class Add(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
|
||||
class Sub(IntermediateNode):
|
||||
"""Subtraction between two values"""
|
||||
@@ -88,6 +102,9 @@ class Sub(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
|
||||
|
||||
class Mul(IntermediateNode):
|
||||
"""Multiplication between two values"""
|
||||
@@ -95,6 +112,9 @@ class Mul(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the numpy program"""
|
||||
@@ -113,3 +133,6 @@ class Input(IntermediateNode):
|
||||
self.input_name = input_name
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
40
tests/common/representation/test_intermediate.py
Normal file
40
tests/common/representation/test_intermediate.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Test file for HDK's common/representation/intermediate.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.representation import intermediate as ir
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node,input_data,expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
ir.Add([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
4599,
|
||||
id="Add",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Sub([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
-4579,
|
||||
id="Sub",
|
||||
),
|
||||
pytest.param(
|
||||
ir.Mul([EncryptedValue(Integer(64, False)), EncryptedValue(Integer(64, False))]),
|
||||
[10, 4589],
|
||||
45890,
|
||||
id="Mul",
|
||||
),
|
||||
pytest.param(ir.Input(ClearValue(Integer(32, True)), "in", 0), [42], 42, id="Input"),
|
||||
],
|
||||
)
|
||||
def test_evaluate(
|
||||
node: ir.IntermediateNode,
|
||||
input_data,
|
||||
expected_result: int,
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
Reference in New Issue
Block a user