mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: change type hints in IntermediateNode evaluate
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.base import BaseDataType
|
||||
@@ -73,11 +73,11 @@ class IntermediateNode(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
"""Function to simulate what the represented computation would output for the given inputs.
|
||||
|
||||
Args:
|
||||
inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation
|
||||
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
|
||||
|
||||
Returns:
|
||||
Any: the result of the computation
|
||||
@@ -90,7 +90,7 @@ class Add(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class Sub(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class Mul(IntermediateNode):
|
||||
__init__ = IntermediateNode._init_binary
|
||||
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class Input(IntermediateNode):
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
@@ -172,7 +172,7 @@ class ConstantInput(IntermediateNode):
|
||||
elif isinstance(constant_data, float):
|
||||
self.outputs = [ClearValue(Float(64))]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
def is_equivalent_to(self, other: object) -> bool:
|
||||
@@ -211,7 +211,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
self.outputs = [EncryptedValue(output_dtype)]
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user