mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
24c0735490
commit
4778dc503b
2
Makefile
2
Makefile
@@ -152,7 +152,7 @@ build_and_open_docs: clean_docs docs open_docs
|
||||
|
||||
pydocstyle:
|
||||
@# From http://www.pydocstyle.org/en/stable/error_codes.html
|
||||
poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D202
|
||||
poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D202 --add-select=D401
|
||||
.PHONY: pydocstyle
|
||||
|
||||
strip_nb:
|
||||
|
||||
@@ -47,7 +47,7 @@ class CompilationArtifacts:
|
||||
self.mlir_of_the_final_operation_graph = None
|
||||
|
||||
def add_function_to_compile(self, function: Union[Callable, str]):
|
||||
"""Adds the function to compile to artifacts.
|
||||
"""Add the function to compile to artifacts.
|
||||
|
||||
Args:
|
||||
function (Union[Callable, str]): the function to compile or source code of it
|
||||
@@ -61,7 +61,7 @@ class CompilationArtifacts:
|
||||
)
|
||||
|
||||
def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]):
|
||||
"""Adds a parameter of the function to compile to the artifacts.
|
||||
"""Add a parameter of the function to compile to the artifacts.
|
||||
|
||||
Args:
|
||||
name (str): name of the parameter
|
||||
@@ -74,7 +74,7 @@ class CompilationArtifacts:
|
||||
self.parameters_of_the_function_to_compile[name] = str(value)
|
||||
|
||||
def add_operation_graph(self, name: str, operation_graph: OPGraph):
|
||||
"""Adds an operation graph to the artifacts.
|
||||
"""Add an operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
name (str): name of the graph
|
||||
@@ -93,7 +93,7 @@ class CompilationArtifacts:
|
||||
self.final_operation_graph = operation_graph
|
||||
|
||||
def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]):
|
||||
"""Adds the bounds of the final operation graph to the artifacts.
|
||||
"""Add the bounds of the final operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary
|
||||
@@ -106,7 +106,7 @@ class CompilationArtifacts:
|
||||
self.bounds_of_the_final_operation_graph = bounds
|
||||
|
||||
def add_final_operation_graph_mlir(self, mlir: str):
|
||||
"""Adds the mlir of the final operation graph to the artifacts.
|
||||
"""Add the mlir of the final operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
mlir (str): the mlir code of the final operation graph
|
||||
@@ -119,7 +119,7 @@ class CompilationArtifacts:
|
||||
self.mlir_of_the_final_operation_graph = mlir
|
||||
|
||||
def export(self):
|
||||
"""Exports the artifacts to a the output directory.
|
||||
"""Export the artifacts to a the output directory.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
@@ -23,7 +23,7 @@ BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted ScalarValue of type Integer.
|
||||
"""Check that a value is an encrypted ScalarValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -35,7 +35,7 @@ def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted ScalarValue of type unsigned Integer.
|
||||
"""Check that a value is an encrypted ScalarValue of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -51,7 +51,7 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo
|
||||
|
||||
|
||||
def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a clear ScalarValue of type Integer.
|
||||
"""Check that a value is a clear ScalarValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -63,7 +63,7 @@ def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
|
||||
def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a ScalarValue of type Integer.
|
||||
"""Check that a value is a ScalarValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -77,7 +77,7 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted TensorValue of type Integer.
|
||||
"""Check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -89,7 +89,7 @@ def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is an encrypted TensorValue of type unsigned Integer.
|
||||
"""Check that a value is an encrypted TensorValue of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -105,7 +105,7 @@ def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> boo
|
||||
|
||||
|
||||
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a clear TensorValue of type Integer.
|
||||
"""Check that a value is a clear TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -117,7 +117,7 @@ def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
|
||||
def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a TensorValue of type Integer.
|
||||
"""Check that a value is a TensorValue of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
@@ -294,7 +294,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
"""Helper function to determine the BaseDataType to hold the input constant data.
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
@@ -320,7 +320,7 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]
|
||||
def get_base_value_for_python_constant_data(
|
||||
constant_data: Union[int, float]
|
||||
) -> Callable[..., ScalarValue]:
|
||||
"""Function to wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
|
||||
"""Wrap the BaseDataType to hold the input constant data in a ScalarValue partial.
|
||||
|
||||
The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue
|
||||
by calling it with the proper arguments forwarded to the ScalarValue `__init__` function
|
||||
|
||||
@@ -43,7 +43,7 @@ class Integer(base.BaseDataType):
|
||||
return 2 ** self.bit_width - 1
|
||||
|
||||
def can_represent_value(self, value_to_represent: int) -> bool:
|
||||
"""A helper function to check if a value is representable by the Integer.
|
||||
"""Check if a value is representable by the Integer.
|
||||
|
||||
Args:
|
||||
value_to_represent (int): Value to check
|
||||
@@ -55,7 +55,7 @@ class Integer(base.BaseDataType):
|
||||
|
||||
|
||||
def create_signed_integer(bit_width: int) -> Integer:
|
||||
"""Convenience function to create a signed integer.
|
||||
"""Create a signed integer.
|
||||
|
||||
Args:
|
||||
bit_width (int): width of the integer
|
||||
@@ -70,7 +70,7 @@ SignedInteger = create_signed_integer
|
||||
|
||||
|
||||
def create_unsigned_integer(bit_width: int) -> Integer:
|
||||
"""Convenience function to create an unsigned integer.
|
||||
"""Create an unsigned integer.
|
||||
|
||||
Args:
|
||||
bit_width (int): width of the integer
|
||||
@@ -85,7 +85,7 @@ UnsignedInteger = create_unsigned_integer
|
||||
|
||||
|
||||
def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
|
||||
"""Returns an Integer able to hold all values, it is possible to force the Integer to be signed.
|
||||
"""Return an Integer able to hold all values, it is possible to force the Integer to be signed.
|
||||
|
||||
Args:
|
||||
values (Iterable[Any]): The values to hold
|
||||
@@ -108,7 +108,7 @@ def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer:
|
||||
|
||||
|
||||
def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int:
|
||||
"""Returns how many bits are required to represent a numerical Value.
|
||||
"""Return how many bits are required to represent a numerical Value.
|
||||
|
||||
Args:
|
||||
value (Any): The value for which we want to know how many bits are required.
|
||||
|
||||
@@ -25,7 +25,7 @@ from ..representation import intermediate as ir
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node."""
|
||||
"""Convert an addition intermediate node."""
|
||||
assert len(node.inputs) == 2, "addition should have two inputs"
|
||||
assert len(node.outputs) == 1, "addition should have a single output"
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
@@ -47,7 +47,7 @@ def add(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node with operands (eint, int)."""
|
||||
"""Convert an addition intermediate node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintIntOp(
|
||||
@@ -58,7 +58,7 @@ def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node with operands (eint, int)."""
|
||||
"""Convert an addition intermediate node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintOp(
|
||||
@@ -69,7 +69,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the subtraction intermediate node."""
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert len(node.inputs) == 2, "subtraction should have two inputs"
|
||||
assert len(node.outputs) == 1, "subtraction should have a single output"
|
||||
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer(
|
||||
@@ -82,7 +82,7 @@ def sub(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the subtraction intermediate node with operands (int, eint)."""
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.SubIntEintOp(
|
||||
@@ -93,7 +93,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the multiplication intermediate node."""
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert len(node.inputs) == 2, "multiplication should have two inputs"
|
||||
assert len(node.outputs) == 1, "multiplication should have a single output"
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
@@ -111,7 +111,7 @@ def mul(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the multiplication intermediate node with operands (eint, int)."""
|
||||
"""Convert a multiplication intermediate node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.MulEintIntOp(
|
||||
@@ -122,7 +122,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def constant(node, _, __, ctx):
|
||||
"""Converter function for constant inputs."""
|
||||
"""Convert a constant inputs."""
|
||||
if not value_is_clear_scalar_integer(node.outputs[0]):
|
||||
raise TypeError("Don't support non-integer constants")
|
||||
dtype = cast(Integer, node.outputs[0].data_type)
|
||||
@@ -133,7 +133,7 @@ def constant(node, _, __, ctx):
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the arbitrary function intermediate node."""
|
||||
"""Convert an arbitrary function intermediate node."""
|
||||
assert len(node.inputs) == 1, "LUT should have a single input"
|
||||
assert len(node.outputs) == 1, "LUT should have a single output"
|
||||
if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]):
|
||||
@@ -159,7 +159,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the dot intermediate node."""
|
||||
"""Convert a dot intermediate node."""
|
||||
assert len(node.inputs) == 2, "Dot should have two inputs"
|
||||
assert len(node.outputs) == 1, "Dot should have a single output"
|
||||
if not (
|
||||
|
||||
@@ -109,7 +109,7 @@ class OPGraph:
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
"""Function to evaluate a graph and get intermediate values for all nodes.
|
||||
"""Evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
Args:
|
||||
inputs (Dict[int, Any]): The inputs to the program
|
||||
@@ -195,7 +195,7 @@ class OPGraph:
|
||||
succ.inputs[input_idx] = deepcopy(node.outputs[0])
|
||||
|
||||
def prune_nodes(self):
|
||||
"""Function to remove unreachable nodes from outputs."""
|
||||
"""Remove unreachable nodes from outputs."""
|
||||
|
||||
current_nodes = set(self.output_nodes.values())
|
||||
useful_nodes: Set[ir.IntermediateNode] = set()
|
||||
|
||||
@@ -15,7 +15,7 @@ def fuse_float_operations(
|
||||
op_graph: OPGraph,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
"""Finds and fuses float domains into single Integer to Integer ArbitraryFunction.
|
||||
"""Find and fuse float domains into single Integer to Integer ArbitraryFunction.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to simplify
|
||||
@@ -90,7 +90,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
terminal_node: ir.IntermediateNode,
|
||||
subgraph_all_nodes: Set[ir.IntermediateNode],
|
||||
) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]:
|
||||
"""Converts a float subgraph to an equivalent fused ArbitraryFunction node.
|
||||
"""Convert a float subgraph to an equivalent fused ArbitraryFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
|
||||
@@ -52,7 +52,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
"""Function to simulate what the represented computation would output for the given inputs.
|
||||
"""Simulate what the represented computation would output for the given inputs.
|
||||
|
||||
Args:
|
||||
inputs (Dict[int, Any]): Dict containing the inputs for the evaluation
|
||||
@@ -63,7 +63,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
@classmethod
|
||||
def n_in(cls) -> int:
|
||||
"""Returns how many inputs the node has.
|
||||
"""Return how many inputs the node has.
|
||||
|
||||
Returns:
|
||||
int: The number of inputs of the node.
|
||||
@@ -72,7 +72,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
@classmethod
|
||||
def requires_mix_values_func(cls) -> bool:
|
||||
"""Function to determine whether the Class requires a mix_values_func to be built.
|
||||
"""Determine whether the Class requires a mix_values_func to be built.
|
||||
|
||||
Returns:
|
||||
bool: True if __init__ expects a mix_values_func argument.
|
||||
@@ -81,7 +81,7 @@ class IntermediateNode(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def label(self) -> str:
|
||||
"""Function to get the label of the node.
|
||||
"""Get the label of the node.
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
@@ -182,7 +182,7 @@ class Constant(IntermediateNode):
|
||||
|
||||
@property
|
||||
def constant_data(self) -> Any:
|
||||
"""Returns the constant_data stored in the Constant node.
|
||||
"""Return the constant_data stored in the Constant node.
|
||||
|
||||
Returns:
|
||||
Any: The constant data that was stored.
|
||||
@@ -230,7 +230,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
return self.op_name
|
||||
|
||||
def get_table(self) -> List[Any]:
|
||||
"""Function to get the table for the current input value of this ArbitraryFunction.
|
||||
"""Get the table for the current input value of this ArbitraryFunction.
|
||||
|
||||
Returns:
|
||||
List[Any]: The table.
|
||||
@@ -255,7 +255,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
|
||||
|
||||
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
|
||||
"""Default python dot implementation for 1D iterable arrays.
|
||||
"""Return the default python dot implementation for 1D iterable arrays.
|
||||
|
||||
Args:
|
||||
lhs (Any): lhs vector of the dot.
|
||||
@@ -268,7 +268,7 @@ def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
|
||||
|
||||
|
||||
class Dot(IntermediateNode):
|
||||
"""Node representing a dot product."""
|
||||
"""Return the node representing a dot product."""
|
||||
|
||||
_n_in: int = 2
|
||||
# Optional, same issue as in ArbitraryFunction for mypy
|
||||
|
||||
@@ -28,7 +28,7 @@ class BaseTracer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _supports_other_operand(self, other: Any) -> bool:
|
||||
"""Function to check if the current class supports tracing with the other operand.
|
||||
"""Check if the current class supports tracing with the other operand.
|
||||
|
||||
Args:
|
||||
other (Any): the operand to check compatibility with.
|
||||
@@ -40,7 +40,7 @@ class BaseTracer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
|
||||
"""Helper function to create a tracer for a constant input.
|
||||
"""Create a tracer for a constant input.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant to store.
|
||||
@@ -63,7 +63,7 @@ class BaseTracer(ABC):
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Helper functions to instantiate all output BaseTracer for a given computation.
|
||||
"""Instantiate all output BaseTracer for a given computation.
|
||||
|
||||
Args:
|
||||
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
|
||||
|
||||
@@ -15,7 +15,7 @@ def make_input_tracers(
|
||||
tracer_class: Type[BaseTracer],
|
||||
function_parameters: OrderedDict[str, BaseValue],
|
||||
) -> OrderedDict[str, BaseTracer]:
|
||||
"""Helper function to create tracers for a function's parameters.
|
||||
"""Create tracers for a function's parameters.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
@@ -37,7 +37,7 @@ def make_input_tracer(
|
||||
input_idx: int,
|
||||
input_value: BaseValue,
|
||||
) -> BaseTracer:
|
||||
"""Helper function to create a tracer for an input value.
|
||||
"""Create a tracer for an input value.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create an Input for
|
||||
@@ -55,7 +55,7 @@ def make_input_tracer(
|
||||
def prepare_function_parameters(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> OrderedDict[str, BaseValue]:
|
||||
"""Function to filter the passed function_parameters to trace function_to_trace.
|
||||
"""Filter the passed function_parameters to trace function_to_trace.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): function that will be traced for which parameters are checked
|
||||
|
||||
@@ -16,7 +16,7 @@ class ScalarValue(BaseValue):
|
||||
|
||||
|
||||
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
"""Helper to create a clear ScalarValue.
|
||||
"""Create a clear ScalarValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the value.
|
||||
@@ -28,7 +28,7 @@ def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
|
||||
|
||||
def make_encrypted_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
"""Helper to create an encrypted ScalarValue.
|
||||
"""Create an encrypted ScalarValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the value.
|
||||
|
||||
@@ -41,7 +41,7 @@ class TensorValue(BaseValue):
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""The TensorValue shape property.
|
||||
"""Return the TensorValue shape property.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The TensorValue shape.
|
||||
@@ -50,7 +50,7 @@ class TensorValue(BaseValue):
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
"""The TensorValue ndim property.
|
||||
"""Return the TensorValue ndim property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue ndim.
|
||||
@@ -59,7 +59,7 @@ class TensorValue(BaseValue):
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""The TensorValue size property.
|
||||
"""Return the TensorValue size property.
|
||||
|
||||
Returns:
|
||||
int: The TensorValue size.
|
||||
@@ -71,7 +71,7 @@ def make_clear_tensor(
|
||||
data_type: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Helper to create a clear TensorValue.
|
||||
"""Create a clear TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
@@ -87,7 +87,7 @@ def make_encrypted_tensor(
|
||||
data_type: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Helper to create an encrypted TensorValue.
|
||||
"""Create an encrypted TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
|
||||
@@ -199,7 +199,7 @@ def _compile_numpy_function_internal(
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
show_mlir: bool,
|
||||
) -> CompilerEngine:
|
||||
"""Internal part of the API to be able to compile an homomorphic program.
|
||||
"""Compile an homomorphic program (internal part of the API).
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function you want to compile
|
||||
@@ -254,7 +254,7 @@ def compile_numpy_function(
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
) -> CompilerEngine:
|
||||
"""Main API to be able to compile an homomorphic program.
|
||||
"""Compile an homomorphic program (main API).
|
||||
|
||||
Args:
|
||||
function_to_compile (Callable): The function to compile
|
||||
|
||||
@@ -33,7 +33,7 @@ SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_
|
||||
|
||||
|
||||
def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType:
|
||||
"""Helper function to get the corresponding BaseDataType from a numpy dtype.
|
||||
"""Get the corresponding BaseDataType from a numpy dtype.
|
||||
|
||||
Args:
|
||||
numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype
|
||||
@@ -99,7 +99,7 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d
|
||||
|
||||
|
||||
def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType:
|
||||
"""Helper function to determine the BaseDataType to hold the input constant data.
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant data for which to determine the
|
||||
@@ -124,7 +124,7 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) ->
|
||||
def get_base_value_for_numpy_or_python_constant_data(
|
||||
constant_data: Any,
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Helper function to determine the BaseValue and BaseDataType to hold the input constant data.
|
||||
"""Determine the BaseValue and BaseDataType to hold the input constant data.
|
||||
|
||||
This function is able to handle numpy types
|
||||
|
||||
@@ -158,7 +158,7 @@ def get_numpy_function_output_dtype(
|
||||
function: Union[numpy.ufunc, Callable],
|
||||
input_dtypes: List[BaseDataType],
|
||||
) -> List[numpy.dtype]:
|
||||
"""Function to record the output dtype of a numpy function given some input types.
|
||||
"""Record the output dtype of a numpy function given some input types.
|
||||
|
||||
Args:
|
||||
function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to
|
||||
|
||||
@@ -134,7 +134,7 @@ class NPTracer(BaseTracer):
|
||||
def _unary_operator(
|
||||
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
) -> "NPTracer":
|
||||
"""Function to trace an unary operator.
|
||||
"""Trace an unary operator.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
@@ -158,7 +158,7 @@ class NPTracer(BaseTracer):
|
||||
return output_tracer
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.dot.
|
||||
"""Trace numpy.dot.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
@@ -285,7 +285,7 @@ class NPTracer(BaseTracer):
|
||||
|
||||
|
||||
def _get_fun(function: numpy.ufunc):
|
||||
"""Helper function to wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
"""Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
@@ -303,7 +303,7 @@ NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORT
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> OPGraph:
|
||||
"""Function used to trace a numpy function.
|
||||
"""Trace a numpy function.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
|
||||
Reference in New Issue
Block a user