From 98bec17050fcd4ecbaf59a03ba0cb9ee2fdf5d70 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 22 Feb 2022 11:14:23 +0100 Subject: [PATCH] feat: add convolution extension extend the current tracing and compilation with convolution, which should compile to the FHELinalg.conv2d operation from the compiler --- concrete/common/debugging/drawing.py | 2 + concrete/common/extensions/__init__.py | 2 +- concrete/common/extensions/convolution.py | 161 +++++++++++++++ concrete/common/mlir/node_converter.py | 49 +++++ concrete/common/mlir/utils.py | 5 +- .../common/representation/intermediate.py | 106 ++++++++++ concrete/numpy/__init__.py | 12 +- deps_licenses/licenses_linux_user.txt | 4 +- poetry.lock | 133 +++++++----- pyproject.toml | 3 +- tests/common/extensions/test_convolution.py | 193 ++++++++++++++++++ tests/conftest.py | 7 + tests/numpy/test_compile_conv.py | 44 ++++ 13 files changed, 663 insertions(+), 58 deletions(-) create mode 100644 concrete/common/extensions/convolution.py create mode 100644 tests/common/extensions/test_convolution.py create mode 100644 tests/numpy/test_compile_conv.py diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index 0a301943e..0bb5fc6ca 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -15,6 +15,7 @@ from ..representation.intermediate import ( ALL_IR_NODES, Add, Constant, + Conv2D, Dot, GenericFunction, IndexConstant, @@ -27,6 +28,7 @@ from ..representation.intermediate import ( IR_NODE_COLOR_MAPPING = { Input: "blue", Constant: "cyan", + Conv2D: "brown", Add: "red", Sub: "yellow", Mul: "green", diff --git a/concrete/common/extensions/__init__.py b/concrete/common/extensions/__init__.py index b99aba1fa..95ccf9a83 100644 --- a/concrete/common/extensions/__init__.py +++ b/concrete/common/extensions/__init__.py @@ -1,2 +1,2 @@ """Extensions module to provide additional functionality to our users.""" -from . import multi_table, table +from . import convolution, multi_table, table diff --git a/concrete/common/extensions/convolution.py b/concrete/common/extensions/convolution.py new file mode 100644 index 000000000..79b9f1f0d --- /dev/null +++ b/concrete/common/extensions/convolution.py @@ -0,0 +1,161 @@ +"""This file contains tracers for convolution operations.""" + +from typing import List, Optional, Tuple, Union, cast + +import numpy as np + +from ...numpy.tracing import NPConstant, NPTracer +from ..representation.intermediate import Conv2D +from ..tracing.base_tracer import BaseTracer + +SUPPORTED_AUTO_PAD = [ + "NOTSET", +] + + +def conv2d( + x: Union[np.ndarray, BaseTracer], + weight: Union[np.ndarray, BaseTracer], + bias: Optional[Union[np.ndarray, BaseTracer]] = None, + pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0), + strides: Union[Tuple[int, int], List[int]] = (1, 1), + dilations: Union[Tuple[int, int], List[int]] = (1, 1), + auto_pad: str = "NOTSET", +) -> Union[np.ndarray, NPTracer]: + """Trace or evaluate 2D convolution. + + Args: + x (Union[np.ndarray, BaseTracer]): Input of shape (NxCxHxW) + weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW) + bias (Optional[Union[np.ndarray, BaseTracer]], optional): Bias vector of size (F). + Defaults to None. + pads (Union[Tuple[int, int, int, int], List[int]], optional): Padding over each axis + (H_beg, W_beg, H_end, W_end). Defaults to (0, 0, 0, 0). + strides (Union[Tuple[int, int], List[int]], optional): Stride over each axis + (height and width). Defaults to (1, 1). + dilations (Union[Tuple[int, int], List[int]], optional): Dilation over each axis + (height and width). Defaults to (1, 1). + auto_pad (str, optional): Padding strategy. Defaults to "NOTSET". + + Raises: + ValueError: If one argument isn't in the range of expected values. + TypeError: If one argument isn't of the appropriate type. + + Returns: + Union[np.ndarray, BaseTracer]: Evaluation result, or traced computation + """ + if auto_pad not in SUPPORTED_AUTO_PAD: + raise ValueError("invalid auto_pad is specified") + + if not isinstance(x, (np.ndarray, BaseTracer)): + raise TypeError(f"input x must be an ndarray, or a BaseTracer, not a {type(x)}") + if not isinstance(weight, (np.ndarray, BaseTracer)): + raise TypeError(f"weight must be an ndarray, or a BaseTracer, not a {type(weight)}") + if not isinstance(bias, (np.ndarray, BaseTracer, type(None))): + raise TypeError(f"bias must be an ndarray, a BaseTracer, or None, not a {type(bias)}") + if not isinstance(pads, (tuple, list)): + raise TypeError(f"padding must be a tuple, or list, not a {type(pads)}") + if not isinstance(strides, (tuple, list)): + raise TypeError(f"strides must be a tuple, or list, not a {type(strides)}") + if not isinstance(dilations, (tuple, list)): + raise TypeError(f"dilations must be a tuple, or list, not a {type(dilations)}") + + if len(pads) != 4: + raise ValueError( + f"padding should be of the form (pad_height_begin, pad_width_begin, pad_height_end, " + f" pad_width_end), but got {type(pads)} of length {len(pads)}" + ) + if len(strides) != 2: + raise ValueError( + f"strides should be of the form (stride_height, stride_width), but got {type(strides)}" + f" of length {len(strides)}" + ) + if len(dilations) != 2: + raise ValueError( + f"dilations should be of the form (dilation_height, dilation_width), but got" + f" {type(dilations)} of length {len(dilations)}" + ) + + assert len(x.shape) == 4, f"input x should have size (N x C x H x W), not {x.shape}" + assert len(weight.shape) == 4, f"weight should have size (F x C x H x W), not {weight.shape}" + if bias is not None: + assert len(bias.shape) == 1, f"bias should have size (F), not {bias.shape}" + + if isinstance(x, BaseTracer): + return _trace_conv2d(x, weight, bias, pads, strides, dilations) + # X is an ndarray + bias = np.zeros(weight.shape[0]) if bias is None else bias + # For mypy + weight = cast(np.ndarray, weight) + bias = cast(np.ndarray, bias) + return _evaluate_conv2d(x, weight, bias, pads, strides, dilations) + + +def _trace_conv2d( + x: BaseTracer, + weight: Union[np.ndarray, BaseTracer], + bias: Optional[Union[np.ndarray, BaseTracer]], + pads: Union[Tuple[int, int, int, int], List[int]], + strides: Union[Tuple[int, int], List[int]], + dilations: Union[Tuple[int, int], List[int]], +) -> NPTracer: + """Trace 2D convolution. + + Args: + x (BaseTracer): Input of shape (NxCxHxW) + weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW) + bias (Optional[Union[np.ndarray, BaseTracer]]): Bias vector of size (F) + pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each + axis (H_beg, W_beg, H_end, W_end) + strides (Union[Tuple[int, int], List[int]]): Stride over each + axis (height and width) + dilations (Union[Tuple[int, int], List[int]]): Dilation over each + axis (height and width) + + Returns: + BaseTracer: Traced computation + """ + weight_tracer = ( + weight if isinstance(weight, BaseTracer) else NPTracer([], NPConstant(weight), 0) + ) + inputs = [x.output, weight_tracer.output] + output_tracer_inputs = [x, weight_tracer] + if bias is not None: + bias_tracer = bias if isinstance(bias, BaseTracer) else NPTracer([], NPConstant(bias), 0) + inputs.append(bias_tracer.output) + # For mypy + bias = cast(BaseTracer, bias_tracer) + output_tracer_inputs.append(bias) + + traced_computation = Conv2D(inputs, x.output.dtype, pads, strides, dilations) + output_tracer = x.__class__( + output_tracer_inputs, traced_computation=traced_computation, output_idx=0 + ) + # For mypy + assert isinstance(output_tracer, NPTracer) + return output_tracer + + +def _evaluate_conv2d( + x: np.ndarray, + weight: np.ndarray, + bias: np.ndarray, + pads: Union[Tuple[int, int, int, int], List[int]], + strides: Union[Tuple[int, int], List[int]], + dilations: Union[Tuple[int, int], List[int]], +) -> np.ndarray: + """Evaluate 2D convolution. + + Args: + x (np.ndarray): Input of shape (NxCxHxW) + weight (np.ndarray): Weight (kernel) of shape (FxCxHxW) + bias (np.ndarray): Bias vector of size (F) + pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each + axis (H_beg, W_beg, H_end, W_end) + strides (Union[Tuple[int, int], List[int]]): Stride over each axis (height and width) + dilations (Union[Tuple[int, int], List[int]]): Dilation over each axis (height and width) + + Returns: + np.ndarray: Result of the convolution of shape (NxCxHxW) + """ + return Conv2D.evaluate_conv2d(x, weight, bias, pads, strides, dilations) diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index fbfaf6dd3..8266a1c79 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -29,6 +29,7 @@ from ..operator_graph import OPGraph from ..representation.intermediate import ( Add, Constant, + Conv2D, Dot, GenericFunction, IndexConstant, @@ -140,6 +141,9 @@ class IntermediateNodeConverter: elif isinstance(self.node, Sub): result = self.convert_sub() + elif isinstance(self.node, Conv2D): + result = self.convert_conv2d() + else: # pragma: no cover # this branch is not covered as unsupported opeations fail on check mlir compatibility raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet") @@ -282,6 +286,51 @@ class IntermediateNodeConverter: return arith.ConstantOp(resulting_type, attr).result + def convert_conv2d(self) -> OpResult: + """Convert a Conv2D node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 2 or len(self.node.inputs) == 3) + assert_true(len(self.node.outputs) == 1) + has_bias = len(self.node.inputs) == 3 + + x = self.node.inputs[0] + weight = self.node.inputs[1] + if not (x.is_encrypted and weight.is_clear): # pragma: no cover + raise NotImplementedError( + f"Conv2D with input {x} and weight {weight} cannot be converted to MLIR yet", + ) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + + node = cast(Conv2D, self.node) + integer_type = IntegerType.get_signless(64, context=self.ctx) + strides = DenseElementsAttr.get( + numpy.array(list(node.strides), dtype=numpy.uint64), + context=self.ctx, + type=integer_type, + ) + dilations = DenseElementsAttr.get( + numpy.array(list(node.dilations), dtype=numpy.uint64), + context=self.ctx, + type=integer_type, + ) + pads = DenseElementsAttr.get( + numpy.array(list(node.pads), dtype=numpy.uint64), context=self.ctx, type=integer_type + ) + if has_bias: + result = fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result + else: + result = fhelinalg.Conv2dOp( + resulting_type, *preds, None, pads, strides, dilations + ).result + + return result + def convert_dot(self) -> OpResult: """Convert a Dot node to its corresponding MLIR representation. diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index ecf6b63e8..98731edbd 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -15,7 +15,7 @@ from ..debugging import format_operation_graph from ..debugging.custom_assert import assert_not_reached, assert_true from ..operator_graph import OPGraph from ..representation import intermediate -from ..representation.intermediate import IntermediateNode +from ..representation.intermediate import Conv2D, IntermediateNode # TODO: should come from compiler, through an API, #402 ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 @@ -103,6 +103,9 @@ def check_node_compatibility_with_mlir( elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication assert_true(len(inputs) == 2) + elif isinstance(node, Conv2D): + assert_true(len(inputs) in [2, 3]) + else: # pragma: no cover assert_not_reached("Non IntermediateNode object in the OPGraph") diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 2267f2a8e..5faaed6b8 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -4,8 +4,11 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy from enum import Enum, unique +from math import floor from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast +import numpy as np +import torch from loguru import logger from ..data_types.base import BaseDataType @@ -221,6 +224,109 @@ class Constant(IntermediateNode): return self._constant_data +class Conv2D(IntermediateNode): + """Return the node representing a 2d-convolution.""" + + def __init__( + self, + inputs: Iterable[BaseValue], + output_dtype: BaseDataType, + pads: Union[List[int], Tuple[int, int, int, int]], + strides: Union[List[int], Tuple[int, int]], + dilations: Union[List[int], Tuple[int, int]], + ) -> None: + + # TODO: remove this when padding is supported (#427) + assert all(pad == 0 for pad in pads), "conv2d doesn't support padding yet" + + super().__init__(inputs) + self.pads = pads + self.strides = strides + self.dilations = dilations + + self._n_in = len(self.inputs) + assert_true(len(self.inputs) == 2 or len(self.inputs) == 3) + + assert_true( + all( + isinstance(input_value, TensorValue) and input_value.ndim == 4 + for input_value in self.inputs[:2] + ), + f"Conv2D only supports input and weight tensors of 4 dimensions" + f"({TensorValue.__name__} with ndim == 4)", + ) + bias = cast(TensorValue, self.inputs[2]) if len(self.inputs) == 3 else None + if bias is not None: + assert_true( + isinstance(bias, TensorValue) and bias.ndim == 1, + f"Conv2D only supports bias 1 dimension ({TensorValue.__name__} with ndim == 1)", + ) + + x = cast(TensorValue, self.inputs[0]) + weight = cast(TensorValue, self.inputs[1]) + + # Compute output shape + input_n, _, input_h, input_w = x.shape + weight_f, _, weight_h, weight_w = weight.shape + pads_h = pads[0] + pads[2] + pads_w = pads[1] + pads[3] + output_h = floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1 + output_w = floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1 + output_shape = (input_n, weight_f, output_h, output_w) + + output_value = EncryptedTensor(dtype=output_dtype, shape=output_shape) + self.outputs = [output_value] + + def text_for_drawing(self) -> str: + return "conv2d" + + def evaluate(self, inputs: Dict[int, Any]) -> Any: + + assert_true( + len(inputs) == self._n_in, f"expected {self.n_in} inputs, but got {len(inputs)}" + ) + x, weight = inputs[0], inputs[1] + bias = inputs[2] if len(inputs) == 3 else np.zeros(weight.shape[0]) + + return self.evaluate_conv2d(x, weight, bias, self.pads, self.strides, self.dilations) + + @staticmethod + def evaluate_conv2d( + x: np.ndarray, + weight: np.ndarray, + bias: np.ndarray, + # TODO: use padding when supported (#427) + _: Union[Tuple[int, int, int, int], List[int]], + strides: Union[Tuple[int, int], List[int]], + dilations: Union[Tuple[int, int], List[int]], + ): + """Evaluate 2D convolution. + + Args: + x (np.ndarray): Input of shape (NxCxHxW) + weight (np.ndarray): Weight (kernel) of shape (FxCxHxW) + bias (np.ndarray): Bias vector of size (F) + pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each + axis (H_beg, W_beg, H_end, W_end) + strides (Union[Tuple[int, int], List[int]]): Stride over each + axis (height and width) + dilations (Union[Tuple[int, int], List[int]]): Dilation over each + axis (height and width) + + Returns: + np.ndarray: Result of the convolution of shape (NxCxHxW) + """ + # pylint: disable=no-member + return torch.conv2d( + torch.tensor(x, dtype=torch.long), + torch.tensor(weight, dtype=torch.long), + torch.tensor(bias, dtype=torch.long), + stride=strides, + dilation=dilations, + ).numpy() + # pylint: enable=no-member + + class IndexConstant(IntermediateNode): """Node representing a constant indexing in the program. diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 0627a474d..a268d866e 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -1,5 +1,13 @@ """Module for compiling numpy functions to homomorphic equivalents.""" +# Import differently to put at the top, and avoid circular import issues +from concrete.numpy.compile import ( + compile_numpy_function, + compile_numpy_function_into_op_graph_and_measure_bounds, +) +from concrete.numpy.np_fhe_compiler import NPFHECompiler +from concrete.numpy.tracing import trace_numpy_function + from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import ( Float, @@ -11,9 +19,7 @@ from ..common.data_types import ( UnsignedInteger, ) from ..common.debugging import draw_graph, format_operation_graph +from ..common.extensions.convolution import conv2d from ..common.extensions.multi_table import MultiLookupTable from ..common.extensions.table import LookupTable from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue -from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds -from .np_fhe_compiler import NPFHECompiler -from .tracing import trace_numpy_function diff --git a/deps_licenses/licenses_linux_user.txt b/deps_licenses/licenses_linux_user.txt index f4c604980..dc75d04a1 100644 --- a/deps_licenses/licenses_linux_user.txt +++ b/deps_licenses/licenses_linux_user.txt @@ -1,7 +1,7 @@ Name Version License Pillow 9.0.1 Historical Permission Notice and Disclaimer (HPND) PyYAML 6.0 MIT License - concrete-compiler 0.2.0 BSD-3 + concrete-compiler 0.3.1 BSD-3 cycler 0.11.0 BSD License fonttools 4.29.1 MIT License kiwisolver 1.3.2 BSD License @@ -16,3 +16,5 @@ setuptools-scm 6.4.2 MIT License six 1.16.0 MIT License tomli 1.2.3 MIT License + torch 1.10.2 BSD License + typing-extensions 4.1.1 Python Software Foundation License diff --git a/poetry.lock b/poetry.lock index 4d33a6663..7c24049b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -243,7 +243,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "concrete-compiler" -version = "0.2.0" +version = "0.3.1" description = "Concrete Compiler" category = "main" optional = false @@ -255,7 +255,7 @@ PyYAML = "*" [[package]] name = "coverage" -version = "6.3.1" +version = "6.3.2" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -2111,6 +2111,17 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "torch" +version = "1.10.2" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +typing-extensions = "*" + [[package]] name = "tornado" version = "6.1" @@ -2186,7 +2197,7 @@ python-versions = "*" name = "typing-extensions" version = "4.1.1" description = "Backported and Experimental Type Hints for Python 3.6+" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -2268,7 +2279,7 @@ full = ["pygraphviz"] [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.10" -content-hash = "c46630b3a44a45815631ebc5f42ef77b14b4648cccce3405231597465b316884" +content-hash = "89f3c912cef146d06a3da96e36347367e1703fa5670f3153637e1822c5e81897" [metadata.files] alabaster = [ @@ -2419,55 +2430,54 @@ colorama = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] concrete-compiler = [ - {file = "concrete_compiler-0.2.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:ea282b4a2ba1af46ec72c044b44342ebec8c9d44734963281a3214e7fa4d68da"}, - {file = "concrete_compiler-0.2.0-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:56a6d37f717f0e85360e9dec790f3d3d1c2a1a99fb26524063414fd447f43a34"}, - {file = "concrete_compiler-0.2.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:711dd3093179194af629ec2dbab28309c4e7a7d188fe692e6c98c481f53639de"}, - {file = "concrete_compiler-0.2.0-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:e640e3f944f3599ad0ab126de3edd984d7eff9b3be094ae4c1d98546e224e784"}, - {file = "concrete_compiler-0.2.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:9203d7c68016a985dab81c4b047391af41573d3dd2884d8e7a645feae50f7609"}, - {file = "concrete_compiler-0.2.0-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:b0112ea4be2a81f8528c50c439e26a756566b7d9b08dd88b48105cd228668ad1"}, + {file = "concrete_compiler-0.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a6f3d27e3246af55d3f8116c12bd63747691939ca6b0c97170d6d7eda5bd3bb7"}, + {file = "concrete_compiler-0.3.1-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:69e77d45a5df39758bbd38c3fa154d479ad7855afdc06bb7f93c75424d00eae8"}, + {file = "concrete_compiler-0.3.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:fc2d87aebf0c6772dc9e443194b8c3a363614c6fb1042a5e86a9f5f77a76e360"}, + {file = "concrete_compiler-0.3.1-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:9adc23818a2d64d24e0ab94fd40938b6aaf5ae32c8ad6b562158bd55914ac319"}, + {file = "concrete_compiler-0.3.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:90ff4fc19dea3f28d7f177ab53979be533443424534f7ee7cce2f0622b82eb58"}, ] coverage = [ - {file = "coverage-6.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeffd96882d8c06d31b65dddcf51db7c612547babc1c4c5db6a011abe9798525"}, - {file = "coverage-6.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:621f6ea7260ea2ffdaec64fe5cb521669984f567b66f62f81445221d4754df4c"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84f2436d6742c01136dd940ee158bfc7cf5ced3da7e4c949662b8703b5cd8145"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de73fca6fb403dd72d4da517cfc49fcf791f74eee697d3219f6be29adf5af6ce"}, - {file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78fbb2be068a13a5d99dce9e1e7d168db880870f7bc73f876152130575bd6167"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f5a4551dfd09c3bd12fca8144d47fe7745275adf3229b7223c2f9e29a975ebda"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7bff3a98f63b47464480de1b5bdd80c8fade0ba2832c9381253c9b74c4153c27"}, - {file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a06c358f4aed05fa1099c39decc8022261bb07dfadc127c08cfbd1391b09689e"}, - {file = "coverage-6.3.1-cp310-cp310-win32.whl", hash = "sha256:9fff3ff052922cb99f9e52f63f985d4f7a54f6b94287463bc66b7cdf3eb41217"}, - {file = "coverage-6.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:276b13cc085474e482566c477c25ed66a097b44c6e77132f3304ac0b039f83eb"}, - {file = "coverage-6.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:56c4a409381ddd7bbff134e9756077860d4e8a583d310a6f38a2315b9ce301d0"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eb494070aa060ceba6e4bbf44c1bc5fa97bfb883a0d9b0c9049415f9e944793"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e15d424b8153756b7c903bde6d4610be0c3daca3986173c18dd5c1a1625e4cd"}, - {file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61d47a897c1e91f33f177c21de897267b38fbb45f2cd8e22a710bcef1df09ac1"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:25e73d4c81efa8ea3785274a2f7f3bfbbeccb6fcba2a0bdd3be9223371c37554"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fac0bcc5b7e8169bffa87f0dcc24435446d329cbc2b5486d155c2e0f3b493ae1"}, - {file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:72128176fea72012063200b7b395ed8a57849282b207321124d7ff14e26988e8"}, - {file = "coverage-6.3.1-cp37-cp37m-win32.whl", hash = "sha256:1bc6d709939ff262fd1432f03f080c5042dc6508b6e0d3d20e61dd045456a1a0"}, - {file = "coverage-6.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:618eeba986cea7f621d8607ee378ecc8c2504b98b3fdc4952b30fe3578304687"}, - {file = "coverage-6.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ed164af5c9078596cfc40b078c3b337911190d3faeac830c3f1274f26b8320"}, - {file = "coverage-6.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:352c68e233409c31048a3725c446a9e48bbff36e39db92774d4f2380d630d8f8"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:448d7bde7ceb6c69e08474c2ddbc5b4cd13c9e4aa4a717467f716b5fc938a734"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9fde6b90889522c220dd56a670102ceef24955d994ff7af2cb786b4ba8fe11e4"}, - {file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e647a0be741edbb529a72644e999acb09f2ad60465f80757da183528941ff975"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a5cdc3adb4f8bb8d8f5e64c2e9e282bc12980ef055ec6da59db562ee9bdfefa"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2dd70a167843b4b4b2630c0c56f1b586fe965b4f8ac5da05b6690344fd065c6b"}, - {file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9ad0a117b8dc2061ce9461ea4c1b4799e55edceb236522c5b8f958ce9ed8fa9a"}, - {file = "coverage-6.3.1-cp38-cp38-win32.whl", hash = "sha256:e92c7a5f7d62edff50f60a045dc9542bf939758c95b2fcd686175dd10ce0ed10"}, - {file = "coverage-6.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:482fb42eea6164894ff82abbcf33d526362de5d1a7ed25af7ecbdddd28fc124f"}, - {file = "coverage-6.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c5b81fb37db76ebea79aa963b76d96ff854e7662921ce742293463635a87a78d"}, - {file = "coverage-6.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a4f923b9ab265136e57cc14794a15b9dcea07a9c578609cd5dbbfff28a0d15e6"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d296cbc8254a7dffdd7bcc2eb70be5a233aae7c01856d2d936f5ac4e8ac1f1"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1245ab82e8554fa88c4b2ab1e098ae051faac5af829efdcf2ce6b34dccd5567c"}, - {file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f2b05757c92ad96b33dbf8e8ec8d4ccb9af6ae3c9e9bd141c7cc44d20c6bcba"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9e3dd806f34de38d4c01416344e98eab2437ac450b3ae39c62a0ede2f8b5e4ed"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d651fde74a4d3122e5562705824507e2f5b2d3d57557f1916c4b27635f8fbe3f"}, - {file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:704f89b87c4f4737da2860695a18c852b78ec7279b24eedacab10b29067d3a38"}, - {file = "coverage-6.3.1-cp39-cp39-win32.whl", hash = "sha256:2aed4761809640f02e44e16b8b32c1a5dee5e80ea30a0ff0912158bde9c501f2"}, - {file = "coverage-6.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:9976fb0a5709988778ac9bc44f3d50fccd989987876dfd7716dee28beed0a9fa"}, - {file = "coverage-6.3.1-pp36.pp37.pp38-none-any.whl", hash = "sha256:463e52616ea687fd323888e86bf25e864a3cc6335a043fad6bbb037dbf49bbe2"}, - {file = "coverage-6.3.1.tar.gz", hash = "sha256:6c3f6158b02ac403868eea390930ae64e9a9a2a5bbfafefbb920d29258d9f2f8"}, + {file = "coverage-6.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b27d894748475fa858f9597c0ee1d4829f44683f3813633aaf94b19cb5453cf"}, + {file = "coverage-6.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37d1141ad6b2466a7b53a22e08fe76994c2d35a5b6b469590424a9953155afac"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9987b0354b06d4df0f4d3e0ec1ae76d7ce7cbca9a2f98c25041eb79eec766f1"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:26e2deacd414fc2f97dd9f7676ee3eaecd299ca751412d89f40bc01557a6b1b4"}, + {file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dd8bafa458b5c7d061540f1ee9f18025a68e2d8471b3e858a9dad47c8d41903"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:46191097ebc381fbf89bdce207a6c107ac4ec0890d8d20f3360345ff5976155c"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6f89d05e028d274ce4fa1a86887b071ae1755082ef94a6740238cd7a8178804f"}, + {file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58303469e9a272b4abdb9e302a780072c0633cdcc0165db7eec0f9e32f901e05"}, + {file = "coverage-6.3.2-cp310-cp310-win32.whl", hash = "sha256:2fea046bfb455510e05be95e879f0e768d45c10c11509e20e06d8fcaa31d9e39"}, + {file = "coverage-6.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:a2a8b8bcc399edb4347a5ca8b9b87e7524c0967b335fbb08a83c8421489ddee1"}, + {file = "coverage-6.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f1555ea6d6da108e1999b2463ea1003fe03f29213e459145e70edbaf3e004aaa"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5f4e1edcf57ce94e5475fe09e5afa3e3145081318e5fd1a43a6b4539a97e518"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a15dc0a14008f1da3d1ebd44bdda3e357dbabdf5a0b5034d38fcde0b5c234b7"}, + {file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21b7745788866028adeb1e0eca3bf1101109e2dc58456cb49d2d9b99a8c516e6"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8ce257cac556cb03be4a248d92ed36904a59a4a5ff55a994e92214cde15c5bad"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b0be84e5a6209858a1d3e8d1806c46214e867ce1b0fd32e4ea03f4bd8b2e3359"}, + {file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:acf53bc2cf7282ab9b8ba346746afe703474004d9e566ad164c91a7a59f188a4"}, + {file = "coverage-6.3.2-cp37-cp37m-win32.whl", hash = "sha256:8bdde1177f2311ee552f47ae6e5aa7750c0e3291ca6b75f71f7ffe1f1dab3dca"}, + {file = "coverage-6.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b31651d018b23ec463e95cf10070d0b2c548aa950a03d0b559eaa11c7e5a6fa3"}, + {file = "coverage-6.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07e6db90cd9686c767dcc593dff16c8c09f9814f5e9c51034066cad3373b914d"}, + {file = "coverage-6.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c6dbb42f3ad25760010c45191e9757e7dce981cbfb90e42feef301d71540059"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c76aeef1b95aff3905fb2ae2d96e319caca5b76fa41d3470b19d4e4a3a313512"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cf5cfcb1521dc3255d845d9dca3ff204b3229401994ef8d1984b32746bb45ca"}, + {file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fbbdc8d55990eac1b0919ca69eb5a988a802b854488c34b8f37f3e2025fa90d"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ec6bc7fe73a938933d4178c9b23c4e0568e43e220aef9472c4f6044bfc6dd0f0"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9baff2a45ae1f17c8078452e9e5962e518eab705e50a0aa8083733ea7d45f3a6"}, + {file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd9e830e9d8d89b20ab1e5af09b32d33e1a08ef4c4e14411e559556fd788e6b2"}, + {file = "coverage-6.3.2-cp38-cp38-win32.whl", hash = "sha256:f7331dbf301b7289013175087636bbaf5b2405e57259dd2c42fdcc9fcc47325e"}, + {file = "coverage-6.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:68353fe7cdf91f109fc7d474461b46e7f1f14e533e911a2a2cbb8b0fc8613cf1"}, + {file = "coverage-6.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b78e5afb39941572209f71866aa0b206c12f0109835aa0d601e41552f9b3e620"}, + {file = "coverage-6.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e21876082ed887baed0146fe222f861b5815455ada3b33b890f4105d806128d"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34626a7eee2a3da12af0507780bb51eb52dca0e1751fd1471d0810539cefb536"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ebf730d2381158ecf3dfd4453fbca0613e16eaa547b4170e2450c9707665ce7"}, + {file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd6fe30bd519694b356cbfcaca9bd5c1737cddd20778c6a581ae20dc8c04def2"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96f8a1cb43ca1422f36492bebe63312d396491a9165ed3b9231e778d43a7fca4"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dd035edafefee4d573140a76fdc785dc38829fe5a455c4bb12bac8c20cfc3d69"}, + {file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5ca5aeb4344b30d0bec47481536b8ba1181d50dbe783b0e4ad03c95dc1296684"}, + {file = "coverage-6.3.2-cp39-cp39-win32.whl", hash = "sha256:f5fa5803f47e095d7ad8443d28b01d48c0359484fec1b9d8606d0e3282084bc4"}, + {file = "coverage-6.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:9548f10d8be799551eb3a9c74bbf2b4934ddb330e08a73320123c07f95cc2d92"}, + {file = "coverage-6.3.2-pp36.pp37.pp38-none-any.whl", hash = "sha256:18d520c6860515a771708937d2f78f63cc47ab3b80cb78e86573b0a760161faf"}, + {file = "coverage-6.3.2.tar.gz", hash = "sha256:03e2a7826086b91ef345ff18742ee9fc47a6839ccd517061ef8fa1976e652ce9"}, ] cryptography = [ {file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:73bc2d3f2444bcfeac67dd130ff2ea598ea5f20b40e36d19821b4df8c9c5037b"}, @@ -3533,6 +3543,27 @@ tomlkit = [ {file = "tomlkit-0.7.0-py2.py3-none-any.whl", hash = "sha256:6babbd33b17d5c9691896b0e68159215a9387ebfa938aa3ac42f4a4beeb2b831"}, {file = "tomlkit-0.7.0.tar.gz", hash = "sha256:ac57f29693fab3e309ea789252fcce3061e19110085aa31af5446ca749325618"}, ] +torch = [ + {file = "torch-1.10.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8f3fd2e3ffc3bb867133fdf7fbcc8a0bb2e62a5c0696396f51856f5abf9045a8"}, + {file = "torch-1.10.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:258a0729fb77a3457d5822d84b536057cd119b08049a8d3c41dc3dcdeb48d56e"}, + {file = "torch-1.10.2-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:935e5ac804c5093c79f23a7e6ca5b912c166071aa9d8b4a0a3d6a85126d6a47b"}, + {file = "torch-1.10.2-cp36-cp36m-win_amd64.whl", hash = "sha256:65fd02ed889c63fd82bf1a440c5a94c1310c29f3e6f9f62add416d34da355d97"}, + {file = "torch-1.10.2-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:6a81f886823bbd15edc2dc0908fa214070df61c9f7ab8831f0a03630275cca5a"}, + {file = "torch-1.10.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:3eee3cf53c1f8fb3f1fe107a22025a8501fc6440d14e09599ba7153002531f84"}, + {file = "torch-1.10.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ef99b8cca5f9358119b07956915faf6e7906f433ab4a603c160ae9de88918371"}, + {file = "torch-1.10.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d43bc3f3a2d89ae185ef96d903c935c335219231e57685658648396984e2a67a"}, + {file = "torch-1.10.2-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6da1b877880435440a5aa9678ef0f01986d4886416844db1d97ebfb7fd1778d0"}, + {file = "torch-1.10.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ab77a9f838874f295ed5410c0686fa22547456e0116efb281c66ef5f9d46fe28"}, + {file = "torch-1.10.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9ef4c004f9e5168bd1c1930c6aff25fed5b097de81db6271ffbb2e4fb8b89319"}, + {file = "torch-1.10.2-cp38-cp38-win_amd64.whl", hash = "sha256:376fc18407add20daa6bbaaffc5a5e06d733abe53bcbd60ef2532bfed34bc091"}, + {file = "torch-1.10.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:f281438ee99bd72ad65c0bba1026a32e45c3b636bc067fc145ad291e9ea2faab"}, + {file = "torch-1.10.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3592d3dd62b32760c82624e7586222747fe2281240e8653970b35f1d6d4a434c"}, + {file = "torch-1.10.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:fbaf18c1b3e0b31af194a9d853e3739464cf982d279df9d34dd18f1c2a471878"}, + {file = "torch-1.10.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:97b7b0c667e8b0dd1fc70137a36e0a4841ec10ef850bda60500ad066bef3e2de"}, + {file = "torch-1.10.2-cp39-cp39-win_amd64.whl", hash = "sha256:901b52787baeb2e9e1357ca7037da0028bc6ad743f530e0040ae96ef8e27156c"}, + {file = "torch-1.10.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:5b68e9108bd7ebd99eee941686046c517cfaac5331f757bcf440fe02f2e3ced1"}, + {file = "torch-1.10.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b07ef01e36b716d0d65ca60c4db0ac9d094a0e797d9b55290da4dcda91463b6c"}, +] tornado = [ {file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"}, {file = "tornado-6.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:0d321a39c36e5f2c4ff12b4ed58d41390460f798422c4504e09eb5678e09998c"}, diff --git a/pyproject.toml b/pyproject.toml index 003c377a4..907384f65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,8 @@ pygraphviz = { version = "^1.7", optional = true } Pillow = "^9.0.0" loguru = "^0.5.3" setuptools = "*" -concrete-compiler = "^0.2.0" +concrete-compiler = "^0.3.1" +torch = "^1.10.2" [tool.poetry.extras] full = ["pygraphviz"] diff --git a/tests/common/extensions/test_convolution.py b/tests/common/extensions/test_convolution.py new file mode 100644 index 000000000..8c2214ba1 --- /dev/null +++ b/tests/common/extensions/test_convolution.py @@ -0,0 +1,193 @@ +"""Test file for convolution""" + +import numpy as np +import pytest +import torch + +from concrete.common.extensions import convolution +from concrete.common.representation.intermediate import Conv2D +from concrete.common.tracing.base_tracer import BaseTracer +from concrete.common.values.tensors import TensorValue +from concrete.numpy.tracing import NPConstant, NPTracer + + +@pytest.mark.parametrize( + "kwargs, error_msg", + [ + pytest.param( + {"x": None, "weight": np.zeros(1)}, + "input x must be an ndarray, or a BaseTracer, not a", + ), + pytest.param( + {"x": np.zeros(1), "weight": None}, + "weight must be an ndarray, or a BaseTracer, not a", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "bias": 0}, + "bias must be an ndarray, a BaseTracer, or None, not a", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "strides": None}, + "strides must be a tuple, or list, not a", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "dilations": None}, + "dilations must be a tuple, or list, not a", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "pads": None}, + "padding must be a tuple, or list, not a", + ), + ], +) +def test_invalid_arg_types(kwargs, error_msg): + """Test function to make sure convolution doesn't accept invalid types""" + + with pytest.raises(TypeError) as err: + convolution.conv2d(**kwargs) + + assert error_msg in str(err) + + +@pytest.mark.parametrize( + "kwargs, error_msg", + [ + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1)}, + "input x should have size (N x C x H x W), not", + ), + pytest.param( + {"x": np.zeros((1, 2, 3, 4)), "weight": np.zeros(1)}, + "weight should have size (F x C x H x W), not", + ), + pytest.param( + { + "x": np.zeros((1, 2, 3, 4)), + "weight": np.zeros((1, 2, 3, 4)), + "bias": np.zeros((1, 2)), + }, + "bias should have size (F), not", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "strides": (1,)}, + "strides should be of the form", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "dilations": (1,)}, + "dilations should be of the form", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "pads": (1,)}, + "padding should be of the form", + ), + pytest.param( + {"x": np.zeros(1), "weight": np.zeros(1), "auto_pad": None}, + "invalid auto_pad is specified", + ), + ], +) +def test_invalid_input_shape(kwargs, error_msg): + """Test function to make sure convolution doesn't accept invalid shapes""" + + with pytest.raises((ValueError, AssertionError)) as err: + convolution.conv2d(**kwargs) + + assert error_msg in str(err) + + +@pytest.mark.parametrize( + "input_shape, weight_shape", + [ + pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), + pytest.param((3, 1, 4, 4), (1, 1, 2, 2)), + pytest.param((1, 1, 4, 4), (3, 1, 2, 2)), + pytest.param((1, 3, 4, 4), (1, 3, 2, 2)), + pytest.param((4, 3, 4, 4), (3, 3, 2, 2)), + pytest.param((4, 3, 16, 16), (3, 3, 2, 2)), + pytest.param((4, 3, 16, 16), (3, 3, 3, 3)), + ], +) +@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("has_bias", [True, False]) +@pytest.mark.parametrize("use_ndarray", [True, False]) +def test_tracing(input_shape, weight_shape, strides, dilations, has_bias, use_ndarray): + """Test function to make sure tracong of conv2d works properly""" + if has_bias: + bias = np.random.randint(0, 4, size=(weight_shape[0],)) + if not use_ndarray: + bias = NPTracer([], NPConstant(bias), 0) + else: + bias = None + + x = NPTracer([], NPConstant(np.random.randint(0, 4, size=input_shape)), 0) + weight = np.random.randint(0, 4, size=weight_shape) + if not use_ndarray: + weight = NPTracer([], NPConstant(weight), 0) + + output_tracer = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations) + traced_computation = output_tracer.traced_computation + assert isinstance(traced_computation, Conv2D) + + if has_bias: + assert len(output_tracer.inputs) == 3 + else: + assert len(output_tracer.inputs) == 2 + + assert all( + isinstance(input_, BaseTracer) for input_ in output_tracer.inputs + ), f"{output_tracer.inputs}" + + assert len(traced_computation.outputs) == 1 + output_value = traced_computation.outputs[0] + assert isinstance(output_value, TensorValue) and output_value.is_encrypted + # pylint: disable=no-member + expected_shape = torch.conv2d( + torch.randn(input_shape), + torch.randn(weight_shape), + torch.randn((weight_shape[0])), + stride=strides, + dilation=dilations, + ).shape + # pylint: enable=no-member + + assert output_value.shape == expected_shape + + +@pytest.mark.parametrize( + "input_shape, weight_shape", + [ + pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), + pytest.param((3, 1, 4, 4), (1, 1, 2, 2)), + pytest.param((1, 1, 4, 4), (3, 1, 2, 2)), + pytest.param((1, 3, 4, 4), (1, 3, 2, 2)), + pytest.param((4, 3, 4, 4), (3, 3, 2, 2)), + pytest.param((4, 3, 16, 16), (3, 3, 2, 2)), + pytest.param((4, 3, 16, 16), (3, 3, 3, 3)), + ], +) +@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("has_bias", [True, False]) +def test_evaluation(input_shape, weight_shape, strides, dilations, has_bias): + """Test function to make sure evaluation of conv2d on plain data works properly""" + if has_bias: + bias = np.random.randint(0, 4, size=(weight_shape[0],)) + else: + bias = np.zeros((weight_shape[0],)) + x = np.random.randint(0, 4, size=input_shape) + weight = np.random.randint(0, 4, size=weight_shape) + # pylint: disable=no-member + expected = torch.conv2d( + torch.tensor(x, dtype=torch.long), + torch.tensor(weight, dtype=torch.long), + torch.tensor(bias, dtype=torch.long), + stride=strides, + dilation=dilations, + ).numpy() + # pylint: enable=no-member + # conv2d should handle None biases + if not has_bias: + bias = None + result = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations) + assert (result == expected).all() diff --git a/tests/conftest.py b/tests/conftest.py index cd4a95c07..ad3c6947b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ from concrete.common.representation.intermediate import ( ALL_IR_NODES, Add, Constant, + Conv2D, Dot, GenericFunction, IndexConstant, @@ -261,6 +262,11 @@ def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool: return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs) +def is_equivalent_conv2d(lhs: Conv2D, rhs: object) -> bool: + """Helper function to check if a Conv2D node is equivalent to an other object.""" + return isinstance(rhs, Conv2D) and is_equivalent_intermediate_node(lhs, rhs) + + def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool: """Helper function to check if an IntermediateNode node is equivalent to an other object.""" return ( @@ -274,6 +280,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = { Add: is_equivalent_add, GenericFunction: is_equivalent_arbitrary_function, Constant: is_equivalent_constant, + Conv2D: is_equivalent_conv2d, Dot: is_equivalent_dot, IndexConstant: is_equivalent_index_constant, Input: is_equivalent_input, diff --git a/tests/numpy/test_compile_conv.py b/tests/numpy/test_compile_conv.py new file mode 100644 index 000000000..9cbceddf7 --- /dev/null +++ b/tests/numpy/test_compile_conv.py @@ -0,0 +1,44 @@ +"""Test module for convolution compilation and execution.""" + +import numpy as np +import pytest + +import concrete.numpy as hnp +from concrete.common.data_types.integers import Integer +from concrete.common.values.tensors import EncryptedTensor +from concrete.numpy.compile import compile_numpy_function + + +@pytest.mark.parametrize( + "input_shape, weight_shape", + [ + pytest.param((1, 1, 4, 4), (1, 1, 2, 2)), + pytest.param((4, 3, 4, 4), (2, 3, 2, 2)), + ], +) +@pytest.mark.parametrize("strides", [(2, 2)]) +@pytest.mark.parametrize("dilations", [(1, 1)]) +@pytest.mark.parametrize("has_bias", [True, False]) +def test_compile_and_run( + input_shape, weight_shape, strides, dilations, has_bias, default_compilation_configuration +): + """Test function to make sure compilation and execution of conv2d works properly""" + if has_bias: + bias = np.random.randint(0, 4, size=(weight_shape[0],)) + else: + bias = None + weight = np.random.randint(0, 4, size=weight_shape) + + def conv(x): + return hnp.conv2d(x, weight, bias, strides=strides, dilations=dilations) + + compiler_engine = compile_numpy_function( + conv, + {"x": EncryptedTensor(Integer(64, False), input_shape)}, + [np.random.randint(0, 4, size=input_shape) for i in range(20)], + default_compilation_configuration, + ) + x = np.random.randint(0, 4, size=input_shape, dtype=np.uint8) + expected = conv(x) + result = compiler_engine.run(x) + assert (expected == result).all()