mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: remove Scalars type hint
- add a helper function to determine BaseDataType of a constant python scalar, int or float in dtype_helpers.py - make BaseTracer type agnostic - make ConstantInput type agnostic
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""File to hold helper functions for data types related stuff."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Union, cast
|
||||
from functools import partial
|
||||
from typing import Callable, Union, cast
|
||||
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
@@ -185,3 +186,24 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
|
||||
elif isinstance(constant_data, float):
|
||||
constant_data_type = Float(64)
|
||||
return constant_data_type
|
||||
|
||||
|
||||
def get_base_value_for_python_constant_data(
|
||||
constant_data: Union[int, float]
|
||||
) -> Callable[..., ScalarValue]:
|
||||
"""Function to wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
|
||||
|
||||
The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue
|
||||
by calling it with the proper arguments forwarded to the ScalarValue `__init__` function
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding ScalarValue and BaseDataType.
|
||||
|
||||
Returns:
|
||||
Callable[..., ScalarValue]: A partial object that will return the proper ScalarValue when
|
||||
called with `encrypted` as keyword argument (forwarded to the ScalarValue `__init__`
|
||||
method).
|
||||
"""
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(ScalarValue, data_type=constant_data_type)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""File holding code to represent data types used for constants in programs."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
# TODO: deal with more types
|
||||
Scalars = Union[int, float]
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import mix_scalar_values_determine_holding_dtype
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer, get_bits_to_represent_value_as_integer
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..data_types.values import ClearValue, EncryptedValue
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
)
|
||||
from ..data_types.values import EncryptedValue
|
||||
|
||||
|
||||
class IntermediateNode(ABC):
|
||||
@@ -147,30 +147,18 @@ class Input(IntermediateNode):
|
||||
class ConstantInput(IntermediateNode):
|
||||
"""Node representing a constant of the program."""
|
||||
|
||||
constant_data: Scalars
|
||||
_constant_data: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constant_data: Scalars,
|
||||
constant_data: Any,
|
||||
) -> None:
|
||||
super().__init__([])
|
||||
self.constant_data = constant_data
|
||||
|
||||
assert isinstance(
|
||||
constant_data, (int, float)
|
||||
), "Only int and float are support for constant input"
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
self.outputs = [
|
||||
ClearValue(
|
||||
Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed),
|
||||
is_signed,
|
||||
)
|
||||
)
|
||||
]
|
||||
elif isinstance(constant_data, float):
|
||||
self.outputs = [ClearValue(Float(64))]
|
||||
base_value_class = get_base_value_for_python_constant_data(constant_data)
|
||||
|
||||
self._constant_data = constant_data
|
||||
self.outputs = [base_value_class(is_encrypted=False)]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
@@ -182,6 +170,15 @@ class ConstantInput(IntermediateNode):
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
@property
|
||||
def constant_data(self) -> Any:
|
||||
"""Returns the constant_data stored in the ConstantInput node.
|
||||
|
||||
Returns:
|
||||
Any: The constant data that was stored.
|
||||
"""
|
||||
return self._constant_data
|
||||
|
||||
|
||||
class ArbitraryFunction(IntermediateNode):
|
||||
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
|
||||
|
||||
@@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.scalars import Scalars
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
@@ -39,13 +38,14 @@ class BaseTracer(ABC):
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Scalars]],
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Helper functions to instantiate all output BaseTracer for a given computation.
|
||||
|
||||
Args:
|
||||
inputs (List[BaseTracer]): Previous BaseTracer used as inputs for a new node
|
||||
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
|
||||
for a new node.
|
||||
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
|
||||
to instantiate for the computation being traced
|
||||
|
||||
@@ -71,7 +71,7 @@ class BaseTracer(ABC):
|
||||
|
||||
return output_tracers
|
||||
|
||||
def __add__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -88,7 +88,7 @@ class BaseTracer(ABC):
|
||||
# some changes
|
||||
__radd__ = __add__
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -100,7 +100,7 @@ class BaseTracer(ABC):
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
def __rsub__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
def __rsub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -112,7 +112,7 @@ class BaseTracer(ABC):
|
||||
assert len(result_tracer) == 1
|
||||
return result_tracer[0]
|
||||
|
||||
def __mul__(self, other: Union["BaseTracer", Scalars]) -> "BaseTracer":
|
||||
def __mul__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
|
||||
@@ -130,12 +130,12 @@ class BaseTracer(ABC):
|
||||
__rmul__ = __mul__
|
||||
|
||||
|
||||
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Scalars) -> BaseTracer:
|
||||
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer:
|
||||
"""Helper function to create a tracer for a constant input.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
|
||||
constant_data (Scalars): the constant
|
||||
constant_data (Any): the constant
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that constant
|
||||
|
||||
Reference in New Issue
Block a user