mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement extensions module
This commit is contained in:
6
concrete/numpy/extensions/__init__.py
Normal file
6
concrete/numpy/extensions/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.extensions` namespace.
|
||||
"""
|
||||
|
||||
from .convolution import conv2d
|
||||
from .table import LookupTable
|
||||
231
concrete/numpy/extensions/convolution.py
Normal file
231
concrete/numpy/extensions/convolution.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Declaration of `conv2d` function.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import EncryptedTensor
|
||||
|
||||
SUPPORTED_AUTO_PAD = {
|
||||
"NOTSET",
|
||||
}
|
||||
|
||||
|
||||
def conv2d(
|
||||
x: Union[np.ndarray, Tracer],
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]] = None,
|
||||
pads: Union[Tuple[int, int, int, int], List[int]] = (0, 0, 0, 0),
|
||||
strides: Union[Tuple[int, int], List[int]] = (1, 1),
|
||||
dilations: Union[Tuple[int, int], List[int]] = (1, 1),
|
||||
auto_pad: str = "NOTSET",
|
||||
) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Trace or evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (Union[np.ndarray, Tracer]):
|
||||
input of shape (N, C, H, W)
|
||||
|
||||
weight (Union[np.ndarray, Tracer]):
|
||||
kernel of shape (F, C, H, W)
|
||||
|
||||
bias (Optional[Union[np.ndarray, Tracer]], default = None):
|
||||
bias of shape (F,)
|
||||
|
||||
pads (Union[Tuple[int, int, int, int], List[int]], default = (0, 0, 0, 0)):
|
||||
padding over each height and width (H_beg, W_beg, H_end, W_end)
|
||||
|
||||
strides (Union[Tuple[int, int], List[int]], default = (1, 1)):
|
||||
stride over height and width
|
||||
|
||||
dilations (Union[Tuple[int, int], List[int]], default = (1, 1)):
|
||||
dilation over height and width
|
||||
|
||||
auto_pad (str, default = "NOTSET"):
|
||||
padding strategy
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
evaluation result or traced computation
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if arguments are not appropriate
|
||||
"""
|
||||
|
||||
if auto_pad not in SUPPORTED_AUTO_PAD:
|
||||
raise ValueError(f"Auto pad should be in {SUPPORTED_AUTO_PAD} but it's {repr(auto_pad)}")
|
||||
|
||||
if len(pads) != 4:
|
||||
raise ValueError(
|
||||
f"Pads should be of form "
|
||||
f"(height_begin_pad, width_begin_pad, height_end_pad, width_end_pad) "
|
||||
f"but it's {pads}"
|
||||
)
|
||||
if len(strides) != 2:
|
||||
raise ValueError(
|
||||
f"Strides should be of form (height_stride, width_stride) but it's {strides}"
|
||||
)
|
||||
if len(dilations) != 2:
|
||||
raise ValueError(
|
||||
f"Dilations should be of form "
|
||||
f"(height_dilation, width_dilation) "
|
||||
f"but it's {dilations}"
|
||||
)
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
return _trace_conv2d(x, weight, bias, pads, strides, dilations)
|
||||
|
||||
if not isinstance(weight, np.ndarray):
|
||||
raise ValueError("Weight should be of type np.ndarray for evaluation")
|
||||
|
||||
if bias is not None and not isinstance(bias, np.ndarray):
|
||||
raise ValueError("Bias should be of type np.ndarray for evaluation")
|
||||
|
||||
bias = np.zeros(weight.shape[0]) if bias is None else bias
|
||||
return _evaluate_conv2d(x, weight, bias, pads, strides, dilations)
|
||||
|
||||
|
||||
def _trace_conv2d(
|
||||
x: Tracer,
|
||||
weight: Union[np.ndarray, Tracer],
|
||||
bias: Optional[Union[np.ndarray, Tracer]],
|
||||
pads: Union[Tuple[int, int, int, int], List[int]],
|
||||
strides: Union[Tuple[int, int], List[int]],
|
||||
dilations: Union[Tuple[int, int], List[int]],
|
||||
) -> Tracer:
|
||||
"""
|
||||
Trace 2D convolution.
|
||||
|
||||
Args:
|
||||
x (Tracer):
|
||||
input of shape (N, C, H, W)
|
||||
|
||||
weight (Union[np.ndarray, Tracer]):
|
||||
kernel of shape (F, C, H, W)
|
||||
|
||||
bias (Optional[Union[np.ndarray, Tracer]]):
|
||||
bias of shape (F,)
|
||||
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]):
|
||||
padding over each axis (H_beg, W_beg, H_end, W_end)
|
||||
|
||||
strides (Union[Tuple[int, int], List[int]]):
|
||||
stride over height and width
|
||||
|
||||
dilations (Union[Tuple[int, int], List[int]]):
|
||||
dilation over height and width
|
||||
|
||||
Returns:
|
||||
Tracer:
|
||||
traced computation
|
||||
"""
|
||||
|
||||
if x.output.ndim != 4:
|
||||
raise ValueError(
|
||||
f"Input should be of shape (N, C, H, W) but it's of shape {x.output.shape}",
|
||||
)
|
||||
|
||||
weight = weight if isinstance(weight, Tracer) else Tracer(Node.constant(weight), [])
|
||||
|
||||
if weight.output.ndim != 4:
|
||||
raise ValueError(
|
||||
f"Weight should be of shape (F, C, H, W) but it's of shape {weight.output.shape}",
|
||||
)
|
||||
|
||||
input_values = [x.output, weight.output]
|
||||
inputs = [x, weight]
|
||||
|
||||
if bias is not None:
|
||||
bias = bias if isinstance(bias, Tracer) else Tracer(Node.constant(bias), [])
|
||||
input_values.append(bias.output)
|
||||
inputs.append(bias)
|
||||
if bias.output.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Bias should be of shape (F,) but it's of shape {bias.output.shape}",
|
||||
)
|
||||
|
||||
input_n, _, input_h, input_w = x.output.shape
|
||||
weight_f, _, weight_h, weight_w = weight.output.shape
|
||||
|
||||
pads_h = pads[0] + pads[2]
|
||||
pads_w = pads[1] + pads[3]
|
||||
|
||||
output_h = math.floor((input_h + pads_h - dilations[0] * (weight_h - 1) - 1) / strides[0]) + 1
|
||||
output_w = math.floor((input_w + pads_w - dilations[1] * (weight_w - 1) - 1) / strides[1]) + 1
|
||||
|
||||
output_shape = (input_n, weight_f, output_h, output_w)
|
||||
output_value = EncryptedTensor(dtype=x.output.dtype, shape=output_shape)
|
||||
|
||||
computation = Node.generic(
|
||||
"conv2d",
|
||||
input_values,
|
||||
output_value,
|
||||
_evaluate_conv2d,
|
||||
args=() if bias is not None else (np.zeros(weight.output.shape[0], dtype=np.int64),),
|
||||
kwargs={"pads": pads, "strides": strides, "dilations": dilations},
|
||||
)
|
||||
return Tracer(computation, inputs)
|
||||
|
||||
|
||||
def _evaluate_conv2d(
|
||||
x: np.ndarray,
|
||||
weight: np.ndarray,
|
||||
bias: np.ndarray,
|
||||
pads: Union[Tuple[int, int, int, int], List[int]], # pylint: disable=unused-argument
|
||||
strides: Union[Tuple[int, int], List[int]],
|
||||
dilations: Union[Tuple[int, int], List[int]],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Evaluate 2D convolution.
|
||||
|
||||
Args:
|
||||
x (np.ndarray):
|
||||
input of shape (N, C, H, W)
|
||||
|
||||
weight (np.ndarray):
|
||||
kernel of shape (F, C, H, W)
|
||||
|
||||
bias (np.ndarray):
|
||||
bias of shape (F,)
|
||||
|
||||
pads (Union[Tuple[int, int, int, int], List[int]]):
|
||||
padding over each axis (H_beg, W_beg, H_end, W_end)
|
||||
|
||||
strides (Union[Tuple[int, int], List[int]]):
|
||||
stride over height and width
|
||||
|
||||
dilations (Union[Tuple[int, int], List[int]]):
|
||||
dilation over height and width
|
||||
|
||||
Returns:
|
||||
np.ndarray:
|
||||
result of the convolution
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member
|
||||
|
||||
dtype = (
|
||||
torch.float64
|
||||
if np.issubdtype(x.dtype, np.floating)
|
||||
or np.issubdtype(weight.dtype, np.floating)
|
||||
or np.issubdtype(bias.dtype, np.floating)
|
||||
else torch.long
|
||||
)
|
||||
|
||||
return torch.conv2d(
|
||||
torch.tensor(x, dtype=dtype),
|
||||
torch.tensor(weight, dtype=dtype),
|
||||
torch.tensor(bias, dtype=dtype),
|
||||
stride=strides,
|
||||
dilation=dilations,
|
||||
).numpy()
|
||||
|
||||
# pylint: enable=no-member
|
||||
129
concrete/numpy/extensions/table.py
Normal file
129
concrete/numpy/extensions/table.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Declaration of `LookupTable` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import BaseDataType, Integer
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
|
||||
|
||||
class LookupTable:
|
||||
"""
|
||||
LookupTable class, to provide a way to do direct table lookups.
|
||||
"""
|
||||
|
||||
table: np.ndarray
|
||||
output_dtype: BaseDataType
|
||||
|
||||
def __init__(self, table: Any):
|
||||
is_valid = True
|
||||
try:
|
||||
self.table = table if isinstance(table, np.ndarray) else np.array(table)
|
||||
except Exception: # pragma: no cover # pylint: disable=broad-except
|
||||
# here we try our best to convert the table to np.ndarray
|
||||
# if it fails we raise the exception at the end of the function
|
||||
is_valid = False
|
||||
|
||||
if is_valid:
|
||||
is_valid = self.table.size > 0
|
||||
|
||||
if is_valid:
|
||||
minimum: int = 0
|
||||
maximum: int = 0
|
||||
|
||||
if np.issubdtype(self.table.dtype, np.integer):
|
||||
minimum = int(self.table.min())
|
||||
maximum = int(self.table.max())
|
||||
if self.table.ndim != 1:
|
||||
is_valid = False
|
||||
else:
|
||||
is_valid = all(isinstance(item, LookupTable) for item in self.table.flat)
|
||||
if is_valid:
|
||||
minimum = int(self.table.flat[0].table.min())
|
||||
maximum = int(self.table.flat[0].table.max())
|
||||
for item in self.table.flat:
|
||||
minimum = min(minimum, item.table.min())
|
||||
maximum = max(maximum, item.table.max())
|
||||
|
||||
self.output_dtype = Integer.that_can_represent([minimum, maximum])
|
||||
|
||||
if not is_valid:
|
||||
raise ValueError(f"LookupTable cannot be constructed with {repr(table)}")
|
||||
|
||||
def __repr__(self):
|
||||
return str(list(self.table))
|
||||
|
||||
def __getitem__(self, key: Union[int, np.integer, np.ndarray, Tracer]):
|
||||
if not isinstance(key, Tracer):
|
||||
return LookupTable.apply(key, self.table)
|
||||
|
||||
if not isinstance(key.output.dtype, Integer):
|
||||
raise ValueError(f"LookupTable cannot be looked up with {key.output}")
|
||||
|
||||
table = self.table
|
||||
if not np.issubdtype(self.table.dtype, np.integer):
|
||||
try:
|
||||
table = np.broadcast_to(table, key.output.shape)
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"LookupTable of shape {self.table.shape} "
|
||||
f"cannot be looked up with {key.output}"
|
||||
) from error
|
||||
|
||||
output = deepcopy(key.output)
|
||||
output.dtype = self.output_dtype
|
||||
|
||||
computation = Node.generic(
|
||||
"tlu",
|
||||
[key.output],
|
||||
output,
|
||||
LookupTable.apply,
|
||||
kwargs={"table": table},
|
||||
)
|
||||
return Tracer(computation, [key])
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
key: Union[int, np.integer, np.ndarray],
|
||||
table: np.ndarray,
|
||||
) -> Union[int, np.integer, np.ndarray]:
|
||||
"""
|
||||
Apply lookup table.
|
||||
|
||||
Args:
|
||||
key (Union[int, np.integer, np.ndarray]):
|
||||
lookup key
|
||||
|
||||
table (np.ndarray):
|
||||
lookup table
|
||||
|
||||
Returns:
|
||||
Union[int, np.integer, np.ndarray]:
|
||||
lookup result
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
if `table` cannot be looked up with `key`
|
||||
"""
|
||||
|
||||
if not isinstance(key, (int, np.integer, np.ndarray)) or (
|
||||
isinstance(key, np.ndarray) and not np.issubdtype(key.dtype, np.integer)
|
||||
):
|
||||
raise ValueError(f"LookupTable cannot be looked up with {key}")
|
||||
|
||||
if np.issubdtype(table.dtype, np.integer):
|
||||
return table[key]
|
||||
|
||||
if not isinstance(key, np.ndarray) or key.shape != table.shape:
|
||||
raise ValueError(f"LookupTable of shape {table.shape} cannot be looked up with {key}")
|
||||
|
||||
flat_result = np.fromiter(
|
||||
(lt.table[k] for lt, k in zip(table.flat, key.flat)),
|
||||
dtype=np.longlong,
|
||||
)
|
||||
return flat_result.reshape(table.shape)
|
||||
Reference in New Issue
Block a user