diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 5b9439155..fe50fb0ac 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -950,6 +950,13 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { let description = [{ Performs a transpose operation on an N-dimensional tensor. + Attributes: + + - axes: I64ArrayAttr = [] + list of dimension to perform the transposition + contains a permutation of [0,1,..,N-1] where N is the number of axes + think of it as a way to rearrange axes (see the example below) + ```mlir "FHELinalg.transpose"(%a) : (tensor) -> tensor ``` @@ -963,10 +970,14 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { // "FHELinalg.transpose"(%a) : (tensor<3x2xi7>) -> tensor<2x3xi7> ``` + + ```mlir + "FHELinalg.transpose"(%a) { axes = [1, 3, 0, 2] } : (tensor<2x3x4x5xi7>) -> tensor<3x5x2x4xi7> + ``` }]; - // Args and result are both supposed to be of tensor of enccrypted integers, and the verifier does check that - let arguments = (ins AnyType:$tensor); + // Args and result are both supposed to be of tensor of encrypted integers, and the verifier does check that + let arguments = (ins AnyType:$tensor, DefaultValuedAttr:$axes); let results = (outs AnyType); let hasVerifier = 1; diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 8975db510..aa6da0a4f 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1287,23 +1287,33 @@ struct TransposeToLinalgGeneric auto inputType = input.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); + auto n_dim = inputType.getShape().size(); + mlir::Location location = transposeOp.getLoc(); // Initialize empty tensor to fill with transpose result mlir::Value zeroTensor = rewriter.create(location, outputType).getResult(); - // Inverted dimensions to create a transposition std::vector perms = {}; - auto n_dim = inputType.getShape().size(); - for (int i = n_dim - 1; i >= 0; i--) - perms.push_back(i); + + mlir::ArrayAttr axes = transposeOp.axes(); + if (axes.empty()) { + for (int i = n_dim - 1; i >= 0; i--) { + perms.push_back(i); + } + } else { + for (mlir::Attribute axisAttribute : axes) { + int64_t axis = axisAttribute.cast().getInt(); + perms.push_back(axis); + } + } llvm::SmallVector resultTypes{zeroTensor.getType()}; auto ins = llvm::SmallVector{input}; auto outs = llvm::SmallVector{zeroTensor}; llvm::SmallVector maps{ - mlir::AffineMap::getPermutationMap(perms, this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()), + mlir::AffineMap::getPermutationMap(perms, this->getContext()), }; auto iteratorTypes = parallelIteratorType(n_dim); // The maps will be responsible for changing item positions, we just return diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index b8cd724d0..39635a3da 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1089,15 +1089,59 @@ mlir::LogicalResult TransposeOp::verify() { << "input and output tensors should have the same element type"; return mlir::failure(); } - size_t n_dims = tensorShapedTy.getShape().size(); - for (size_t i = 0; i < n_dims; i++) { - if (tensorShapedTy.getDimSize(i) != - resultShapedTy.getDimSize(n_dims - (i + 1))) { + + llvm::ArrayRef inShape = tensorShapedTy.getShape(); + llvm::ArrayRef outShape = resultShapedTy.getShape(); + + int64_t inputDimensions = (int64_t)inShape.size(); + + mlir::ArrayAttr axes = this->axes(); + if (axes.empty()) { + for (int64_t i = 0; i < inputDimensions; i++) { + if (inShape[i] != outShape[inputDimensions - (i + 1)]) { + this->emitOpError() + << "output tensor should have inverted dimensions of input"; + return mlir::failure(); + } + } + } else { + if (axes.size() != (size_t)inputDimensions) { + this->emitOpError() << "has invalid axes attribute (doesn't have " + << inputDimensions << " elements)"; + return mlir::failure(); + } + + auto seenAxes = std::unordered_set{}; + + size_t i = 0; + for (mlir::Attribute axisAttribute : axes) { + int64_t axis = axisAttribute.cast().getInt(); + + bool axisIsValid = (0 <= axis) && (axis < inputDimensions); + if (!axisIsValid) { + this->emitOpError() + << "has invalid axes attribute (axes[" << i << "] " + << "isn't in range [0, " << inputDimensions - 1 << "])"; + return mlir::failure(); + } + + seenAxes.insert(axis); + + if (outShape[i] != inShape[axis]) { + this->emitOpError() << "has invalid output shape (output.shape[" << i + << "] is not input.shape[axes[" << i << "]])"; + return mlir::failure(); + } + + i++; + } + if (seenAxes.size() != (size_t)inputDimensions) { this->emitOpError() - << "output tensor should have inverted dimensions of input"; + << "has invalid axes attribute (doesn't contain all input axes)"; return mlir::failure(); } } + return mlir::success(); } diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/transpose.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/transpose.mlir new file mode 100644 index 000000000..45c96a6df --- /dev/null +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/transpose.mlir @@ -0,0 +1,127 @@ +// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %0 = "FHELinalg.transpose"(%arg0) : (tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + return %0 : tensor<3x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x3x2x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<4x3x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<4x3x2x!FHE.eint<7>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>> { + %0 = "FHELinalg.transpose"(%arg0) : (tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>> + return %0 : tensor<4x3x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x5x!FHE.eint<6>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<4x3x5x!FHE.eint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor<4x3x5x!FHE.eint<6>> +// CHECK-NEXT: return %[[v1]] : tensor<4x3x5x!FHE.eint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> { + %0 = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> + return %0 : tensor<4x3x5x!FHE.eint<6>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x5x3x!FHE.eint<6>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<4x5x3x!FHE.eint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor<4x5x3x!FHE.eint<6>> +// CHECK-NEXT: return %[[v1]] : tensor<4x5x3x!FHE.eint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> { + %0 = "FHELinalg.transpose"(%arg0) { axes = [1, 2, 0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> + return %0 : tensor<4x5x3x!FHE.eint<6>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x5x4x!FHE.eint<6>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<3x5x4x!FHE.eint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor<3x5x4x!FHE.eint<6>> +// CHECK-NEXT: return %[[v1]] : tensor<3x5x4x!FHE.eint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> { + %0 = "FHELinalg.transpose"(%arg0) { axes = [0, 2, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> + return %0 : tensor<3x5x4x!FHE.eint<6>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x3x4x!FHE.eint<6>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<5x3x4x!FHE.eint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor<5x3x4x!FHE.eint<6>> +// CHECK-NEXT: return %[[v1]] : tensor<5x3x4x!FHE.eint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> { + %0 = "FHELinalg.transpose"(%arg0) { axes = [2, 0, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> + return %0 : tensor<5x3x4x!FHE.eint<6>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2, d1)> + +// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x2x4x3x!FHE.eint<6>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<5x2x4x3x!FHE.eint<6>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>): +// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6> +// CHECK-NEXT: } -> tensor<5x2x4x3x!FHE.eint<6>> +// CHECK-NEXT: return %[[v1]] : tensor<5x2x4x3x!FHE.eint<6>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> { + %0 = "FHELinalg.transpose"(%arg0) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> + return %0 : tensor<5x2x4x3x!FHE.eint<6>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir index 2f0ae3b21..8a457fe62 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir @@ -493,3 +493,48 @@ func.func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> return %c : tensor<5x4x3x!FHE.eint<6>> } + +// CHECK-LABEL: @transpose_eint_3D_axes_102(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> +func.func @transpose_eint_3D_axes_102(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [1, 0, 2]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<4x3x5x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> + return %c : tensor<4x3x5x!FHE.eint<6>> +} + +// CHECK-LABEL: @transpose_eint_3D_axes_120(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> +func.func @transpose_eint_3D_axes_120(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [1, 2, 0]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<4x5x3x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) { axes = [1, 2, 0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> + return %c : tensor<4x5x3x!FHE.eint<6>> +} + +// CHECK-LABEL: @transpose_eint_3D_axes_021(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> +func.func @transpose_eint_3D_axes_021(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [0, 2, 1]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<3x5x4x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) { axes = [0, 2, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> + return %c : tensor<3x5x4x!FHE.eint<6>> +} + +// CHECK-LABEL: @transpose_eint_3D_axes_201(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> +func.func @transpose_eint_3D_axes_201(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [2, 0, 1]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<5x3x4x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) { axes = [2, 0, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> + return %c : tensor<5x3x4x!FHE.eint<6>> +} + +// CHECK-LABEL: @transpose_eint_4D_axes_3021(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> +func.func @transpose_eint_4D_axes_3021(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> { + // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [3, 0, 2, 1]} : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> + // CHECK-NEXT: return %[[v0]] : tensor<5x2x4x3x!FHE.eint<6>> + // CHECK-NEXT: } + %c = "FHELinalg.transpose"(%arg0) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> + return %c : tensor<5x2x4x3x!FHE.eint<6>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/transpose.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/transpose.invalid.mlir index ba9609f44..d8c0cb0de 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/transpose.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/transpose.invalid.mlir @@ -26,3 +26,33 @@ func.func @transpose_eint(%arg0: tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FH } // ----- + +func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (doesn't have 3 elements)}} + %c = "FHELinalg.transpose"(%arg0) { axes = [0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> + return %c : tensor<4x3x5x!FHE.eint<6>> +} + +// ----- + +func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (axes[1] isn't in range [0, 2])}} + %c = "FHELinalg.transpose"(%arg0) { axes = [1, 5, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> + return %c : tensor<4x3x5x!FHE.eint<6>> +} + +// ----- + +func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x10x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op has invalid output shape (output.shape[2] is not input.shape[axes[2]])}} + %c = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x10x!FHE.eint<6>> + return %c : tensor<4x3x10x!FHE.eint<6>> +} + +// ----- + +func.func @transpose_eint(%arg0: tensor<2x2x2x!FHE.eint<6>>) -> tensor<2x2x2x!FHE.eint<6>> { + // expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (doesn't contain all input axes)}} + %c = "FHELinalg.transpose"(%arg0) { axes = [0, 1, 0] } : (tensor<2x2x2x!FHE.eint<6>>) -> tensor<2x2x2x!FHE.eint<6>> + return %c : tensor<2x2x2x!FHE.eint<6>> +} diff --git a/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml b/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml index 435cd9bd7..ed08c2e1b 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml @@ -236,6 +236,100 @@ tests: - tensor: [1, 3, 5, 2, 4, 6] shape: [2, 3] --- +description: transpose3d +program: | + func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x3x2x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input): (tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x3x2x!FHE.eint<6>> + return %1 : tensor<4x3x2x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + shape: [2, 3, 4] + outputs: + - tensor: [0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23] + shape: [4, 3, 2] +--- +description: transpose3d_axes_102 +program: | + func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x2x4x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input) { axes = [1, 0, 2] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x2x4x!FHE.eint<6>> + return %1 : tensor<3x2x4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + shape: [2, 3, 4] + outputs: + - tensor: [0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23] + shape: [3, 2, 4] +--- +description: transpose3d_axes_120 +program: | + func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x4x2x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input) { axes = [1, 2, 0] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x4x2x!FHE.eint<6>> + return %1 : tensor<3x4x2x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + shape: [2, 3, 4] + outputs: + - tensor: [0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23] + shape: [3, 4, 2] +--- +description: transpose3d_axes_021 +program: | + func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<2x4x3x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input) { axes = [0, 2, 1] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<2x4x3x!FHE.eint<6>> + return %1 : tensor<2x4x3x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + shape: [2, 3, 4] + outputs: + - tensor: [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23] + shape: [2, 4, 3] +--- +description: transpose3d_axes_201 +program: | + func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x2x3x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input) { axes = [2, 0, 1] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x2x3x!FHE.eint<6>> + return %1 : tensor<4x2x3x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + shape: [2, 3, 4] + outputs: + - tensor: [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23] + shape: [4, 2, 3] +--- +description: transpose4d_axes_3021 +program: | + func.func @main(%input: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> { + %1 = "FHELinalg.transpose"(%input) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> + return %1 : tensor<5x2x4x3x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, + 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, + 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 117, 118, 119] + shape: [2, 3, 4, 5] + outputs: + - tensor: [0, 20, 40, 5, 25, 45, 10, 30, 50, 15, 35, 55, 60, 80, 100, 65, 85, 105, 70, 90, 110, 75, 95, + 115, 1, 21, 41, 6, 26, 46, 11, 31, 51, 16, 36, 56, 61, 81, 101, 66, 86, 106, 71, 91, 111, 76, + 96, 116, 2, 22, 42, 7, 27, 47, 12, 32, 52, 17, 37, 57, 62, 82, 102, 67, 87, 107, 72, 92, 112, + 77, 97, 117, 3, 23, 43, 8, 28, 48, 13, 33, 53, 18, 38, 58, 63, 83, 103, 68, 88, 108, 73, 93, 113, + 78, 98, 118, 4, 24, 44, 9, 29, 49, 14, 34, 54, 19, 39, 59, 64, 84, 104, 69, 89, 109, 74, 94, 114, + 79, 99, 119] + shape: [5, 2, 4, 3] +--- description: conv2dWithGroup1C program: | func.func @main(%input: tensor<1x6x4x4x!FHE.eint<5>>, %weight: tensor<6x1x2x2xi6>) -> tensor<1x6x3x3x!FHE.eint<5>> {