refactor: remove Scalars type hint

- add a helper function to determine BaseDataType of a constant python
scalar, int or float in dtype_helpers.py
- make BaseTracer type agnostic
- make ConstantInput type agnostic
This commit is contained in:
Arthur Meyre
2021-08-19 16:33:47 +02:00
parent 5e258ca443
commit c528d72e62
4 changed files with 52 additions and 39 deletions

View File

@@ -1,7 +1,8 @@
"""File to hold helper functions for data types related stuff."""
from copy import deepcopy
from typing import Union, cast
from functools import partial
from typing import Callable, Union, cast
from .base import BaseDataType
from .floats import Float
@@ -185,3 +186,24 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
elif isinstance(constant_data, float):
constant_data_type = Float(64)
return constant_data_type
def get_base_value_for_python_constant_data(
constant_data: Union[int, float]
) -> Callable[..., ScalarValue]:
"""Function to wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue
by calling it with the proper arguments forwarded to the ScalarValue `__init__` function
Args:
constant_data (Union[int, float]): The constant data for which to determine the
corresponding ScalarValue and BaseDataType.
Returns:
Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when
called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__`
method).
"""
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(ScalarValue, data_type=constant_data_type)

View File

@@ -1,6 +0,0 @@
"""File holding code to represent data types used for constants in programs."""
from typing import Union
# TODO: deal with more types
Scalars = Union[int, float]

View File

@@ -6,11 +6,11 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from ..data_types import BaseValue
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
from ..data_types.floats import Float
from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer
from ..data_types.scalars import Scalars
from ..data_types.values import ClearValue, EncryptedValue
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
mix_scalar_values_determine_holding_dtype,
)
from ..data_types.values import EncryptedValue
class IntermediateNode(ABC):
@@ -147,30 +147,18 @@ class Input(IntermediateNode):
class ConstantInput(IntermediateNode):
"""Node representing a constant of the program."""
constant_data: Scalars
_constant_data: Any
def __init__(
self,
constant_data: Scalars,
constant_data: Any,
) -> None:
super().__init__([])
self.constant_data = constant_data
assert isinstance(
constant_data, (int, float)
), "Only int and float are support for constant input"
if isinstance(constant_data, int):
is_signed = constant_data < 0
self.outputs = [
ClearValue(
Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed),
is_signed,
)
)
]
elif isinstance(constant_data, float):
self.outputs = [ClearValue(Float(64))]
base_value_class = get_base_value_for_python_constant_data(constant_data)
self._constant_data = constant_data
self.outputs = [base_value_class(is_encrypted=False)]
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return self.constant_data
@@ -182,6 +170,15 @@ class ConstantInput(IntermediateNode):
and super().is_equivalent_to(other)
)
@property
def constant_data(self) -> Any:
"""Returns the constant_data stored in the ConstantInput node.
Returns:
Any: The constant data that was stored.
"""
return self._constant_data
class ArbitraryFunction(IntermediateNode):
"""Node representing a univariate arbitrary function, e.g. sin(x)."""

View File

@@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Tuple, Type, Union
from ..data_types import BaseValue
from ..data_types.scalars import Scalars
from ..representation import intermediate as ir
@@ -39,13 +38,14 @@ class BaseTracer(ABC):
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Scalars]],
inputs: Iterable[Union["BaseTracer", Any]],
computation_to_trace: Type[ir.IntermediateNode],
) -> Tuple["BaseTracer", ...]:
"""Helper functions to instantiate all output BaseTracer for a given computation.
Args:
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
for a new node.
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
to instantiate for the computation being traced
@@ -71,7 +71,7 @@ class BaseTracer(ABC):
return output_tracers
def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -88,7 +88,7 @@ class BaseTracer(ABC):
# some changes
__radd__ = __add__
def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -100,7 +100,7 @@ class BaseTracer(ABC):
assert len(result_tracer) == 1
return result_tracer[0]
def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -112,7 +112,7 @@ class BaseTracer(ABC):
assert len(result_tracer) == 1
return result_tracer[0]
def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
if not self._supports_other_operand(other):
return NotImplemented
@@ -130,12 +130,12 @@ class BaseTracer(ABC):
__rmul__ = __mul__
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer:
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer:
"""Helper function to create a tracer for a constant input.
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
constant_data (Scalars): the constant
constant_data (Any): the constant
Returns:
BaseTracer: The BaseTracer for that constant