From 3a4723a0b81aebfe079f48b0819ac914d0bd68ac Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 22 Mar 2022 09:28:34 +0100 Subject: [PATCH] feat: add FHELinalg.transpose operation --- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 30 ++++++++++++++++ .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 36 +++++++++++++++++++ .../Dialect/FHELinalg/FHELinalg/ops.mlir | 31 ++++++++++++++++ .../FHELinalg/transpose.invalid.mlir | 28 +++++++++++++++ 4 files changed, 125 insertions(+) create mode 100644 compiler/tests/Dialect/FHELinalg/FHELinalg/transpose.invalid.mlir diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index d09910869..f758b24f1 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -1001,5 +1001,35 @@ $_state.addAttribute("dilations", dilations); }]; } +def TransposeOp : FHELinalg_Op<"transpose", []> { + let summary = "Returns a tensor that contains the transposition of the input tensor."; + + let description = [{ + Performs a transpose operation on an N-dimensional tensor. + + ```mlir + "FHELinalg.transpose"(%a) : (tensor) -> tensor + ``` + + Examples: + ```mlir + // Transpose the input tensor + // [1,2] [1, 3, 5] + // [3,4] => [2, 4, 6] + // [5,6] + // + "FHELinalg.transpose"(%a) : (tensor<3x2xi7>) -> tensor<2x3xi7> + ``` + }]; + + // Args and result are both supposed to be of tensor of enccrypted integers, and the verifier does check that + let arguments = (ins AnyType:$tensor); + let results = (outs AnyType); + + let verifier = [{ + return ::mlir::concretelang::FHELinalg::verifyTranspose(*this); + }]; +} + #endif diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 2027745cc..e41d6f0d8 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1605,6 +1605,42 @@ void FhelinalgConv2DNchwFchwOp::getEffects( outputBuffers); } +/// Verify the transpose shapes +mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) { + mlir::Type tensorTy = ((mlir::Type)transposeOp.tensor().getType()); + if (!tensorTy.isa()) { + transposeOp->emitOpError() << "should have operand as tensor"; + return mlir::failure(); + } + mlir::Type resultTy = ((mlir::Type)transposeOp.getResult().getType()); + if (!resultTy.isa()) { + transposeOp->emitOpError() << "should have result as tensor"; + return mlir::failure(); + } + auto tensorShapedTy = tensorTy.dyn_cast_or_null(); + auto resultShapedTy = resultTy.dyn_cast_or_null(); + if (tensorShapedTy.getShape().size() != resultShapedTy.getShape().size()) { + transposeOp.emitOpError() + << "input and output tensors should have the same number of dimensions"; + return mlir::failure(); + } + if (tensorShapedTy.getElementType() != resultShapedTy.getElementType()) { + transposeOp.emitOpError() + << "input and output tensors should have the same element type"; + return mlir::failure(); + } + size_t n_dims = tensorShapedTy.getShape().size(); + for (size_t i = 0; i < n_dims; i++) { + if (tensorShapedTy.getDimSize(i) != + resultShapedTy.getDimSize(n_dims - (i + 1))) { + transposeOp.emitOpError() + << "output tensor should have inverted dimensions of input"; + return mlir::failure(); + } + } + return mlir::success(); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index 65558c74d..3a08ad05b 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -352,3 +352,34 @@ func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: ten %1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> return %1 : tensor<100x4x15x15x!FHE.eint<2>> } + +///////////////////////////////////////////////// +// FHELinalg.transpose +///////////////////////////////////////////////// + +// CHECK-LABEL: @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> +func @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<10x2x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> + return %c : tensor<10x2x!FHE.eint<6>> +} + +// CHECK-LABEL: @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6> +func @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6> + // CHECK-NEXT: return %[[v0]] : tensor<10x2xi6> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6> + return %c : tensor<10x2xi6> +} + +// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> +func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + return %c : tensor<5x4x3x!FHE.eint<6>> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/transpose.invalid.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/transpose.invalid.mlir new file mode 100644 index 000000000..5d48ff24d --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/transpose.invalid.mlir @@ -0,0 +1,28 @@ +// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s + +// Incompatible types +func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same element type}} + %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>> + return %c : tensor<5x4x3x!FHE.eint<7>> +} + +// ----- + +// Incompatible shapes +func @transpose_eint(%arg0: tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same number of dimensions}} + %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + return %c : tensor<5x4x3x!FHE.eint<6>> +} + +// ----- + +// Incompatible shapes +func @transpose_eint(%arg0: tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op output tensor should have inverted dimensions of input}} + %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> + return %c : tensor<5x4x3x!FHE.eint<6>> +} + +// ----- \ No newline at end of file