feat: add convolution extension

extend the current tracing and compilation with convolution, which
should compile to the FHELinalg.conv2d operation from the compiler
This commit is contained in:
youben11
2022-02-22 11:14:23 +01:00
committed by Ayoub Benaissa
parent 8d12a53651
commit 98bec17050
13 changed files with 663 additions and 58 deletions

View File

@@ -15,6 +15,7 @@ from ..representation.intermediate import (
ALL_IR_NODES,
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
@@ -27,6 +28,7 @@ from ..representation.intermediate import (
IR_NODE_COLOR_MAPPING = {
Input: "blue",
Constant: "cyan",
Conv2D: "brown",
Add: "red",
Sub: "yellow",
Mul: "green",

View File

@@ -1,2 +1,2 @@
"""Extensions module to provide additional functionality to our users."""
from . import multi_table, table
from . import convolution, multi_table, table

View File

@@ -0,0 +1,161 @@
"""This file contains tracers for convolution operations."""
from typing import List, Optional, Tuple, Union, cast
import numpy as np
from ...numpy.tracing import NPConstant, NPTracer
from ..representation.intermediate import Conv2D
from ..tracing.base_tracer import BaseTracer
SUPPORTED_AUTO_PAD = [
"NOTSET",
]
def conv2d(
x: Union[np.ndarray, BaseTracer],
weight: Union[np.ndarray, BaseTracer],
bias: Optional[Union[np.ndarray, BaseTracer]] = None,
pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0),
strides: Union[Tuple[int, int], List[int]] = (1, 1),
dilations: Union[Tuple[int, int], List[int]] = (1, 1),
auto_pad: str = "NOTSET",
) -> Union[np.ndarray, NPTracer]:
"""Trace or evaluate 2D convolution.
Args:
x (Union[np.ndarray, BaseTracer]): Input of shape (NxCxHxW)
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
bias (Optional[Union[np.ndarray, BaseTracer]], optional): Bias vector of size (F).
Defaults to None.
pads (Union[Tuple[int, int, int, int], List[int]], optional): Padding over each axis
(H_beg, W_beg, H_end, W_end). Defaults to (0, 0, 0, 0).
strides (Union[Tuple[int, int], List[int]], optional): Stride over each axis
(height and width). Defaults to (1, 1).
dilations (Union[Tuple[int, int], List[int]], optional): Dilation over each axis
(height and width). Defaults to (1, 1).
auto_pad (str, optional): Padding strategy. Defaults to "NOTSET".
Raises:
ValueError: If one argument isn't in the range of expected values.
TypeError: If one argument isn't of the appropriate type.
Returns:
Union[np.ndarray, BaseTracer]: Evaluation result, or traced computation
"""
if auto_pad not in SUPPORTED_AUTO_PAD:
raise ValueError("invalid auto_pad is specified")
if not isinstance(x, (np.ndarray, BaseTracer)):
raise TypeError(f"input x must be an ndarray, or a BaseTracer, not a {type(x)}")
if not isinstance(weight, (np.ndarray, BaseTracer)):
raise TypeError(f"weight must be an ndarray, or a BaseTracer, not a {type(weight)}")
if not isinstance(bias, (np.ndarray, BaseTracer, type(None))):
raise TypeError(f"bias must be an ndarray, a BaseTracer, or None, not a {type(bias)}")
if not isinstance(pads, (tuple, list)):
raise TypeError(f"padding must be a tuple, or list, not a {type(pads)}")
if not isinstance(strides, (tuple, list)):
raise TypeError(f"strides must be a tuple, or list, not a {type(strides)}")
if not isinstance(dilations, (tuple, list)):
raise TypeError(f"dilations must be a tuple, or list, not a {type(dilations)}")
if len(pads) != 4:
raise ValueError(
f"padding should be of the form (pad_height_begin, pad_width_begin, pad_height_end, "
f" pad_width_end), but got {type(pads)} of length {len(pads)}"
)
if len(strides) != 2:
raise ValueError(
f"strides should be of the form (stride_height, stride_width), but got {type(strides)}"
f" of length {len(strides)}"
)
if len(dilations) != 2:
raise ValueError(
f"dilations should be of the form (dilation_height, dilation_width), but got"
f" {type(dilations)} of length {len(dilations)}"
)
assert len(x.shape) == 4, f"input x should have size (N x C x H x W), not {x.shape}"
assert len(weight.shape) == 4, f"weight should have size (F x C x H x W), not {weight.shape}"
if bias is not None:
assert len(bias.shape) == 1, f"bias should have size (F), not {bias.shape}"
if isinstance(x, BaseTracer):
return _trace_conv2d(x, weight, bias, pads, strides, dilations)
# X is an ndarray
bias = np.zeros(weight.shape[0]) if bias is None else bias
# For mypy
weight = cast(np.ndarray, weight)
bias = cast(np.ndarray, bias)
return _evaluate_conv2d(x, weight, bias, pads, strides, dilations)
def _trace_conv2d(
x: BaseTracer,
weight: Union[np.ndarray, BaseTracer],
bias: Optional[Union[np.ndarray, BaseTracer]],
pads: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
) -> NPTracer:
"""Trace 2D convolution.
Args:
x (BaseTracer): Input of shape (NxCxHxW)
weight (Union[np.ndarray, BaseTracer]): Weight (kernel) of shape (FxCxHxW)
bias (Optional[Union[np.ndarray, BaseTracer]]): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each
axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
axis (height and width)
Returns:
BaseTracer: Traced computation
"""
weight_tracer = (
weight if isinstance(weight, BaseTracer) else NPTracer([], NPConstant(weight), 0)
)
inputs = [x.output, weight_tracer.output]
output_tracer_inputs = [x, weight_tracer]
if bias is not None:
bias_tracer = bias if isinstance(bias, BaseTracer) else NPTracer([], NPConstant(bias), 0)
inputs.append(bias_tracer.output)
# For mypy
bias = cast(BaseTracer, bias_tracer)
output_tracer_inputs.append(bias)
traced_computation = Conv2D(inputs, x.output.dtype, pads, strides, dilations)
output_tracer = x.__class__(
output_tracer_inputs, traced_computation=traced_computation, output_idx=0
)
# For mypy
assert isinstance(output_tracer, NPTracer)
return output_tracer
def _evaluate_conv2d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
pads: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
) -> np.ndarray:
"""Evaluate 2D convolution.
Args:
x (np.ndarray): Input of shape (NxCxHxW)
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
bias (np.ndarray): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each axis (height and width)
Returns:
np.ndarray: Result of the convolution of shape (NxCxHxW)
"""
return Conv2D.evaluate_conv2d(x, weight, bias, pads, strides, dilations)

View File

@@ -29,6 +29,7 @@ from ..operator_graph import OPGraph
from ..representation.intermediate import (
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
@@ -140,6 +141,9 @@ class IntermediateNodeConverter:
elif isinstance(self.node, Sub):
result = self.convert_sub()
elif isinstance(self.node, Conv2D):
result = self.convert_conv2d()
else: # pragma: no cover
# this branch is not covered as unsupported opeations fail on check mlir compatibility
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
@@ -282,6 +286,51 @@ class IntermediateNodeConverter:
return arith.ConstantOp(resulting_type, attr).result
def convert_conv2d(self) -> OpResult:
"""Convert a Conv2D node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2 or len(self.node.inputs) == 3)
assert_true(len(self.node.outputs) == 1)
has_bias = len(self.node.inputs) == 3
x = self.node.inputs[0]
weight = self.node.inputs[1]
if not (x.is_encrypted and weight.is_clear): # pragma: no cover
raise NotImplementedError(
f"Conv2D with input {x} and weight {weight} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
node = cast(Conv2D, self.node)
integer_type = IntegerType.get_signless(64, context=self.ctx)
strides = DenseElementsAttr.get(
numpy.array(list(node.strides), dtype=numpy.uint64),
context=self.ctx,
type=integer_type,
)
dilations = DenseElementsAttr.get(
numpy.array(list(node.dilations), dtype=numpy.uint64),
context=self.ctx,
type=integer_type,
)
pads = DenseElementsAttr.get(
numpy.array(list(node.pads), dtype=numpy.uint64), context=self.ctx, type=integer_type
)
if has_bias:
result = fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
else:
result = fhelinalg.Conv2dOp(
resulting_type, *preds, None, pads, strides, dilations
).result
return result
def convert_dot(self) -> OpResult:
"""Convert a Dot node to its corresponding MLIR representation.

View File

@@ -15,7 +15,7 @@ from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_not_reached, assert_true
from ..operator_graph import OPGraph
from ..representation import intermediate
from ..representation.intermediate import IntermediateNode
from ..representation.intermediate import Conv2D, IntermediateNode
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
@@ -103,6 +103,9 @@ def check_node_compatibility_with_mlir(
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
assert_true(len(inputs) == 2)
elif isinstance(node, Conv2D):
assert_true(len(inputs) in [2, 3])
else: # pragma: no cover
assert_not_reached("Non IntermediateNode object in the OPGraph")

View File

@@ -4,8 +4,11 @@ from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from enum import Enum, unique
from math import floor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import numpy as np
import torch
from loguru import logger
from ..data_types.base import BaseDataType
@@ -221,6 +224,109 @@ class Constant(IntermediateNode):
return self._constant_data
class Conv2D(IntermediateNode):
"""Return the node representing a 2d-convolution."""
def __init__(
self,
inputs: Iterable[BaseValue],
output_dtype: BaseDataType,
pads: Union[List[int], Tuple[int, int, int, int]],
strides: Union[List[int], Tuple[int, int]],
dilations: Union[List[int], Tuple[int, int]],
) -> None:
# TODO: remove this when padding is supported (#427)
assert all(pad == 0 for pad in pads), "conv2d doesn't support padding yet"
super().__init__(inputs)
self.pads = pads
self.strides = strides
self.dilations = dilations
self._n_in = len(self.inputs)
assert_true(len(self.inputs) == 2 or len(self.inputs) == 3)
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 4
for input_value in self.inputs[:2]
),
f"Conv2D only supports input and weight tensors of 4 dimensions"
f"({TensorValue.__name__} with ndim == 4)",
)
bias = cast(TensorValue, self.inputs[2]) if len(self.inputs) == 3 else None
if bias is not None:
assert_true(
isinstance(bias, TensorValue) and bias.ndim == 1,
f"Conv2D only supports bias 1 dimension ({TensorValue.__name__} with ndim == 1)",
)
x = cast(TensorValue, self.inputs[0])
weight = cast(TensorValue, self.inputs[1])
# Compute output shape
input_n, _, input_h, input_w = x.shape
weight_f, _, weight_h, weight_w = weight.shape
pads_h = pads[0] + pads[2]
pads_w = pads[1] + pads[3]
output_h = floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1
output_w = floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1
output_shape = (input_n, weight_f, output_h, output_w)
output_value = EncryptedTensor(dtype=output_dtype, shape=output_shape)
self.outputs = [output_value]
def text_for_drawing(self) -> str:
return "conv2d"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
assert_true(
len(inputs) == self._n_in, f"expected {self.n_in} inputs, but got {len(inputs)}"
)
x, weight = inputs[0], inputs[1]
bias = inputs[2] if len(inputs) == 3 else np.zeros(weight.shape[0])
return self.evaluate_conv2d(x, weight, bias, self.pads, self.strides, self.dilations)
@staticmethod
def evaluate_conv2d(
x: np.ndarray,
weight: np.ndarray,
bias: np.ndarray,
# TODO: use padding when supported (#427)
_: Union[Tuple[int, int, int, int], List[int]],
strides: Union[Tuple[int, int], List[int]],
dilations: Union[Tuple[int, int], List[int]],
):
"""Evaluate 2D convolution.
Args:
x (np.ndarray): Input of shape (NxCxHxW)
weight (np.ndarray): Weight (kernel) of shape (FxCxHxW)
bias (np.ndarray): Bias vector of size (F)
pads (Union[Tuple[int, int, int, int], List[int]]): Padding over each
axis (H_beg, W_beg, H_end, W_end)
strides (Union[Tuple[int, int], List[int]]): Stride over each
axis (height and width)
dilations (Union[Tuple[int, int], List[int]]): Dilation over each
axis (height and width)
Returns:
np.ndarray: Result of the convolution of shape (NxCxHxW)
"""
# pylint: disable=no-member
return torch.conv2d(
torch.tensor(x, dtype=torch.long),
torch.tensor(weight, dtype=torch.long),
torch.tensor(bias, dtype=torch.long),
stride=strides,
dilation=dilations,
).numpy()
# pylint: enable=no-member
class IndexConstant(IntermediateNode):
"""Node representing a constant indexing in the program.

View File

@@ -1,5 +1,13 @@
"""Module for compiling numpy functions to homomorphic equivalents."""
# Import differently to put at the top, and avoid circular import issues
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
from concrete.numpy.np_fhe_compiler import NPFHECompiler
from concrete.numpy.tracing import trace_numpy_function
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import (
Float,
@@ -11,9 +19,7 @@ from ..common.data_types import (
UnsignedInteger,
)
from ..common.debugging import draw_graph, format_operation_graph
from ..common.extensions.convolution import conv2d
from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds
from .np_fhe_compiler import NPFHECompiler
from .tracing import trace_numpy_function

View File

@@ -1,7 +1,7 @@
Name Version License
Pillow 9.0.1 Historical Permission Notice and Disclaimer (HPND)
PyYAML 6.0 MIT License
concrete-compiler 0.2.0 BSD-3
concrete-compiler 0.3.1 BSD-3
cycler 0.11.0 BSD License
fonttools 4.29.1 MIT License
kiwisolver 1.3.2 BSD License
@@ -16,3 +16,5 @@
setuptools-scm 6.4.2 MIT License
six 1.16.0 MIT License
tomli 1.2.3 MIT License
torch 1.10.2 BSD License
typing-extensions 4.1.1 Python Software Foundation License

133
poetry.lock generated
View File

@@ -243,7 +243,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "concrete-compiler"
version = "0.2.0"
version = "0.3.1"
description = "Concrete Compiler"
category = "main"
optional = false
@@ -255,7 +255,7 @@ PyYAML = "*"
[[package]]
name = "coverage"
version = "6.3.1"
version = "6.3.2"
description = "Code coverage measurement for Python"
category = "dev"
optional = false
@@ -2111,6 +2111,17 @@ category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "torch"
version = "1.10.2"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
category = "main"
optional = false
python-versions = ">=3.6.2"
[package.dependencies]
typing-extensions = "*"
[[package]]
name = "tornado"
version = "6.1"
@@ -2186,7 +2197,7 @@ python-versions = "*"
name = "typing-extensions"
version = "4.1.1"
description = "Backported and Experimental Type Hints for Python 3.6+"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.6"
@@ -2268,7 +2279,7 @@ full = ["pygraphviz"]
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.10"
content-hash = "c46630b3a44a45815631ebc5f42ef77b14b4648cccce3405231597465b316884"
content-hash = "89f3c912cef146d06a3da96e36347367e1703fa5670f3153637e1822c5e81897"
[metadata.files]
alabaster = [
@@ -2419,55 +2430,54 @@ colorama = [
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
]
concrete-compiler = [
{file = "concrete_compiler-0.2.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:ea282b4a2ba1af46ec72c044b44342ebec8c9d44734963281a3214e7fa4d68da"},
{file = "concrete_compiler-0.2.0-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:56a6d37f717f0e85360e9dec790f3d3d1c2a1a99fb26524063414fd447f43a34"},
{file = "concrete_compiler-0.2.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:711dd3093179194af629ec2dbab28309c4e7a7d188fe692e6c98c481f53639de"},
{file = "concrete_compiler-0.2.0-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:e640e3f944f3599ad0ab126de3edd984d7eff9b3be094ae4c1d98546e224e784"},
{file = "concrete_compiler-0.2.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:9203d7c68016a985dab81c4b047391af41573d3dd2884d8e7a645feae50f7609"},
{file = "concrete_compiler-0.2.0-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:b0112ea4be2a81f8528c50c439e26a756566b7d9b08dd88b48105cd228668ad1"},
{file = "concrete_compiler-0.3.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a6f3d27e3246af55d3f8116c12bd63747691939ca6b0c97170d6d7eda5bd3bb7"},
{file = "concrete_compiler-0.3.1-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:69e77d45a5df39758bbd38c3fa154d479ad7855afdc06bb7f93c75424d00eae8"},
{file = "concrete_compiler-0.3.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:fc2d87aebf0c6772dc9e443194b8c3a363614c6fb1042a5e86a9f5f77a76e360"},
{file = "concrete_compiler-0.3.1-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:9adc23818a2d64d24e0ab94fd40938b6aaf5ae32c8ad6b562158bd55914ac319"},
{file = "concrete_compiler-0.3.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:90ff4fc19dea3f28d7f177ab53979be533443424534f7ee7cce2f0622b82eb58"},
]
coverage = [
{file = "coverage-6.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeffd96882d8c06d31b65dddcf51db7c612547babc1c4c5db6a011abe9798525"},
{file = "coverage-6.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:621f6ea7260ea2ffdaec64fe5cb521669984f567b66f62f81445221d4754df4c"},
{file = "coverage-6.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84f2436d6742c01136dd940ee158bfc7cf5ced3da7e4c949662b8703b5cd8145"},
{file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de73fca6fb403dd72d4da517cfc49fcf791f74eee697d3219f6be29adf5af6ce"},
{file = "coverage-6.3.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78fbb2be068a13a5d99dce9e1e7d168db880870f7bc73f876152130575bd6167"},
{file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f5a4551dfd09c3bd12fca8144d47fe7745275adf3229b7223c2f9e29a975ebda"},
{file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7bff3a98f63b47464480de1b5bdd80c8fade0ba2832c9381253c9b74c4153c27"},
{file = "coverage-6.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a06c358f4aed05fa1099c39decc8022261bb07dfadc127c08cfbd1391b09689e"},
{file = "coverage-6.3.1-cp310-cp310-win32.whl", hash = "sha256:9fff3ff052922cb99f9e52f63f985d4f7a54f6b94287463bc66b7cdf3eb41217"},
{file = "coverage-6.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:276b13cc085474e482566c477c25ed66a097b44c6e77132f3304ac0b039f83eb"},
{file = "coverage-6.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:56c4a409381ddd7bbff134e9756077860d4e8a583d310a6f38a2315b9ce301d0"},
{file = "coverage-6.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eb494070aa060ceba6e4bbf44c1bc5fa97bfb883a0d9b0c9049415f9e944793"},
{file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e15d424b8153756b7c903bde6d4610be0c3daca3986173c18dd5c1a1625e4cd"},
{file = "coverage-6.3.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61d47a897c1e91f33f177c21de897267b38fbb45f2cd8e22a710bcef1df09ac1"},
{file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:25e73d4c81efa8ea3785274a2f7f3bfbbeccb6fcba2a0bdd3be9223371c37554"},
{file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:fac0bcc5b7e8169bffa87f0dcc24435446d329cbc2b5486d155c2e0f3b493ae1"},
{file = "coverage-6.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:72128176fea72012063200b7b395ed8a57849282b207321124d7ff14e26988e8"},
{file = "coverage-6.3.1-cp37-cp37m-win32.whl", hash = "sha256:1bc6d709939ff262fd1432f03f080c5042dc6508b6e0d3d20e61dd045456a1a0"},
{file = "coverage-6.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:618eeba986cea7f621d8607ee378ecc8c2504b98b3fdc4952b30fe3578304687"},
{file = "coverage-6.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ed164af5c9078596cfc40b078c3b337911190d3faeac830c3f1274f26b8320"},
{file = "coverage-6.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:352c68e233409c31048a3725c446a9e48bbff36e39db92774d4f2380d630d8f8"},
{file = "coverage-6.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:448d7bde7ceb6c69e08474c2ddbc5b4cd13c9e4aa4a717467f716b5fc938a734"},
{file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9fde6b90889522c220dd56a670102ceef24955d994ff7af2cb786b4ba8fe11e4"},
{file = "coverage-6.3.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e647a0be741edbb529a72644e999acb09f2ad60465f80757da183528941ff975"},
{file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a5cdc3adb4f8bb8d8f5e64c2e9e282bc12980ef055ec6da59db562ee9bdfefa"},
{file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2dd70a167843b4b4b2630c0c56f1b586fe965b4f8ac5da05b6690344fd065c6b"},
{file = "coverage-6.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9ad0a117b8dc2061ce9461ea4c1b4799e55edceb236522c5b8f958ce9ed8fa9a"},
{file = "coverage-6.3.1-cp38-cp38-win32.whl", hash = "sha256:e92c7a5f7d62edff50f60a045dc9542bf939758c95b2fcd686175dd10ce0ed10"},
{file = "coverage-6.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:482fb42eea6164894ff82abbcf33d526362de5d1a7ed25af7ecbdddd28fc124f"},
{file = "coverage-6.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c5b81fb37db76ebea79aa963b76d96ff854e7662921ce742293463635a87a78d"},
{file = "coverage-6.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a4f923b9ab265136e57cc14794a15b9dcea07a9c578609cd5dbbfff28a0d15e6"},
{file = "coverage-6.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56d296cbc8254a7dffdd7bcc2eb70be5a233aae7c01856d2d936f5ac4e8ac1f1"},
{file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1245ab82e8554fa88c4b2ab1e098ae051faac5af829efdcf2ce6b34dccd5567c"},
{file = "coverage-6.3.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f2b05757c92ad96b33dbf8e8ec8d4ccb9af6ae3c9e9bd141c7cc44d20c6bcba"},
{file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9e3dd806f34de38d4c01416344e98eab2437ac450b3ae39c62a0ede2f8b5e4ed"},
{file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d651fde74a4d3122e5562705824507e2f5b2d3d57557f1916c4b27635f8fbe3f"},
{file = "coverage-6.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:704f89b87c4f4737da2860695a18c852b78ec7279b24eedacab10b29067d3a38"},
{file = "coverage-6.3.1-cp39-cp39-win32.whl", hash = "sha256:2aed4761809640f02e44e16b8b32c1a5dee5e80ea30a0ff0912158bde9c501f2"},
{file = "coverage-6.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:9976fb0a5709988778ac9bc44f3d50fccd989987876dfd7716dee28beed0a9fa"},
{file = "coverage-6.3.1-pp36.pp37.pp38-none-any.whl", hash = "sha256:463e52616ea687fd323888e86bf25e864a3cc6335a043fad6bbb037dbf49bbe2"},
{file = "coverage-6.3.1.tar.gz", hash = "sha256:6c3f6158b02ac403868eea390930ae64e9a9a2a5bbfafefbb920d29258d9f2f8"},
{file = "coverage-6.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9b27d894748475fa858f9597c0ee1d4829f44683f3813633aaf94b19cb5453cf"},
{file = "coverage-6.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37d1141ad6b2466a7b53a22e08fe76994c2d35a5b6b469590424a9953155afac"},
{file = "coverage-6.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9987b0354b06d4df0f4d3e0ec1ae76d7ce7cbca9a2f98c25041eb79eec766f1"},
{file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:26e2deacd414fc2f97dd9f7676ee3eaecd299ca751412d89f40bc01557a6b1b4"},
{file = "coverage-6.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dd8bafa458b5c7d061540f1ee9f18025a68e2d8471b3e858a9dad47c8d41903"},
{file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:46191097ebc381fbf89bdce207a6c107ac4ec0890d8d20f3360345ff5976155c"},
{file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6f89d05e028d274ce4fa1a86887b071ae1755082ef94a6740238cd7a8178804f"},
{file = "coverage-6.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58303469e9a272b4abdb9e302a780072c0633cdcc0165db7eec0f9e32f901e05"},
{file = "coverage-6.3.2-cp310-cp310-win32.whl", hash = "sha256:2fea046bfb455510e05be95e879f0e768d45c10c11509e20e06d8fcaa31d9e39"},
{file = "coverage-6.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:a2a8b8bcc399edb4347a5ca8b9b87e7524c0967b335fbb08a83c8421489ddee1"},
{file = "coverage-6.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f1555ea6d6da108e1999b2463ea1003fe03f29213e459145e70edbaf3e004aaa"},
{file = "coverage-6.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5f4e1edcf57ce94e5475fe09e5afa3e3145081318e5fd1a43a6b4539a97e518"},
{file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a15dc0a14008f1da3d1ebd44bdda3e357dbabdf5a0b5034d38fcde0b5c234b7"},
{file = "coverage-6.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21b7745788866028adeb1e0eca3bf1101109e2dc58456cb49d2d9b99a8c516e6"},
{file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8ce257cac556cb03be4a248d92ed36904a59a4a5ff55a994e92214cde15c5bad"},
{file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b0be84e5a6209858a1d3e8d1806c46214e867ce1b0fd32e4ea03f4bd8b2e3359"},
{file = "coverage-6.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:acf53bc2cf7282ab9b8ba346746afe703474004d9e566ad164c91a7a59f188a4"},
{file = "coverage-6.3.2-cp37-cp37m-win32.whl", hash = "sha256:8bdde1177f2311ee552f47ae6e5aa7750c0e3291ca6b75f71f7ffe1f1dab3dca"},
{file = "coverage-6.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b31651d018b23ec463e95cf10070d0b2c548aa950a03d0b559eaa11c7e5a6fa3"},
{file = "coverage-6.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07e6db90cd9686c767dcc593dff16c8c09f9814f5e9c51034066cad3373b914d"},
{file = "coverage-6.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c6dbb42f3ad25760010c45191e9757e7dce981cbfb90e42feef301d71540059"},
{file = "coverage-6.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c76aeef1b95aff3905fb2ae2d96e319caca5b76fa41d3470b19d4e4a3a313512"},
{file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cf5cfcb1521dc3255d845d9dca3ff204b3229401994ef8d1984b32746bb45ca"},
{file = "coverage-6.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fbbdc8d55990eac1b0919ca69eb5a988a802b854488c34b8f37f3e2025fa90d"},
{file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ec6bc7fe73a938933d4178c9b23c4e0568e43e220aef9472c4f6044bfc6dd0f0"},
{file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9baff2a45ae1f17c8078452e9e5962e518eab705e50a0aa8083733ea7d45f3a6"},
{file = "coverage-6.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd9e830e9d8d89b20ab1e5af09b32d33e1a08ef4c4e14411e559556fd788e6b2"},
{file = "coverage-6.3.2-cp38-cp38-win32.whl", hash = "sha256:f7331dbf301b7289013175087636bbaf5b2405e57259dd2c42fdcc9fcc47325e"},
{file = "coverage-6.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:68353fe7cdf91f109fc7d474461b46e7f1f14e533e911a2a2cbb8b0fc8613cf1"},
{file = "coverage-6.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b78e5afb39941572209f71866aa0b206c12f0109835aa0d601e41552f9b3e620"},
{file = "coverage-6.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e21876082ed887baed0146fe222f861b5815455ada3b33b890f4105d806128d"},
{file = "coverage-6.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34626a7eee2a3da12af0507780bb51eb52dca0e1751fd1471d0810539cefb536"},
{file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ebf730d2381158ecf3dfd4453fbca0613e16eaa547b4170e2450c9707665ce7"},
{file = "coverage-6.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd6fe30bd519694b356cbfcaca9bd5c1737cddd20778c6a581ae20dc8c04def2"},
{file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96f8a1cb43ca1422f36492bebe63312d396491a9165ed3b9231e778d43a7fca4"},
{file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dd035edafefee4d573140a76fdc785dc38829fe5a455c4bb12bac8c20cfc3d69"},
{file = "coverage-6.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5ca5aeb4344b30d0bec47481536b8ba1181d50dbe783b0e4ad03c95dc1296684"},
{file = "coverage-6.3.2-cp39-cp39-win32.whl", hash = "sha256:f5fa5803f47e095d7ad8443d28b01d48c0359484fec1b9d8606d0e3282084bc4"},
{file = "coverage-6.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:9548f10d8be799551eb3a9c74bbf2b4934ddb330e08a73320123c07f95cc2d92"},
{file = "coverage-6.3.2-pp36.pp37.pp38-none-any.whl", hash = "sha256:18d520c6860515a771708937d2f78f63cc47ab3b80cb78e86573b0a760161faf"},
{file = "coverage-6.3.2.tar.gz", hash = "sha256:03e2a7826086b91ef345ff18742ee9fc47a6839ccd517061ef8fa1976e652ce9"},
]
cryptography = [
{file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:73bc2d3f2444bcfeac67dd130ff2ea598ea5f20b40e36d19821b4df8c9c5037b"},
@@ -3533,6 +3543,27 @@ tomlkit = [
{file = "tomlkit-0.7.0-py2.py3-none-any.whl", hash = "sha256:6babbd33b17d5c9691896b0e68159215a9387ebfa938aa3ac42f4a4beeb2b831"},
{file = "tomlkit-0.7.0.tar.gz", hash = "sha256:ac57f29693fab3e309ea789252fcce3061e19110085aa31af5446ca749325618"},
]
torch = [
{file = "torch-1.10.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8f3fd2e3ffc3bb867133fdf7fbcc8a0bb2e62a5c0696396f51856f5abf9045a8"},
{file = "torch-1.10.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:258a0729fb77a3457d5822d84b536057cd119b08049a8d3c41dc3dcdeb48d56e"},
{file = "torch-1.10.2-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:935e5ac804c5093c79f23a7e6ca5b912c166071aa9d8b4a0a3d6a85126d6a47b"},
{file = "torch-1.10.2-cp36-cp36m-win_amd64.whl", hash = "sha256:65fd02ed889c63fd82bf1a440c5a94c1310c29f3e6f9f62add416d34da355d97"},
{file = "torch-1.10.2-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:6a81f886823bbd15edc2dc0908fa214070df61c9f7ab8831f0a03630275cca5a"},
{file = "torch-1.10.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:3eee3cf53c1f8fb3f1fe107a22025a8501fc6440d14e09599ba7153002531f84"},
{file = "torch-1.10.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ef99b8cca5f9358119b07956915faf6e7906f433ab4a603c160ae9de88918371"},
{file = "torch-1.10.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d43bc3f3a2d89ae185ef96d903c935c335219231e57685658648396984e2a67a"},
{file = "torch-1.10.2-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6da1b877880435440a5aa9678ef0f01986d4886416844db1d97ebfb7fd1778d0"},
{file = "torch-1.10.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ab77a9f838874f295ed5410c0686fa22547456e0116efb281c66ef5f9d46fe28"},
{file = "torch-1.10.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9ef4c004f9e5168bd1c1930c6aff25fed5b097de81db6271ffbb2e4fb8b89319"},
{file = "torch-1.10.2-cp38-cp38-win_amd64.whl", hash = "sha256:376fc18407add20daa6bbaaffc5a5e06d733abe53bcbd60ef2532bfed34bc091"},
{file = "torch-1.10.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:f281438ee99bd72ad65c0bba1026a32e45c3b636bc067fc145ad291e9ea2faab"},
{file = "torch-1.10.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:3592d3dd62b32760c82624e7586222747fe2281240e8653970b35f1d6d4a434c"},
{file = "torch-1.10.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:fbaf18c1b3e0b31af194a9d853e3739464cf982d279df9d34dd18f1c2a471878"},
{file = "torch-1.10.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:97b7b0c667e8b0dd1fc70137a36e0a4841ec10ef850bda60500ad066bef3e2de"},
{file = "torch-1.10.2-cp39-cp39-win_amd64.whl", hash = "sha256:901b52787baeb2e9e1357ca7037da0028bc6ad743f530e0040ae96ef8e27156c"},
{file = "torch-1.10.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:5b68e9108bd7ebd99eee941686046c517cfaac5331f757bcf440fe02f2e3ced1"},
{file = "torch-1.10.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b07ef01e36b716d0d65ca60c4db0ac9d094a0e797d9b55290da4dcda91463b6c"},
]
tornado = [
{file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"},
{file = "tornado-6.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:0d321a39c36e5f2c4ff12b4ed58d41390460f798422c4504e09eb5678e09998c"},

View File

@@ -45,7 +45,8 @@ pygraphviz = { version = "^1.7", optional = true }
Pillow = "^9.0.0"
loguru = "^0.5.3"
setuptools = "*"
concrete-compiler = "^0.2.0"
concrete-compiler = "^0.3.1"
torch = "^1.10.2"
[tool.poetry.extras]
full = ["pygraphviz"]

View File

@@ -0,0 +1,193 @@
"""Test file for convolution"""
import numpy as np
import pytest
import torch
from concrete.common.extensions import convolution
from concrete.common.representation.intermediate import Conv2D
from concrete.common.tracing.base_tracer import BaseTracer
from concrete.common.values.tensors import TensorValue
from concrete.numpy.tracing import NPConstant, NPTracer
@pytest.mark.parametrize(
"kwargs, error_msg",
[
pytest.param(
{"x": None, "weight": np.zeros(1)},
"input x must be an ndarray, or a BaseTracer, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": None},
"weight must be an ndarray, or a BaseTracer, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "bias": 0},
"bias must be an ndarray, a BaseTracer, or None, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "strides": None},
"strides must be a tuple, or list, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": None},
"dilations must be a tuple, or list, not a",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "pads": None},
"padding must be a tuple, or list, not a",
),
],
)
def test_invalid_arg_types(kwargs, error_msg):
"""Test function to make sure convolution doesn't accept invalid types"""
with pytest.raises(TypeError) as err:
convolution.conv2d(**kwargs)
assert error_msg in str(err)
@pytest.mark.parametrize(
"kwargs, error_msg",
[
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1)},
"input x should have size (N x C x H x W), not",
),
pytest.param(
{"x": np.zeros((1, 2, 3, 4)), "weight": np.zeros(1)},
"weight should have size (F x C x H x W), not",
),
pytest.param(
{
"x": np.zeros((1, 2, 3, 4)),
"weight": np.zeros((1, 2, 3, 4)),
"bias": np.zeros((1, 2)),
},
"bias should have size (F), not",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "strides": (1,)},
"strides should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "dilations": (1,)},
"dilations should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "pads": (1,)},
"padding should be of the form",
),
pytest.param(
{"x": np.zeros(1), "weight": np.zeros(1), "auto_pad": None},
"invalid auto_pad is specified",
),
],
)
def test_invalid_input_shape(kwargs, error_msg):
"""Test function to make sure convolution doesn't accept invalid shapes"""
with pytest.raises((ValueError, AssertionError)) as err:
convolution.conv2d(**kwargs)
assert error_msg in str(err)
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
],
)
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("use_ndarray", [True, False])
def test_tracing(input_shape, weight_shape, strides, dilations, has_bias, use_ndarray):
"""Test function to make sure tracong of conv2d works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
if not use_ndarray:
bias = NPTracer([], NPConstant(bias), 0)
else:
bias = None
x = NPTracer([], NPConstant(np.random.randint(0, 4, size=input_shape)), 0)
weight = np.random.randint(0, 4, size=weight_shape)
if not use_ndarray:
weight = NPTracer([], NPConstant(weight), 0)
output_tracer = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
traced_computation = output_tracer.traced_computation
assert isinstance(traced_computation, Conv2D)
if has_bias:
assert len(output_tracer.inputs) == 3
else:
assert len(output_tracer.inputs) == 2
assert all(
isinstance(input_, BaseTracer) for input_ in output_tracer.inputs
), f"{output_tracer.inputs}"
assert len(traced_computation.outputs) == 1
output_value = traced_computation.outputs[0]
assert isinstance(output_value, TensorValue) and output_value.is_encrypted
# pylint: disable=no-member
expected_shape = torch.conv2d(
torch.randn(input_shape),
torch.randn(weight_shape),
torch.randn((weight_shape[0])),
stride=strides,
dilation=dilations,
).shape
# pylint: enable=no-member
assert output_value.shape == expected_shape
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((3, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((1, 1, 4, 4), (3, 1, 2, 2)),
pytest.param((1, 3, 4, 4), (1, 3, 2, 2)),
pytest.param((4, 3, 4, 4), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 2, 2)),
pytest.param((4, 3, 16, 16), (3, 3, 3, 3)),
],
)
@pytest.mark.parametrize("strides", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1), (1, 2), (2, 1), (2, 2)])
@pytest.mark.parametrize("has_bias", [True, False])
def test_evaluation(input_shape, weight_shape, strides, dilations, has_bias):
"""Test function to make sure evaluation of conv2d on plain data works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
else:
bias = np.zeros((weight_shape[0],))
x = np.random.randint(0, 4, size=input_shape)
weight = np.random.randint(0, 4, size=weight_shape)
# pylint: disable=no-member
expected = torch.conv2d(
torch.tensor(x, dtype=torch.long),
torch.tensor(weight, dtype=torch.long),
torch.tensor(bias, dtype=torch.long),
stride=strides,
dilation=dilations,
).numpy()
# pylint: enable=no-member
# conv2d should handle None biases
if not has_bias:
bias = None
result = convolution.conv2d(x, weight, bias, strides=strides, dilations=dilations)
assert (result == expected).all()

