diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index 3f6bbfd46..90ab3dd94 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -104,3 +104,23 @@ using namespace mlir::linalg; #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc" #endif + +namespace mlir { +namespace concretelang { +namespace FHELinalg { + +/// Get padding from the Conv2dOp if defined, or return default value +mlir::SmallVector +getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); + +/// Get strides from the Conv2dOp if defined, or return default value +mlir::SmallVector +getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); + +/// Get dilations from the Conv2dOp if defined, or return default value +mlir::SmallVector +getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); + +} // namespace FHELinalg +} // namespace concretelang +} // namespace mlir diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 670c0a87c..0cecd3f7b 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -3,6 +3,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -631,6 +632,23 @@ def ConcatOp : FHELinalg_Op<"concat"> { }]; } +def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { + let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW"; + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$input, + Type.predicate, HasStaticShapePred]>>:$weight, + Optional.predicate, HasStaticShapePred]>>>:$bias, + // Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification + OptionalAttr:$padding, + OptionalAttr:$strides, + OptionalAttr:$dilations + ); + let results = (outs Type.predicate, HasStaticShapePred]>>); + let verifier = [{ + return ::mlir::concretelang::FHELinalg::verifyConv2d(*this); + }]; +} + class LinalgStructuredBase_Op props> : Op, diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 74b1c1296..176b65107 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -625,6 +625,253 @@ template mlir::LogicalResult verifyMatmul(MatMulOp &op) { return mlir::success(); } +mlir::SmallVector +getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { + mlir::SmallVector paddingInts; + llvm::Optional optionalPadding = convOp.padding(); + if (optionalPadding.hasValue()) { + auto paddingAttr = optionalPadding.getValue(); + auto paddingAttrShape = + paddingAttr.getType().cast().getShape(); + assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 && + "incorrect padding shape"); + paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin(), + paddingAttr.value_end()); + } else { + paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0}); + } + return paddingInts; +} + +mlir::SmallVector +getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { + mlir::SmallVector stridesInts; + llvm::Optional optionalStrides = convOp.strides(); + if (optionalStrides.hasValue()) { + auto stridesAttr = optionalStrides.getValue(); + auto stridesAttrShape = + stridesAttr.getType().cast().getShape(); + assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 && + "incorrect strides shape"); + stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin(), + stridesAttr.value_end()); + } else { + stridesInts.insert(stridesInts.begin(), {1, 1}); + } + return stridesInts; +} + +mlir::SmallVector +getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { + mlir::SmallVector dilationsInts; + llvm::Optional optionalDilations = + convOp.dilations(); + if (optionalDilations.hasValue()) { + auto dilationsAttr = optionalDilations.getValue(); + auto dilationsAttrShape = + dilationsAttr.getType().cast().getShape(); + assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 && + "incorrect dilations shape"); + dilationsInts.insert(dilationsInts.begin(), + dilationsAttr.value_begin(), + dilationsAttr.value_end()); + } else { + dilationsInts.insert(dilationsInts.begin(), {1, 1}); + } + return dilationsInts; +} + +/// Verify the Conv2d shapes, attributes, and expected output dimensions +mlir::LogicalResult +verifyConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { + auto inputTy = + ((mlir::Type)convOp.input().getType()).cast(); + auto weightTy = + ((mlir::Type)convOp.weight().getType()).cast(); + auto resultTy = + ((mlir::Type)convOp.getResult().getType()).cast(); + auto inputShape = inputTy.getShape(); + auto weightShape = weightTy.getShape(); + auto resultShape = resultTy.getShape(); + + auto p = inputTy.getElementType() + .cast() + .getWidth(); + auto weightElementTyWidth = + weightTy.getElementType().cast().getWidth(); + if (weightElementTyWidth != p + 1) { + convOp.emitOpError() << "expected weight element type to have width " + << p + 1 << " but got " << weightElementTyWidth; + return mlir::failure(); + } + + // Checking dimensions + if (inputShape.size() != 4) { + convOp.emitOpError() << "input should have 4 dimensions (N*C*H*W) but got " + << inputShape.size(); + return mlir::failure(); + } + if (weightShape.size() != 4) { + convOp.emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got " + << weightShape.size(); + return mlir::failure(); + } + if (resultShape.size() != 4) { + convOp.emitOpError() << "result should have 4 dimensions (N*C*H*W) but got " + << resultShape.size(); + return mlir::failure(); + } + + // Checking attributes + mlir::SmallVector paddingInts = getPaddingFromConv2d(convOp); + llvm::Optional optionalPadding = convOp.padding(); + if (optionalPadding.hasValue()) { + auto paddingAttr = optionalPadding.getValue(); + auto paddingAttrShape = + paddingAttr.getType().cast().getShape(); + if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) { + convOp.emitOpError() + << "padding should have a single dimension of size 4, but got shape [" + << paddingAttrShape << "]"; + return mlir::failure(); + } + for (auto i = 0; i < 4; i++) { + // TODO: Support padding (#427) + if (paddingInts[i] != 0) { + convOp.emitOpError() + << "padding isn't yet supported, but got a non zero value (" + << paddingInts[i] << ") at index " << i; + return mlir::failure(); + } + + if (paddingInts[i] < 0) { + convOp.emitOpError() << "padding can't have a negative value, but got " + << paddingInts[i] << " at index " << i; + return mlir::failure(); + } + } + } + mlir::SmallVector stridesInts = getStridesFromConv2d(convOp); + llvm::Optional optionalStrides = convOp.strides(); + if (optionalStrides.hasValue()) { + auto stridesAttr = optionalStrides.getValue(); + auto stridesAttrShape = + stridesAttr.getType().cast().getShape(); + if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) { + convOp.emitOpError() + << "strides should have a single dimension of size 2, but got shape [" + << stridesAttrShape << "]"; + return mlir::failure(); + } + for (auto i = 0; i < 2; i++) { + if (stridesInts[i] < 1) { + convOp.emitOpError() + << "strides can't have a value less than 1, but got " + << stridesInts[i] << " at index " << i; + return mlir::failure(); + } + } + } + mlir::SmallVector dilationsInts = getDilationsFromConv2d(convOp); + llvm::Optional optionalDilations = + convOp.dilations(); + if (optionalDilations.hasValue()) { + auto dilationsAttr = optionalDilations.getValue(); + auto dilationsAttrShape = + dilationsAttr.getType().cast().getShape(); + if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) { + convOp.emitOpError() << "dilations should have a single dimension of " + "size 2, but got shape [" + << dilationsAttrShape << "]"; + return mlir::failure(); + } + for (auto i = 0; i < 2; i++) { + if (dilationsInts[i] < 1) { + convOp.emitOpError() + << "dilations can't have a value less than 1, but got " + << dilationsInts[i] << " at index " << i; + return mlir::failure(); + } + } + } + + // Extracting dimensions + int64_t inputN = inputShape[0], inputC = inputShape[1], + inputH = inputShape[2], inputW = inputShape[3]; + int64_t weightF = weightShape[0], weightC = weightShape[1], + weightH = weightShape[2], weightW = weightShape[3]; + int64_t resultN = resultShape[0], resultC = resultShape[1], + resultH = resultShape[2], resultW = resultShape[3]; + + // Bias check if specified + mlir::Value bias = convOp.bias(); + if (bias) { + auto biasTy = ((mlir::Type)bias.getType()).cast(); + auto biasShape = biasTy.getShape(); + if (biasShape.size() != 1) { + convOp.emitOpError() << "bias should have 1 dimension but got " + << biasShape.size(); + return mlir::failure(); + } + if (biasShape[0] != weightF) { + convOp.emitOpError() << "expected bias vector to have size " << weightF + << " but got " << biasShape[0]; + return mlir::failure(); + } + auto biasElementTyWidth = + biasTy.getElementType().cast().getWidth(); + if (biasElementTyWidth != p + 1) { + convOp.emitOpError() << "expected bias element type to have width " + << p + 1 << " but got " << biasElementTyWidth; + return mlir::failure(); + } + } + + // Dimension sizes checks + if (resultN != inputN) { + convOp.emitOpError() + << "expected result batch size to be equal to input batch size (" + << inputN << ") but got " << resultN; + return mlir::failure(); + } + if (inputC != weightC) { + convOp.emitOpError() << "expected number of channels in weight to be equal " + "to number of channels in input (" + << inputC << ") but got " << weightC; + return mlir::failure(); + } + if (weightF != resultC) { + convOp.emitOpError() << "expected number of output channels to be equal to " + "the number of filters (" + << weightF << ") but got " << resultC; + return mlir::failure(); + } + + int64_t paddingH = paddingInts[0] + paddingInts[2]; + int64_t paddingW = paddingInts[1] + paddingInts[3]; + int64_t dilationH = dilationsInts[0]; + int64_t dilationW = dilationsInts[1]; + int64_t strideH = stridesInts[0]; + int64_t strideW = stridesInts[1]; + int64_t expectedResultH = + floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1; + int64_t expectedResultW = + floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1; + + if (expectedResultH != resultH) { + convOp.emitOpError() << "expected height of output to be equal to " + << expectedResultH << " but got " << resultH; + return mlir::failure(); + } + if (expectedResultW != resultW) { + convOp.emitOpError() << "expected width of output to be equal to " + << expectedResultW << " but got " << resultW; + return mlir::failure(); + } + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // Implementation of FhelinalgConv2DNchwFchwOp // This is a generated functions from `make generate_conv_op`, and some helpers diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir index 6fd4d88df..97d060450 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.invalid.mlir @@ -270,3 +270,94 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) -> %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> return %1 : tensor<4x2x!FHE.eint<2>> } + +// ----- + +///////////////////////////////////////////////// +// FHELinalg.conv2d +///////////////////////////////////////////////// + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op padding isn't yet supported, but got a non zero value (1) at index 0}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[1,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 8 but got 15}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + + +// ----- + + +func @conv2d(%input: tensor<101x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected result batch size to be equal to input batch size (101) but got 100}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<101x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x4x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to number of channels in input (3) but got 4}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x4x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi4>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected weight element type to have width 3 but got 4}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi4>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected bias element type to have width 3 but got 4}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi4>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x2x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 27 but got 15}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x2x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + + +// ----- + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x2xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 27 but got 15}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x2xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + + +// ----- + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 2 but got 15}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected width of output to be equal to 2 but got 15}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,2]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index 1d4577f50..a5f869df3 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -345,3 +345,36 @@ func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!FHE.eint<2>>) -> %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } + +///////////////////////////////////////////////// +// FHELinalg.conv2d +///////////////////////////////////////////////// + +// CHECK: func @conv2d(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: } +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + + +// CHECK: func @conv2d_without_attr(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: } +func @conv2d_without_attr(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + %1 = "FHELinalg.conv2d"(%input, %weight, %bias): (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + + +// CHECK: func @conv2d_without_bias(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: } +func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + %1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +}