feat: add a Conv2d operation in FHELinalg

This commit is contained in:
youben11
2022-02-15 11:11:09 +01:00
committed by Ayoub Benaissa
parent 78596f899f
commit 3668b2d73a
5 changed files with 409 additions and 0 deletions

View File

@@ -104,3 +104,23 @@ using namespace mlir::linalg;
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
#endif
namespace mlir {
namespace concretelang {
namespace FHELinalg {
/// Get padding from the Conv2dOp if defined, or return default value
mlir::SmallVector<int64_t, 4>
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
/// Get strides from the Conv2dOp if defined, or return default value
mlir::SmallVector<int64_t, 2>
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
/// Get dilations from the Conv2dOp if defined, or return default value
mlir::SmallVector<int64_t, 2>
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -3,6 +3,7 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -631,6 +632,23 @@ def ConcatOp : FHELinalg_Op<"concat"> {
}];
}
def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW";
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$input,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$weight,
Optional<Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>>:$bias,
// Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification
OptionalAttr<I64ElementsAttr>:$padding,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::FHELinalg::verifyConv2d(*this);
}];
}
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat([
SingleBlockImplicitTerminator<"YieldOp">,

View File

@@ -625,6 +625,253 @@ template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
return mlir::success();
}
mlir::SmallVector<int64_t, 4>
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 4> paddingInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
if (optionalPadding.hasValue()) {
auto paddingAttr = optionalPadding.getValue();
auto paddingAttrShape =
paddingAttr.getType().cast<RankedTensorType>().getShape();
assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 &&
"incorrect padding shape");
paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin<int64_t>(),
paddingAttr.value_end<int64_t>());
} else {
paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0});
}
return paddingInts;
}
mlir::SmallVector<int64_t, 2>
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2> stridesInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
if (optionalStrides.hasValue()) {
auto stridesAttr = optionalStrides.getValue();
auto stridesAttrShape =
stridesAttr.getType().cast<RankedTensorType>().getShape();
assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 &&
"incorrect strides shape");
stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin<int64_t>(),
stridesAttr.value_end<int64_t>());
} else {
stridesInts.insert(stridesInts.begin(), {1, 1});
}
return stridesInts;
}
mlir::SmallVector<int64_t, 2>
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
mlir::SmallVector<int64_t, 2> dilationsInts;
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
convOp.dilations();
if (optionalDilations.hasValue()) {
auto dilationsAttr = optionalDilations.getValue();
auto dilationsAttrShape =
dilationsAttr.getType().cast<RankedTensorType>().getShape();
assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 &&
"incorrect dilations shape");
dilationsInts.insert(dilationsInts.begin(),
dilationsAttr.value_begin<int64_t>(),
dilationsAttr.value_end<int64_t>());
} else {
dilationsInts.insert(dilationsInts.begin(), {1, 1});
}
return dilationsInts;
}
/// Verify the Conv2d shapes, attributes, and expected output dimensions
mlir::LogicalResult
verifyConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
auto inputTy =
((mlir::Type)convOp.input().getType()).cast<mlir::RankedTensorType>();
auto weightTy =
((mlir::Type)convOp.weight().getType()).cast<mlir::RankedTensorType>();
auto resultTy =
((mlir::Type)convOp.getResult().getType()).cast<mlir::RankedTensorType>();
auto inputShape = inputTy.getShape();
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
auto p = inputTy.getElementType()
.cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
auto weightElementTyWidth =
weightTy.getElementType().cast<mlir::IntegerType>().getWidth();
if (weightElementTyWidth != p + 1) {
convOp.emitOpError() << "expected weight element type to have width "
<< p + 1 << " but got " << weightElementTyWidth;
return mlir::failure();
}
// Checking dimensions
if (inputShape.size() != 4) {
convOp.emitOpError() << "input should have 4 dimensions (N*C*H*W) but got "
<< inputShape.size();
return mlir::failure();
}
if (weightShape.size() != 4) {
convOp.emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got "
<< weightShape.size();
return mlir::failure();
}
if (resultShape.size() != 4) {
convOp.emitOpError() << "result should have 4 dimensions (N*C*H*W) but got "
<< resultShape.size();
return mlir::failure();
}
// Checking attributes
mlir::SmallVector<int64_t, 4> paddingInts = getPaddingFromConv2d(convOp);
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
if (optionalPadding.hasValue()) {
auto paddingAttr = optionalPadding.getValue();
auto paddingAttrShape =
paddingAttr.getType().cast<RankedTensorType>().getShape();
if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) {
convOp.emitOpError()
<< "padding should have a single dimension of size 4, but got shape ["
<< paddingAttrShape << "]";
return mlir::failure();
}
for (auto i = 0; i < 4; i++) {
// TODO: Support padding (#427)
if (paddingInts[i] != 0) {
convOp.emitOpError()
<< "padding isn't yet supported, but got a non zero value ("
<< paddingInts[i] << ") at index " << i;
return mlir::failure();
}
if (paddingInts[i] < 0) {
convOp.emitOpError() << "padding can't have a negative value, but got "
<< paddingInts[i] << " at index " << i;
return mlir::failure();
}
}
}
mlir::SmallVector<int64_t, 2> stridesInts = getStridesFromConv2d(convOp);
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
if (optionalStrides.hasValue()) {
auto stridesAttr = optionalStrides.getValue();
auto stridesAttrShape =
stridesAttr.getType().cast<RankedTensorType>().getShape();
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
convOp.emitOpError()
<< "strides should have a single dimension of size 2, but got shape ["
<< stridesAttrShape << "]";
return mlir::failure();
}
for (auto i = 0; i < 2; i++) {
if (stridesInts[i] < 1) {
convOp.emitOpError()
<< "strides can't have a value less than 1, but got "
<< stridesInts[i] << " at index " << i;
return mlir::failure();
}
}
}
mlir::SmallVector<int64_t, 2> dilationsInts = getDilationsFromConv2d(convOp);
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
convOp.dilations();
if (optionalDilations.hasValue()) {
auto dilationsAttr = optionalDilations.getValue();
auto dilationsAttrShape =
dilationsAttr.getType().cast<RankedTensorType>().getShape();
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
convOp.emitOpError() << "dilations should have a single dimension of "
"size 2, but got shape ["
<< dilationsAttrShape << "]";
return mlir::failure();
}
for (auto i = 0; i < 2; i++) {
if (dilationsInts[i] < 1) {
convOp.emitOpError()
<< "dilations can't have a value less than 1, but got "
<< dilationsInts[i] << " at index " << i;
return mlir::failure();
}
}
}
// Extracting dimensions
int64_t inputN = inputShape[0], inputC = inputShape[1],
inputH = inputShape[2], inputW = inputShape[3];
int64_t weightF = weightShape[0], weightC = weightShape[1],
weightH = weightShape[2], weightW = weightShape[3];
int64_t resultN = resultShape[0], resultC = resultShape[1],
resultH = resultShape[2], resultW = resultShape[3];
// Bias check if specified
mlir::Value bias = convOp.bias();
if (bias) {
auto biasTy = ((mlir::Type)bias.getType()).cast<mlir::RankedTensorType>();
auto biasShape = biasTy.getShape();
if (biasShape.size() != 1) {
convOp.emitOpError() << "bias should have 1 dimension but got "
<< biasShape.size();
return mlir::failure();
}
if (biasShape[0] != weightF) {
convOp.emitOpError() << "expected bias vector to have size " << weightF
<< " but got " << biasShape[0];
return mlir::failure();
}
auto biasElementTyWidth =
biasTy.getElementType().cast<mlir::IntegerType>().getWidth();
if (biasElementTyWidth != p + 1) {
convOp.emitOpError() << "expected bias element type to have width "
<< p + 1 << " but got " << biasElementTyWidth;
return mlir::failure();
}
}
// Dimension sizes checks
if (resultN != inputN) {
convOp.emitOpError()
<< "expected result batch size to be equal to input batch size ("
<< inputN << ") but got " << resultN;
return mlir::failure();
}
if (inputC != weightC) {
convOp.emitOpError() << "expected number of channels in weight to be equal "
"to number of channels in input ("
<< inputC << ") but got " << weightC;
return mlir::failure();
}
if (weightF != resultC) {
convOp.emitOpError() << "expected number of output channels to be equal to "
"the number of filters ("
<< weightF << ") but got " << resultC;
return mlir::failure();
}
int64_t paddingH = paddingInts[0] + paddingInts[2];
int64_t paddingW = paddingInts[1] + paddingInts[3];
int64_t dilationH = dilationsInts[0];
int64_t dilationW = dilationsInts[1];
int64_t strideH = stridesInts[0];
int64_t strideW = stridesInts[1];
int64_t expectedResultH =
floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1;
int64_t expectedResultW =
floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1;
if (expectedResultH != resultH) {
convOp.emitOpError() << "expected height of output to be equal to "
<< expectedResultH << " but got " << resultH;
return mlir::failure();
}
if (expectedResultW != resultW) {
convOp.emitOpError() << "expected width of output to be equal to "
<< expectedResultW << " but got " << resultW;
return mlir::failure();
}
return mlir::success();
}
//===----------------------------------------------------------------------===//
// Implementation of FhelinalgConv2DNchwFchwOp
// This is a generated functions from `make generate_conv_op`, and some helpers

View File

@@ -270,3 +270,94 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) ->
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
return %1 : tensor<4x2x!FHE.eint<2>>
}
// -----
/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op padding isn't yet supported, but got a non zero value (1) at index 0}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[1,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 8 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<101x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected result batch size to be equal to input batch size (101) but got 100}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<101x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x4x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to number of channels in input (3) but got 4}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x4x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi4>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected weight element type to have width 3 but got 4}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi4>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected bias element type to have width 3 but got 4}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x2x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 27 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x2x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 27 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 2 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 2 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,2]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}

View File

@@ -345,3 +345,36 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) ->
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////
// CHECK: func @conv2d(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: }
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// CHECK: func @conv2d_without_attr(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: }
func @conv2d_without_attr(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias): (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// CHECK: func @conv2d_without_bias(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: }
func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}