View File

@@ -21,6 +21,7 @@ from concrete.common.representation.intermediate import (
ALL_IR_NODES,
Add,
Constant,
Conv2D,
Dot,
GenericFunction,
IndexConstant,
@@ -261,6 +262,11 @@ def is_equivalent_matmul(lhs: MatMul, rhs: object) -> bool:
return isinstance(rhs, MatMul) and is_equivalent_intermediate_node(lhs, rhs)
def is_equivalent_conv2d(lhs: Conv2D, rhs: object) -> bool:
"""Helper function to check if a Conv2D node is equivalent to an other object."""
return isinstance(rhs, Conv2D) and is_equivalent_intermediate_node(lhs, rhs)
def is_equivalent_intermediate_node(lhs: IntermediateNode, rhs: object) -> bool:
"""Helper function to check if an IntermediateNode node is equivalent to an other object."""
return (
@@ -274,6 +280,7 @@ EQUIVALENT_TEST_FUNC: Dict[Type, Callable[..., bool]] = {
Add: is_equivalent_add,
GenericFunction: is_equivalent_arbitrary_function,
Constant: is_equivalent_constant,
Conv2D: is_equivalent_conv2d,
Dot: is_equivalent_dot,
IndexConstant: is_equivalent_index_constant,
Input: is_equivalent_input,

View File

@@ -0,0 +1,44 @@
"""Test module for convolution compilation and execution."""
import numpy as np
import pytest
import concrete.numpy as hnp
from concrete.common.data_types.integers import Integer
from concrete.common.values.tensors import EncryptedTensor
from concrete.numpy.compile import compile_numpy_function
@pytest.mark.parametrize(
"input_shape, weight_shape",
[
pytest.param((1, 1, 4, 4), (1, 1, 2, 2)),
pytest.param((4, 3, 4, 4), (2, 3, 2, 2)),
],
)
@pytest.mark.parametrize("strides", [(2, 2)])
@pytest.mark.parametrize("dilations", [(1, 1)])
@pytest.mark.parametrize("has_bias", [True, False])
def test_compile_and_run(
input_shape, weight_shape, strides, dilations, has_bias, default_compilation_configuration
):
"""Test function to make sure compilation and execution of conv2d works properly"""
if has_bias:
bias = np.random.randint(0, 4, size=(weight_shape[0],))
else:
bias = None
weight = np.random.randint(0, 4, size=weight_shape)
def conv(x):
return hnp.conv2d(x, weight, bias, strides=strides, dilations=dilations)
compiler_engine = compile_numpy_function(
conv,
{"x": EncryptedTensor(Integer(64, False), input_shape)},
[np.random.randint(0, 4, size=input_shape) for i in range(20)],
default_compilation_configuration,
)
x = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
expected = conv(x)
result = compiler_engine.run(x)
assert (expected == result).all()