diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index fecc54d58..4eeebbe6d 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from ..data_types import BaseValue from ..data_types.base import BaseDataType @@ -73,11 +73,11 @@ class IntermediateNode(ABC): ) @abstractmethod - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: """Function to simulate what the represented computation would output for the given inputs. Args: - inputs (Mapping[int, Any]): Mapping containing the inputs for the evaluation + inputs (Dict[int, Any]): Dict containing the inputs for the evaluation Returns: Any: the result of the computation @@ -90,7 +90,7 @@ class Add(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] + inputs[1] @@ -100,7 +100,7 @@ class Sub(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_non_commutative - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] - inputs[1] @@ -110,7 +110,7 @@ class Mul(IntermediateNode): __init__ = IntermediateNode._init_binary is_equivalent_to = IntermediateNode._is_equivalent_to_binary_commutative - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] * inputs[1] @@ -132,7 +132,7 @@ class Input(IntermediateNode): self.program_input_idx = program_input_idx self.outputs = [deepcopy(self.inputs[0])] - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: return inputs[0] def is_equivalent_to(self, other: object) -> bool: @@ -172,7 +172,7 @@ class ConstantInput(IntermediateNode): elif isinstance(constant_data, float): self.outputs = [ClearValue(Float(64))] - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: return self.constant_data def is_equivalent_to(self, other: object) -> bool: @@ -211,7 +211,7 @@ class ArbitraryFunction(IntermediateNode): self.outputs = [EncryptedValue(output_dtype)] self.op_name = op_name if op_name is not None else self.__class__.__name__ - def evaluate(self, inputs: Mapping[int, Any]) -> Any: + def evaluate(self, inputs: Dict[int, Any]) -> Any: # This is the continuation of the mypy bug workaround assert self.arbitrary_func is not None return self.arbitrary_func(inputs[0], *self.op_args, **self.op_kwargs)