refactor: change type hints in IntermediateNode evaluate

This commit is contained in:
Arthur Meyre
2021-08-19 16:25:58 +02:00
parent 5048992707
commit 5e258ca443

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.base import BaseDataType
@@ -73,11 +73,11 @@ class IntermediateNode(ABC):
)
@abstractmethod
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
"""Function to simulate what the represented computation would output for the given inputs.
Args:
inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
Returns:
Any: the result of the computation
@@ -90,7 +90,7 @@ class Add(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] + inputs[1]
@@ -100,7 +100,7 @@ class Sub(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] - inputs[1]
@@ -110,7 +110,7 @@ class Mul(IntermediateNode):
__init__ = IntermediateNode._init_binary
is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] * inputs[1]
@@ -132,7 +132,7 @@ class Input(IntermediateNode):
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0]
def is_equivalent_to(self, other: object) -> bool:
@@ -172,7 +172,7 @@ class ConstantInput(IntermediateNode):
elif isinstance(constant_data, float):
self.outputs = [ClearValue(Float(64))]
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return self.constant_data
def is_equivalent_to(self, other: object) -> bool:
@@ -211,7 +211,7 @@ class ArbitraryFunction(IntermediateNode):
self.outputs = [EncryptedValue(output_dtype)]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)