diff --git a/hdk/common/data_types/dtypes_helpers.py b/hdk/common/data_types/dtypes_helpers.py index 6521ba12e..91db9f1f9 100644 --- a/hdk/common/data_types/dtypes_helpers.py +++ b/hdk/common/data_types/dtypes_helpers.py @@ -1,7 +1,8 @@ """File to hold helper functions for data types related stuff.""" from copy import deepcopy -from typing import Union, cast +from functools import partial +from typing import Callable, Union, cast from .base import BaseDataType from .floats import Float @@ -185,3 +186,24 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] elif isinstance(constant_data, float): constant_data_type = Float(64) return constant_data_type + + +def get_base_value_for_python_constant_data( + constant_data: Union[int, float] +) -> Callable[..., ScalarValue]: + """Function to wrap the BaseDataType to hold the input constant data in a ScalarValue partial. + + The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue + by calling it with the proper arguments forwarded to the ScalarValue `__init__` function + + Args: + constant_data (Union[int, float]): The constant data for which to determine the + corresponding ScalarValue and BaseDataType. + + Returns: + Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when + called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__` + method). + """ + constant_data_type = get_base_data_type_for_python_constant_data(constant_data) + return partial(ScalarValue, data_type=constant_data_type) diff --git a/hdk/common/data_types/scalars.py b/hdk/common/data_types/scalars.py deleted file mode 100644 index 078777cb5..000000000 --- a/hdk/common/data_types/scalars.py +++ /dev/null @@ -1,6 +0,0 @@ -"""File holding code to represent data types used for constants in programs.""" - -from typing import Union - -# TODO: deal with more types -Scalars = Union[int, float] diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 4eeebbe6d..f591ca1cd 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -6,11 +6,11 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from ..data_types import BaseValue from ..data_types.base import BaseDataType -from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype -from ..data_types.floats import Float -from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer -from ..data_types.scalars import Scalars -from ..data_types.values import ClearValue, EncryptedValue +from ..data_types.dtypes_helpers import ( + get_base_value_for_python_constant_data, + mix_scalar_values_determine_holding_dtype, +) +from ..data_types.values import EncryptedValue class IntermediateNode(ABC): @@ -147,30 +147,18 @@ class Input(IntermediateNode): class ConstantInput(IntermediateNode): """Node representing a constant of the program.""" - constant_data: Scalars + _constant_data: Any def __init__( self, - constant_data: Scalars, + constant_data: Any, ) -> None: super().__init__([]) - self.constant_data = constant_data - assert isinstance( - constant_data, (int, float) - ), "Only int and float are support for constant input" - if isinstance(constant_data, int): - is_signed = constant_data < 0 - self.outputs = [ - ClearValue( - Integer( - get_bits_to_represent_value_as_integer(constant_data, is_signed), - is_signed, - ) - ) - ] - elif isinstance(constant_data, float): - self.outputs = [ClearValue(Float(64))] + base_value_class = get_base_value_for_python_constant_data(constant_data) + + self._constant_data = constant_data + self.outputs = [base_value_class(is_encrypted=False)] def evaluate(self, inputs: Dict[int, Any]) -> Any: return self.constant_data @@ -182,6 +170,15 @@ class ConstantInput(IntermediateNode): and super().is_equivalent_to(other) ) + @property + def constant_data(self) -> Any: + """Returns the constant_data stored in the ConstantInput node. + + Returns: + Any: The constant data that was stored. + """ + return self._constant_data + class ArbitraryFunction(IntermediateNode): """Node representing a univariate arbitrary function, e.g. sin(x).""" diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index 950a9a764..83a52af07 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from typing import Any, Iterable, List, Tuple, Type, Union from ..data_types import BaseValue -from ..data_types.scalars import Scalars from ..representation import intermediate as ir @@ -39,13 +38,14 @@ class BaseTracer(ABC): def instantiate_output_tracers( self, - inputs: Iterable[Union["BaseTracer", Scalars]], + inputs: Iterable[Union["BaseTracer", Any]], computation_to_trace: Type[ir.IntermediateNode], ) -> Tuple["BaseTracer", ...]: """Helper functions to instantiate all output BaseTracer for a given computation. Args: - inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node + inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs + for a new node. computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class to instantiate for the computation being traced @@ -71,7 +71,7 @@ class BaseTracer(ABC): return output_tracers - def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -88,7 +88,7 @@ class BaseTracer(ABC): # some changes __radd__ = __add__ - def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -100,7 +100,7 @@ class BaseTracer(ABC): assert len(result_tracer) == 1 return result_tracer[0] - def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -112,7 +112,7 @@ class BaseTracer(ABC): assert len(result_tracer) == 1 return result_tracer[0] - def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer": + def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -130,12 +130,12 @@ class BaseTracer(ABC): __rmul__ = __mul__ -def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer: +def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer: """Helper function to create a tracer for a constant input. Args: tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for - constant_data (Scalars): the constant + constant_data (Any): the constant Returns: BaseTracer: The BaseTracer for that constant