feat: add axes argument to transpose

This commit is contained in:
Umut
2022-10-14 16:57:12 +02:00
parent 1d1dfc6b2b
commit 5f845bf9ff
7 changed files with 373 additions and 12 deletions

View File

@@ -950,6 +950,13 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
let description = [{
Performs a transpose operation on an N-dimensional tensor.
Attributes:
- axes: I64ArrayAttr = []
list of dimension to perform the transposition
contains a permutation of [0,1,..,N-1] where N is the number of axes
think of it as a way to rearrange axes (see the example below)
```mlir
"FHELinalg.transpose"(%a) : (tensor<n0xn1x...xnNxType>) -> tensor<nNx...xn1xn0xType>
```
@@ -963,10 +970,14 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
//
"FHELinalg.transpose"(%a) : (tensor<3x2xi7>) -> tensor<2x3xi7>
```
```mlir
"FHELinalg.transpose"(%a) { axes = [1, 3, 0, 2] } : (tensor<2x3x4x5xi7>) -> tensor<3x5x2x4xi7>
```
}];
// Args and result are both supposed to be of tensor of enccrypted integers, and the verifier does check that
let arguments = (ins AnyType:$tensor);
// Args and result are both supposed to be of tensor of encrypted integers, and the verifier does check that
let arguments = (ins AnyType:$tensor, DefaultValuedAttr<I64ArrayAttr, "{}">:$axes);
let results = (outs AnyType);
let hasVerifier = 1;

View File

@@ -1287,23 +1287,33 @@ struct TransposeToLinalgGeneric
auto inputType = input.getType().dyn_cast<mlir::RankedTensorType>();
auto outputType = output.getType().dyn_cast<mlir::RankedTensorType>();
auto n_dim = inputType.getShape().size();
mlir::Location location = transposeOp.getLoc();
// Initialize empty tensor to fill with transpose result
mlir::Value zeroTensor =
rewriter.create<FHE::ZeroTensorOp>(location, outputType).getResult();
// Inverted dimensions to create a transposition
std::vector<unsigned int> perms = {};
auto n_dim = inputType.getShape().size();
for (int i = n_dim - 1; i >= 0; i--)
perms.push_back(i);
mlir::ArrayAttr axes = transposeOp.axes();
if (axes.empty()) {
for (int i = n_dim - 1; i >= 0; i--) {
perms.push_back(i);
}
} else {
for (mlir::Attribute axisAttribute : axes) {
int64_t axis = axisAttribute.cast<mlir::IntegerAttr>().getInt();
perms.push_back(axis);
}
}
llvm::SmallVector<mlir::Type, 1> resultTypes{zeroTensor.getType()};
auto ins = llvm::SmallVector<mlir::Value, 1>{input};
auto outs = llvm::SmallVector<mlir::Value, 1>{zeroTensor};
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getPermutationMap(perms, this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(n_dim, this->getContext()),
mlir::AffineMap::getPermutationMap(perms, this->getContext()),
};
auto iteratorTypes = parallelIteratorType(n_dim);
// The maps will be responsible for changing item positions, we just return

View File

@@ -1089,15 +1089,59 @@ mlir::LogicalResult TransposeOp::verify() {
<< "input and output tensors should have the same element type";
return mlir::failure();
}
size_t n_dims = tensorShapedTy.getShape().size();
for (size_t i = 0; i < n_dims; i++) {
if (tensorShapedTy.getDimSize(i) !=
resultShapedTy.getDimSize(n_dims - (i + 1))) {
llvm::ArrayRef<int64_t> inShape = tensorShapedTy.getShape();
llvm::ArrayRef<int64_t> outShape = resultShapedTy.getShape();
int64_t inputDimensions = (int64_t)inShape.size();
mlir::ArrayAttr axes = this->axes();
if (axes.empty()) {
for (int64_t i = 0; i < inputDimensions; i++) {
if (inShape[i] != outShape[inputDimensions - (i + 1)]) {
this->emitOpError()
<< "output tensor should have inverted dimensions of input";
return mlir::failure();
}
}
} else {
if (axes.size() != (size_t)inputDimensions) {
this->emitOpError() << "has invalid axes attribute (doesn't have "
<< inputDimensions << " elements)";
return mlir::failure();
}
auto seenAxes = std::unordered_set<int64_t>{};
size_t i = 0;
for (mlir::Attribute axisAttribute : axes) {
int64_t axis = axisAttribute.cast<mlir::IntegerAttr>().getInt();
bool axisIsValid = (0 <= axis) && (axis < inputDimensions);
if (!axisIsValid) {
this->emitOpError()
<< "has invalid axes attribute (axes[" << i << "] "
<< "isn't in range [0, " << inputDimensions - 1 << "])";
return mlir::failure();
}
seenAxes.insert(axis);
if (outShape[i] != inShape[axis]) {
this->emitOpError() << "has invalid output shape (output.shape[" << i
<< "] is not input.shape[axes[" << i << "]])";
return mlir::failure();
}
i++;
}
if (seenAxes.size() != (size_t)inputDimensions) {
this->emitOpError()
<< "output tensor should have inverted dimensions of input";
<< "has invalid axes attribute (doesn't contain all input axes)";
return mlir::failure();
}
}
return mlir::success();
}

View File

@@ -0,0 +1,127 @@
// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<7>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<7>
// CHECK-NEXT: } -> tensor<3x2x!FHE.eint<7>>
// CHECK-NEXT: return %[[v1]] : tensor<3x2x!FHE.eint<7>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> {
%0 = "FHELinalg.transpose"(%arg0) : (tensor<2x3x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>>
return %0 : tensor<3x2x!FHE.eint<7>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x2x!FHE.eint<7>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x3x2x!FHE.eint<7>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<7>
// CHECK-NEXT: } -> tensor<4x3x2x!FHE.eint<7>>
// CHECK-NEXT: return %[[v1]] : tensor<4x3x2x!FHE.eint<7>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>> {
%0 = "FHELinalg.transpose"(%arg0) : (tensor<2x3x4x!FHE.eint<7>>) -> tensor<4x3x2x!FHE.eint<7>>
return %0 : tensor<4x3x2x!FHE.eint<7>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x5x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<4x3x5x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<4x3x5x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<4x3x5x!FHE.eint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> {
%0 = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
return %0 : tensor<4x3x5x!FHE.eint<6>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x5x3x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<4x5x3x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<4x5x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<4x5x3x!FHE.eint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> {
%0 = "FHELinalg.transpose"(%arg0) { axes = [1, 2, 0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>>
return %0 : tensor<4x5x3x!FHE.eint<6>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<3x5x4x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<3x5x4x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<3x5x4x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<3x5x4x!FHE.eint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> {
%0 = "FHELinalg.transpose"(%arg0) { axes = [0, 2, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>>
return %0 : tensor<3x5x4x!FHE.eint<6>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x3x4x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<5x3x4x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<5x3x4x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<5x3x4x!FHE.eint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> {
%0 = "FHELinalg.transpose"(%arg0) { axes = [2, 0, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>>
return %0 : tensor<5x3x4x!FHE.eint<6>>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2, d1)>
// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x2x4x3x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x5x!FHE.eint<6>>) outs(%[[v0]] : tensor<5x2x4x3x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<6>, %[[aa1:.*]]: !FHE.eint<6>):
// CHECK-NEXT: linalg.yield %[[aa0]] : !FHE.eint<6>
// CHECK-NEXT: } -> tensor<5x2x4x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v1]] : tensor<5x2x4x3x!FHE.eint<6>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> {
%0 = "FHELinalg.transpose"(%arg0) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>>
return %0 : tensor<5x2x4x3x!FHE.eint<6>>
}

View File

@@ -493,3 +493,48 @@ func.func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_eint_3D_axes_102(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
func.func @transpose_eint_3D_axes_102(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [1, 0, 2]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<4x3x5x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
return %c : tensor<4x3x5x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_eint_3D_axes_120(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>>
func.func @transpose_eint_3D_axes_120(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [1, 2, 0]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<4x5x3x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) { axes = [1, 2, 0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x5x3x!FHE.eint<6>>
return %c : tensor<4x5x3x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_eint_3D_axes_021(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>>
func.func @transpose_eint_3D_axes_021(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [0, 2, 1]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<3x5x4x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) { axes = [0, 2, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x5x4x!FHE.eint<6>>
return %c : tensor<3x5x4x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_eint_3D_axes_201(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>>
func.func @transpose_eint_3D_axes_201(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [2, 0, 1]} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<5x3x4x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) { axes = [2, 0, 1] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x3x4x!FHE.eint<6>>
return %c : tensor<5x3x4x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_eint_4D_axes_3021(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>>
func.func @transpose_eint_4D_axes_3021(%arg0: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {axes = [3, 0, 2, 1]} : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<5x2x4x3x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>>
return %c : tensor<5x2x4x3x!FHE.eint<6>>
}

View File

@@ -26,3 +26,33 @@ func.func @transpose_eint(%arg0: tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FH
}
// -----
func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (doesn't have 3 elements)}}
%c = "FHELinalg.transpose"(%arg0) { axes = [0] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
return %c : tensor<4x3x5x!FHE.eint<6>>
}
// -----
func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (axes[1] isn't in range [0, 2])}}
%c = "FHELinalg.transpose"(%arg0) { axes = [1, 5, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x5x!FHE.eint<6>>
return %c : tensor<4x3x5x!FHE.eint<6>>
}
// -----
func.func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x10x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op has invalid output shape (output.shape[2] is not input.shape[axes[2]])}}
%c = "FHELinalg.transpose"(%arg0) { axes = [1, 0, 2] } : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<4x3x10x!FHE.eint<6>>
return %c : tensor<4x3x10x!FHE.eint<6>>
}
// -----
func.func @transpose_eint(%arg0: tensor<2x2x2x!FHE.eint<6>>) -> tensor<2x2x2x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op has invalid axes attribute (doesn't contain all input axes)}}
%c = "FHELinalg.transpose"(%arg0) { axes = [0, 1, 0] } : (tensor<2x2x2x!FHE.eint<6>>) -> tensor<2x2x2x!FHE.eint<6>>
return %c : tensor<2x2x2x!FHE.eint<6>>
}

View File

@@ -236,6 +236,100 @@ tests:
- tensor: [1, 3, 5, 2, 4, 6]
shape: [2, 3]
---
description: transpose3d
program: |
func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x3x2x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input): (tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x3x2x!FHE.eint<6>>
return %1 : tensor<4x3x2x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
shape: [2, 3, 4]
outputs:
- tensor: [0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23]
shape: [4, 3, 2]
---
description: transpose3d_axes_102
program: |
func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x2x4x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input) { axes = [1, 0, 2] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x2x4x!FHE.eint<6>>
return %1 : tensor<3x2x4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
shape: [2, 3, 4]
outputs:
- tensor: [0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23]
shape: [3, 2, 4]
---
description: transpose3d_axes_120
program: |
func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x4x2x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input) { axes = [1, 2, 0] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<3x4x2x!FHE.eint<6>>
return %1 : tensor<3x4x2x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
shape: [2, 3, 4]
outputs:
- tensor: [0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23]
shape: [3, 4, 2]
---
description: transpose3d_axes_021
program: |
func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<2x4x3x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input) { axes = [0, 2, 1] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<2x4x3x!FHE.eint<6>>
return %1 : tensor<2x4x3x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
shape: [2, 3, 4]
outputs:
- tensor: [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23]
shape: [2, 4, 3]
---
description: transpose3d_axes_201
program: |
func.func @main(%input: tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x2x3x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input) { axes = [2, 0, 1] } : (tensor<2x3x4x!FHE.eint<6>>) -> tensor<4x2x3x!FHE.eint<6>>
return %1 : tensor<4x2x3x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
shape: [2, 3, 4]
outputs:
- tensor: [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23]
shape: [4, 2, 3]
---
description: transpose4d_axes_3021
program: |
func.func @main(%input: tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>> {
%1 = "FHELinalg.transpose"(%input) { axes = [3, 0, 2, 1] } : (tensor<2x3x4x5x!FHE.eint<6>>) -> tensor<5x2x4x3x!FHE.eint<6>>
return %1 : tensor<5x2x4x3x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96,
97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
117, 118, 119]
shape: [2, 3, 4, 5]
outputs:
- tensor: [0, 20, 40, 5, 25, 45, 10, 30, 50, 15, 35, 55, 60, 80, 100, 65, 85, 105, 70, 90, 110, 75, 95,
115, 1, 21, 41, 6, 26, 46, 11, 31, 51, 16, 36, 56, 61, 81, 101, 66, 86, 106, 71, 91, 111, 76,
96, 116, 2, 22, 42, 7, 27, 47, 12, 32, 52, 17, 37, 57, 62, 82, 102, 67, 87, 107, 72, 92, 112,
77, 97, 117, 3, 23, 43, 8, 28, 48, 13, 33, 53, 18, 38, 58, 63, 83, 103, 68, 88, 108, 73, 93, 113,
78, 98, 118, 4, 24, 44, 9, 29, 49, 14, 34, 54, 19, 39, 59, 64, 84, 104, 69, 89, 109, 74, 94, 114,
79, 99, 119]
shape: [5, 2, 4, 3]
---
description: conv2dWithGroup1C
program: |
func.func @main(%input: tensor<1x6x4x4x!FHE.eint<5>>, %weight: tensor<6x1x2x2xi6>) -> tensor<1x6x3x3x!FHE.eint<5>> {