diff --git a/Makefile b/Makefile index 288770470..2f59e6fd6 100644 --- a/Makefile +++ b/Makefile @@ -152,7 +152,7 @@ build_and_open_docs: clean_docs docs open_docs pydocstyle: @# From http://www.pydocstyle.org/en/stable/error_codes.html - poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D202 + poetry run pydocstyle $(SRC_DIR) --convention google --add-ignore=D1,D202 --add-select=D401 .PHONY: pydocstyle strip_nb: diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index 28b16d6fa..a0f5760f0 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -47,7 +47,7 @@ class CompilationArtifacts: self.mlir_of_the_final_operation_graph = None def add_function_to_compile(self, function: Union[Callable, str]): - """Adds the function to compile to artifacts. + """Add the function to compile to artifacts. Args: function (Union[Callable, str]): the function to compile or source code of it @@ -61,7 +61,7 @@ class CompilationArtifacts: ) def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]): - """Adds a parameter of the function to compile to the artifacts. + """Add a parameter of the function to compile to the artifacts. Args: name (str): name of the parameter @@ -74,7 +74,7 @@ class CompilationArtifacts: self.parameters_of_the_function_to_compile[name] = str(value) def add_operation_graph(self, name: str, operation_graph: OPGraph): - """Adds an operation graph to the artifacts. + """Add an operation graph to the artifacts. Args: name (str): name of the graph @@ -93,7 +93,7 @@ class CompilationArtifacts: self.final_operation_graph = operation_graph def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]): - """Adds the bounds of the final operation graph to the artifacts. + """Add the bounds of the final operation graph to the artifacts. Args: bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary @@ -106,7 +106,7 @@ class CompilationArtifacts: self.bounds_of_the_final_operation_graph = bounds def add_final_operation_graph_mlir(self, mlir: str): - """Adds the mlir of the final operation graph to the artifacts. + """Add the mlir of the final operation graph to the artifacts. Args: mlir (str): the mlir code of the final operation graph @@ -119,7 +119,7 @@ class CompilationArtifacts: self.mlir_of_the_final_operation_graph = mlir def export(self): - """Exports the artifacts to a the output directory. + """Export the artifacts to a the output directory. Returns: None diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 0879521ee..7234c311f 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -23,7 +23,7 @@ BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is an encrypted ScalarValue of type Integer. + """Check that a value is an encrypted ScalarValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -35,7 +35,7 @@ def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool: def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is an encrypted ScalarValue of type unsigned Integer. + """Check that a value is an encrypted ScalarValue of type unsigned Integer. Args: value_to_check (BaseValue): The value to check @@ -51,7 +51,7 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is a clear ScalarValue of type Integer. + """Check that a value is a clear ScalarValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -63,7 +63,7 @@ def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool: def value_is_scalar_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is a ScalarValue of type Integer. + """Check that a value is a ScalarValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -77,7 +77,7 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool: def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is an encrypted TensorValue of type Integer. + """Check that a value is an encrypted TensorValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -89,7 +89,7 @@ def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is an encrypted TensorValue of type unsigned Integer. + """Check that a value is an encrypted TensorValue of type unsigned Integer. Args: value_to_check (BaseValue): The value to check @@ -105,7 +105,7 @@ def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> boo def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is a clear TensorValue of type Integer. + """Check that a value is a clear TensorValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -117,7 +117,7 @@ def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool: def value_is_tensor_integer(value_to_check: BaseValue) -> bool: - """Helper function to check that a value is a TensorValue of type Integer. + """Check that a value is a TensorValue of type Integer. Args: value_to_check (BaseValue): The value to check @@ -294,7 +294,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: - """Helper function to determine the BaseDataType to hold the input constant data. + """Determine the BaseDataType to hold the input constant data. Args: constant_data (Union[int, float]): The constant data for which to determine the @@ -320,7 +320,7 @@ def get_base_data_type_for_python_constant_data(constant_data: Union[int, float] def get_base_value_for_python_constant_data( constant_data: Union[int, float] ) -> Callable[..., ScalarValue]: - """Function to wrap the BaseDataType to hold the input constant data in a ScalarValue partial. + """Wrap the BaseDataType to hold the input constant data in a ScalarValue partial. The returned object can then be instantiated as an Encrypted or Clear version of the ScalarValue by calling it with the proper arguments forwarded to the ScalarValue `__init__` function diff --git a/concrete/common/data_types/integers.py b/concrete/common/data_types/integers.py index 7e5b0d79b..7ef0674d7 100644 --- a/concrete/common/data_types/integers.py +++ b/concrete/common/data_types/integers.py @@ -43,7 +43,7 @@ class Integer(base.BaseDataType): return 2 ** self.bit_width - 1 def can_represent_value(self, value_to_represent: int) -> bool: - """A helper function to check if a value is representable by the Integer. + """Check if a value is representable by the Integer. Args: value_to_represent (int): Value to check @@ -55,7 +55,7 @@ class Integer(base.BaseDataType): def create_signed_integer(bit_width: int) -> Integer: - """Convenience function to create a signed integer. + """Create a signed integer. Args: bit_width (int): width of the integer @@ -70,7 +70,7 @@ SignedInteger = create_signed_integer def create_unsigned_integer(bit_width: int) -> Integer: - """Convenience function to create an unsigned integer. + """Create an unsigned integer. Args: bit_width (int): width of the integer @@ -85,7 +85,7 @@ UnsignedInteger = create_unsigned_integer def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer: - """Returns an Integer able to hold all values, it is possible to force the Integer to be signed. + """Return an Integer able to hold all values, it is possible to force the Integer to be signed. Args: values (Iterable[Any]): The values to hold @@ -108,7 +108,7 @@ def make_integer_to_hold(values: Iterable[Any], force_signed: bool) -> Integer: def get_bits_to_represent_value_as_integer(value: Any, force_signed: bool) -> int: - """Returns how many bits are required to represent a numerical Value. + """Return how many bits are required to represent a numerical Value. Args: value (Any): The value for which we want to know how many bits are required. diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index d5ab22fee..1af29dd75 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -25,7 +25,7 @@ from ..representation import intermediate as ir def add(node, preds, ir_to_mlir_node, ctx): - """Converter function for the addition intermediate node.""" + """Convert an addition intermediate node.""" assert len(node.inputs) == 2, "addition should have two inputs" assert len(node.outputs) == 1, "addition should have a single output" if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( @@ -47,7 +47,7 @@ def add(node, preds, ir_to_mlir_node, ctx): def _add_eint_int(node, preds, ir_to_mlir_node, ctx): - """Converter function for the addition intermediate node with operands (eint, int).""" + """Convert an addition intermediate node with (eint, int).""" lhs_node, rhs_node = preds lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] return hlfhe.AddEintIntOp( @@ -58,7 +58,7 @@ def _add_eint_int(node, preds, ir_to_mlir_node, ctx): def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): - """Converter function for the addition intermediate node with operands (eint, int).""" + """Convert an addition intermediate node with (eint, int).""" lhs_node, rhs_node = preds lhs, rhs = lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] return hlfhe.AddEintOp( @@ -69,7 +69,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): def sub(node, preds, ir_to_mlir_node, ctx): - """Converter function for the subtraction intermediate node.""" + """Convert a subtraction intermediate node.""" assert len(node.inputs) == 2, "subtraction should have two inputs" assert len(node.outputs) == 1, "subtraction should have a single output" if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer( @@ -82,7 +82,7 @@ def sub(node, preds, ir_to_mlir_node, ctx): def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): - """Converter function for the subtraction intermediate node with operands (int, eint).""" + """Convert a subtraction intermediate node with (int, eint).""" lhs_node, rhs_node = preds lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] return hlfhe.SubIntEintOp( @@ -93,7 +93,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): def mul(node, preds, ir_to_mlir_node, ctx): - """Converter function for the multiplication intermediate node.""" + """Convert a multiplication intermediate node.""" assert len(node.inputs) == 2, "multiplication should have two inputs" assert len(node.outputs) == 1, "multiplication should have a single output" if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer( @@ -111,7 +111,7 @@ def mul(node, preds, ir_to_mlir_node, ctx): def _mul_eint_int(node, preds, ir_to_mlir_node, ctx): - """Converter function for the multiplication intermediate node with operands (eint, int).""" + """Convert a multiplication intermediate node with (eint, int).""" lhs_node, rhs_node = preds lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] return hlfhe.MulEintIntOp( @@ -122,7 +122,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx): def constant(node, _, __, ctx): - """Converter function for constant inputs.""" + """Convert a constant inputs.""" if not value_is_clear_scalar_integer(node.outputs[0]): raise TypeError("Don't support non-integer constants") dtype = cast(Integer, node.outputs[0].data_type) @@ -133,7 +133,7 @@ def constant(node, _, __, ctx): def apply_lut(node, preds, ir_to_mlir_node, ctx): - """Converter function for the arbitrary function intermediate node.""" + """Convert an arbitrary function intermediate node.""" assert len(node.inputs) == 1, "LUT should have a single input" assert len(node.outputs) == 1, "LUT should have a single output" if not value_is_encrypted_scalar_unsigned_integer(node.inputs[0]): @@ -159,7 +159,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): def dot(node, preds, ir_to_mlir_node, ctx): - """Converter function for the dot intermediate node.""" + """Convert a dot intermediate node.""" assert len(node.inputs) == 2, "Dot should have two inputs" assert len(node.outputs) == 1, "Dot should have a single output" if not ( diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index d42a15419..313e25ec1 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -109,7 +109,7 @@ class OPGraph: return [self.output_nodes[idx] for idx in range(len(self.output_nodes))] def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]: - """Function to evaluate a graph and get intermediate values for all nodes. + """Evaluate a graph and get intermediate values for all nodes. Args: inputs (Dict[int, Any]): The inputs to the program @@ -195,7 +195,7 @@ class OPGraph: succ.inputs[input_idx] = deepcopy(node.outputs[0]) def prune_nodes(self): - """Function to remove unreachable nodes from outputs.""" + """Remove unreachable nodes from outputs.""" current_nodes = set(self.output_nodes.values()) useful_nodes: Set[ir.IntermediateNode] = set() diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index cc66bb731..f80ec2161 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -15,7 +15,7 @@ def fuse_float_operations( op_graph: OPGraph, compilation_artifacts: Optional[CompilationArtifacts] = None, ): - """Finds and fuses float domains into single Integer to Integer ArbitraryFunction. + """Find and fuse float domains into single Integer to Integer ArbitraryFunction. Args: op_graph (OPGraph): The OPGraph to simplify @@ -90,7 +90,7 @@ def convert_float_subgraph_to_fused_node( terminal_node: ir.IntermediateNode, subgraph_all_nodes: Set[ir.IntermediateNode], ) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: - """Converts a float subgraph to an equivalent fused ArbitraryFunction node. + """Convert a float subgraph to an equivalent fused ArbitraryFunction node. Args: op_graph (OPGraph): The OPGraph the float subgraph is part of. diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 3bcf86507..5f4ba2d23 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -52,7 +52,7 @@ class IntermediateNode(ABC): @abstractmethod def evaluate(self, inputs: Dict[int, Any]) -> Any: - """Function to simulate what the represented computation would output for the given inputs. + """Simulate what the represented computation would output for the given inputs. Args: inputs (Dict[int, Any]): Dict containing the inputs for the evaluation @@ -63,7 +63,7 @@ class IntermediateNode(ABC): @classmethod def n_in(cls) -> int: - """Returns how many inputs the node has. + """Return how many inputs the node has. Returns: int: The number of inputs of the node. @@ -72,7 +72,7 @@ class IntermediateNode(ABC): @classmethod def requires_mix_values_func(cls) -> bool: - """Function to determine whether the Class requires a mix_values_func to be built. + """Determine whether the Class requires a mix_values_func to be built. Returns: bool: True if __init__ expects a mix_values_func argument. @@ -81,7 +81,7 @@ class IntermediateNode(ABC): @abstractmethod def label(self) -> str: - """Function to get the label of the node. + """Get the label of the node. Returns: str: the label of the node @@ -182,7 +182,7 @@ class Constant(IntermediateNode): @property def constant_data(self) -> Any: - """Returns the constant_data stored in the Constant node. + """Return the constant_data stored in the Constant node. Returns: Any: The constant data that was stored. @@ -230,7 +230,7 @@ class ArbitraryFunction(IntermediateNode): return self.op_name def get_table(self) -> List[Any]: - """Function to get the table for the current input value of this ArbitraryFunction. + """Get the table for the current input value of this ArbitraryFunction. Returns: List[Any]: The table. @@ -255,7 +255,7 @@ class ArbitraryFunction(IntermediateNode): def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any: - """Default python dot implementation for 1D iterable arrays. + """Return the default python dot implementation for 1D iterable arrays. Args: lhs (Any): lhs vector of the dot. @@ -268,7 +268,7 @@ def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any: class Dot(IntermediateNode): - """Node representing a dot product.""" + """Return the node representing a dot product.""" _n_in: int = 2 # Optional, same issue as in ArbitraryFunction for mypy diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 851362802..dfaa12548 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -28,7 +28,7 @@ class BaseTracer(ABC): @abstractmethod def _supports_other_operand(self, other: Any) -> bool: - """Function to check if the current class supports tracing with the other operand. + """Check if the current class supports tracing with the other operand. Args: other (Any): the operand to check compatibility with. @@ -40,7 +40,7 @@ class BaseTracer(ABC): @abstractmethod def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer": - """Helper function to create a tracer for a constant input. + """Create a tracer for a constant input. Args: constant_data (Any): The constant to store. @@ -63,7 +63,7 @@ class BaseTracer(ABC): inputs: Iterable[Union["BaseTracer", Any]], computation_to_trace: Type[ir.IntermediateNode], ) -> Tuple["BaseTracer", ...]: - """Helper functions to instantiate all output BaseTracer for a given computation. + """Instantiate all output BaseTracer for a given computation. Args: inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index 34043bda7..8e0824c47 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -15,7 +15,7 @@ def make_input_tracers( tracer_class: Type[BaseTracer], function_parameters: OrderedDict[str, BaseValue], ) -> OrderedDict[str, BaseTracer]: - """Helper function to create tracers for a function's parameters. + """Create tracers for a function's parameters. Args: tracer_class (Type[BaseTracer]): the class of tracer to create an Input for @@ -37,7 +37,7 @@ def make_input_tracer( input_idx: int, input_value: BaseValue, ) -> BaseTracer: - """Helper function to create a tracer for an input value. + """Create a tracer for an input value. Args: tracer_class (Type[BaseTracer]): the class of tracer to create an Input for @@ -55,7 +55,7 @@ def make_input_tracer( def prepare_function_parameters( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] ) -> OrderedDict[str, BaseValue]: - """Function to filter the passed function_parameters to trace function_to_trace. + """Filter the passed function_parameters to trace function_to_trace. Args: function_to_trace (Callable): function that will be traced for which parameters are checked diff --git a/concrete/common/values/scalars.py b/concrete/common/values/scalars.py index 38a9a87e4..1b057fa99 100644 --- a/concrete/common/values/scalars.py +++ b/concrete/common/values/scalars.py @@ -16,7 +16,7 @@ class ScalarValue(BaseValue): def make_clear_scalar(data_type: BaseDataType) -> ScalarValue: - """Helper to create a clear ScalarValue. + """Create a clear ScalarValue. Args: data_type (BaseDataType): The data type for the value. @@ -28,7 +28,7 @@ def make_clear_scalar(data_type: BaseDataType) -> ScalarValue: def make_encrypted_scalar(data_type: BaseDataType) -> ScalarValue: - """Helper to create an encrypted ScalarValue. + """Create an encrypted ScalarValue. Args: data_type (BaseDataType): The data type for the value. diff --git a/concrete/common/values/tensors.py b/concrete/common/values/tensors.py index 966f7c8e7..dc3d421d0 100644 --- a/concrete/common/values/tensors.py +++ b/concrete/common/values/tensors.py @@ -41,7 +41,7 @@ class TensorValue(BaseValue): @property def shape(self) -> Tuple[int, ...]: - """The TensorValue shape property. + """Return the TensorValue shape property. Returns: Tuple[int, ...]: The TensorValue shape. @@ -50,7 +50,7 @@ class TensorValue(BaseValue): @property def ndim(self) -> int: - """The TensorValue ndim property. + """Return the TensorValue ndim property. Returns: int: The TensorValue ndim. @@ -59,7 +59,7 @@ class TensorValue(BaseValue): @property def size(self) -> int: - """The TensorValue size property. + """Return the TensorValue size property. Returns: int: The TensorValue size. @@ -71,7 +71,7 @@ def make_clear_tensor( data_type: BaseDataType, shape: Optional[Tuple[int, ...]] = None, ) -> TensorValue: - """Helper to create a clear TensorValue. + """Create a clear TensorValue. Args: data_type (BaseDataType): The data type for the tensor. @@ -87,7 +87,7 @@ def make_encrypted_tensor( data_type: BaseDataType, shape: Optional[Tuple[int, ...]] = None, ) -> TensorValue: - """Helper to create an encrypted TensorValue. + """Create an encrypted TensorValue. Args: data_type (BaseDataType): The data type for the tensor. diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index f3700a1fc..e3de2a337 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -199,7 +199,7 @@ def _compile_numpy_function_internal( compilation_artifacts: CompilationArtifacts, show_mlir: bool, ) -> CompilerEngine: - """Internal part of the API to be able to compile an homomorphic program. + """Compile an homomorphic program (internal part of the API). Args: function_to_compile (Callable): The function you want to compile @@ -254,7 +254,7 @@ def compile_numpy_function( compilation_artifacts: Optional[CompilationArtifacts] = None, show_mlir: bool = False, ) -> CompilerEngine: - """Main API to be able to compile an homomorphic program. + """Compile an homomorphic program (main API). Args: function_to_compile (Callable): The function to compile diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index cbba8f9df..69586c1cc 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -33,7 +33,7 @@ SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_ def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType: - """Helper function to get the corresponding BaseDataType from a numpy dtype. + """Get the corresponding BaseDataType from a numpy dtype. Args: numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype @@ -99,7 +99,7 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType: - """Helper function to determine the BaseDataType to hold the input constant data. + """Determine the BaseDataType to hold the input constant data. Args: constant_data (Any): The constant data for which to determine the @@ -124,7 +124,7 @@ def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> def get_base_value_for_numpy_or_python_constant_data( constant_data: Any, ) -> Callable[..., BaseValue]: - """Helper function to determine the BaseValue and BaseDataType to hold the input constant data. + """Determine the BaseValue and BaseDataType to hold the input constant data. This function is able to handle numpy types @@ -158,7 +158,7 @@ def get_numpy_function_output_dtype( function: Union[numpy.ufunc, Callable], input_dtypes: List[BaseDataType], ) -> List[numpy.dtype]: - """Function to record the output dtype of a numpy function given some input types. + """Record the output dtype of a numpy function given some input types. Args: function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index d5247dc38..0c1739852 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -134,7 +134,7 @@ class NPTracer(BaseTracer): def _unary_operator( cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs ) -> "NPTracer": - """Function to trace an unary operator. + """Trace an unary operator. Returns: NPTracer: The output NPTracer containing the traced function @@ -158,7 +158,7 @@ class NPTracer(BaseTracer): return output_tracer def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": - """Function to trace numpy.dot. + """Trace numpy.dot. Returns: NPTracer: The output NPTracer containing the traced function @@ -285,7 +285,7 @@ class NPTracer(BaseTracer): def _get_fun(function: numpy.ufunc): - """Helper function to wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" + """Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" # We have to access this method to be able to build NPTracer.UFUNC_ROUTING # dynamically @@ -303,7 +303,7 @@ NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORT def trace_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] ) -> OPGraph: - """Function used to trace a numpy function. + """Trace a numpy function. Args: function_to_trace (Callable): The function you want to trace