mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: add a Conv2d operation in FHELinalg
This commit is contained in:
@@ -104,3 +104,23 @@ using namespace mlir::linalg;
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace FHELinalg {
|
||||
|
||||
/// Get padding from the Conv2dOp if defined, or return default value
|
||||
mlir::SmallVector<int64_t, 4>
|
||||
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
|
||||
|
||||
/// Get strides from the Conv2dOp if defined, or return default value
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
|
||||
|
||||
/// Get dilations from the Conv2dOp if defined, or return default value
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
@@ -631,6 +632,23 @@ def ConcatOp : FHELinalg_Op<"concat"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
|
||||
let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW";
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$input,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$weight,
|
||||
Optional<Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>>:$bias,
|
||||
// Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
OptionalAttr<I64ElementsAttr>:$strides,
|
||||
OptionalAttr<I64ElementsAttr>:$dilations
|
||||
);
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::FHELinalg::verifyConv2d(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
||||
: Op<Linalg_Dialect, mnemonic, !listconcat([
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
|
||||
@@ -625,6 +625,253 @@ template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::SmallVector<int64_t, 4>
|
||||
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 4> paddingInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
|
||||
if (optionalPadding.hasValue()) {
|
||||
auto paddingAttr = optionalPadding.getValue();
|
||||
auto paddingAttrShape =
|
||||
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 &&
|
||||
"incorrect padding shape");
|
||||
paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin<int64_t>(),
|
||||
paddingAttr.value_end<int64_t>());
|
||||
} else {
|
||||
paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0});
|
||||
}
|
||||
return paddingInts;
|
||||
}
|
||||
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2> stridesInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
|
||||
if (optionalStrides.hasValue()) {
|
||||
auto stridesAttr = optionalStrides.getValue();
|
||||
auto stridesAttrShape =
|
||||
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 &&
|
||||
"incorrect strides shape");
|
||||
stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin<int64_t>(),
|
||||
stridesAttr.value_end<int64_t>());
|
||||
} else {
|
||||
stridesInts.insert(stridesInts.begin(), {1, 1});
|
||||
}
|
||||
return stridesInts;
|
||||
}
|
||||
|
||||
mlir::SmallVector<int64_t, 2>
|
||||
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
mlir::SmallVector<int64_t, 2> dilationsInts;
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
convOp.dilations();
|
||||
if (optionalDilations.hasValue()) {
|
||||
auto dilationsAttr = optionalDilations.getValue();
|
||||
auto dilationsAttrShape =
|
||||
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
||||
assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 &&
|
||||
"incorrect dilations shape");
|
||||
dilationsInts.insert(dilationsInts.begin(),
|
||||
dilationsAttr.value_begin<int64_t>(),
|
||||
dilationsAttr.value_end<int64_t>());
|
||||
} else {
|
||||
dilationsInts.insert(dilationsInts.begin(), {1, 1});
|
||||
}
|
||||
return dilationsInts;
|
||||
}
|
||||
|
||||
/// Verify the Conv2d shapes, attributes, and expected output dimensions
|
||||
mlir::LogicalResult
|
||||
verifyConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
||||
auto inputTy =
|
||||
((mlir::Type)convOp.input().getType()).cast<mlir::RankedTensorType>();
|
||||
auto weightTy =
|
||||
((mlir::Type)convOp.weight().getType()).cast<mlir::RankedTensorType>();
|
||||
auto resultTy =
|
||||
((mlir::Type)convOp.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
auto resultShape = resultTy.getShape();
|
||||
|
||||
auto p = inputTy.getElementType()
|
||||
.cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
auto weightElementTyWidth =
|
||||
weightTy.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
if (weightElementTyWidth != p + 1) {
|
||||
convOp.emitOpError() << "expected weight element type to have width "
|
||||
<< p + 1 << " but got " << weightElementTyWidth;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Checking dimensions
|
||||
if (inputShape.size() != 4) {
|
||||
convOp.emitOpError() << "input should have 4 dimensions (N*C*H*W) but got "
|
||||
<< inputShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (weightShape.size() != 4) {
|
||||
convOp.emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got "
|
||||
<< weightShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (resultShape.size() != 4) {
|
||||
convOp.emitOpError() << "result should have 4 dimensions (N*C*H*W) but got "
|
||||
<< resultShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Checking attributes
|
||||
mlir::SmallVector<int64_t, 4> paddingInts = getPaddingFromConv2d(convOp);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
|
||||
if (optionalPadding.hasValue()) {
|
||||
auto paddingAttr = optionalPadding.getValue();
|
||||
auto paddingAttrShape =
|
||||
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) {
|
||||
convOp.emitOpError()
|
||||
<< "padding should have a single dimension of size 4, but got shape ["
|
||||
<< paddingAttrShape << "]";
|
||||
return mlir::failure();
|
||||
}
|
||||
for (auto i = 0; i < 4; i++) {
|
||||
// TODO: Support padding (#427)
|
||||
if (paddingInts[i] != 0) {
|
||||
convOp.emitOpError()
|
||||
<< "padding isn't yet supported, but got a non zero value ("
|
||||
<< paddingInts[i] << ") at index " << i;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (paddingInts[i] < 0) {
|
||||
convOp.emitOpError() << "padding can't have a negative value, but got "
|
||||
<< paddingInts[i] << " at index " << i;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
mlir::SmallVector<int64_t, 2> stridesInts = getStridesFromConv2d(convOp);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
|
||||
if (optionalStrides.hasValue()) {
|
||||
auto stridesAttr = optionalStrides.getValue();
|
||||
auto stridesAttrShape =
|
||||
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
|
||||
convOp.emitOpError()
|
||||
<< "strides should have a single dimension of size 2, but got shape ["
|
||||
<< stridesAttrShape << "]";
|
||||
return mlir::failure();
|
||||
}
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (stridesInts[i] < 1) {
|
||||
convOp.emitOpError()
|
||||
<< "strides can't have a value less than 1, but got "
|
||||
<< stridesInts[i] << " at index " << i;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
mlir::SmallVector<int64_t, 2> dilationsInts = getDilationsFromConv2d(convOp);
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
||||
convOp.dilations();
|
||||
if (optionalDilations.hasValue()) {
|
||||
auto dilationsAttr = optionalDilations.getValue();
|
||||
auto dilationsAttrShape =
|
||||
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
||||
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
|
||||
convOp.emitOpError() << "dilations should have a single dimension of "
|
||||
"size 2, but got shape ["
|
||||
<< dilationsAttrShape << "]";
|
||||
return mlir::failure();
|
||||
}
|
||||
for (auto i = 0; i < 2; i++) {
|
||||
if (dilationsInts[i] < 1) {
|
||||
convOp.emitOpError()
|
||||
<< "dilations can't have a value less than 1, but got "
|
||||
<< dilationsInts[i] << " at index " << i;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extracting dimensions
|
||||
int64_t inputN = inputShape[0], inputC = inputShape[1],
|
||||
inputH = inputShape[2], inputW = inputShape[3];
|
||||
int64_t weightF = weightShape[0], weightC = weightShape[1],
|
||||
weightH = weightShape[2], weightW = weightShape[3];
|
||||
int64_t resultN = resultShape[0], resultC = resultShape[1],
|
||||
resultH = resultShape[2], resultW = resultShape[3];
|
||||
|
||||
// Bias check if specified
|
||||
mlir::Value bias = convOp.bias();
|
||||
if (bias) {
|
||||
auto biasTy = ((mlir::Type)bias.getType()).cast<mlir::RankedTensorType>();
|
||||
auto biasShape = biasTy.getShape();
|
||||
if (biasShape.size() != 1) {
|
||||
convOp.emitOpError() << "bias should have 1 dimension but got "
|
||||
<< biasShape.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (biasShape[0] != weightF) {
|
||||
convOp.emitOpError() << "expected bias vector to have size " << weightF
|
||||
<< " but got " << biasShape[0];
|
||||
return mlir::failure();
|
||||
}
|
||||
auto biasElementTyWidth =
|
||||
biasTy.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
if (biasElementTyWidth != p + 1) {
|
||||
convOp.emitOpError() << "expected bias element type to have width "
|
||||
<< p + 1 << " but got " << biasElementTyWidth;
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Dimension sizes checks
|
||||
if (resultN != inputN) {
|
||||
convOp.emitOpError()
|
||||
<< "expected result batch size to be equal to input batch size ("
|
||||
<< inputN << ") but got " << resultN;
|
||||
return mlir::failure();
|
||||
}
|
||||
if (inputC != weightC) {
|
||||
convOp.emitOpError() << "expected number of channels in weight to be equal "
|
||||
"to number of channels in input ("
|
||||
<< inputC << ") but got " << weightC;
|
||||
return mlir::failure();
|
||||
}
|
||||
if (weightF != resultC) {
|
||||
convOp.emitOpError() << "expected number of output channels to be equal to "
|
||||
"the number of filters ("
|
||||
<< weightF << ") but got " << resultC;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
int64_t paddingH = paddingInts[0] + paddingInts[2];
|
||||
int64_t paddingW = paddingInts[1] + paddingInts[3];
|
||||
int64_t dilationH = dilationsInts[0];
|
||||
int64_t dilationW = dilationsInts[1];
|
||||
int64_t strideH = stridesInts[0];
|
||||
int64_t strideW = stridesInts[1];
|
||||
int64_t expectedResultH =
|
||||
floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1;
|
||||
int64_t expectedResultW =
|
||||
floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1;
|
||||
|
||||
if (expectedResultH != resultH) {
|
||||
convOp.emitOpError() << "expected height of output to be equal to "
|
||||
<< expectedResultH << " but got " << resultH;
|
||||
return mlir::failure();
|
||||
}
|
||||
if (expectedResultW != resultW) {
|
||||
convOp.emitOpError() << "expected width of output to be equal to "
|
||||
<< expectedResultW << " but got " << resultW;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementation of FhelinalgConv2DNchwFchwOp
|
||||
// This is a generated functions from `make generate_conv_op`, and some helpers
|
||||
|
||||
@@ -270,3 +270,94 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) ->
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %1 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op padding isn't yet supported, but got a non zero value (1) at index 0}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[1,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 8 but got 15}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @conv2d(%input: tensor<101x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected result batch size to be equal to input batch size (101) but got 100}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<101x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x4x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to number of channels in input (3) but got 4}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x4x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi4>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected weight element type to have width 3 but got 4}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi4>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected bias element type to have width 3 but got 4}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x2x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 27 but got 15}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x2x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 27 but got 15}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 2 but got 15}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 2 but got 15}}
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,2]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -345,3 +345,36 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) ->
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK: func @conv2d(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @conv2d_without_attr(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @conv2d_without_attr(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias): (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @conv2d_without_bias(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user