feat: add FHELinalg.transpose operation

This commit is contained in:
youben11
2022-03-22 09:28:34 +01:00
committed by Ayoub Benaissa
parent 6c7cd97e73
commit 3a4723a0b8
4 changed files with 125 additions and 0 deletions

View File

@@ -1001,5 +1001,35 @@ $_state.addAttribute("dilations", dilations);
}];
}
def TransposeOp : FHELinalg_Op<"transpose", []> {
let summary = "Returns a tensor that contains the transposition of the input tensor.";
let description = [{
Performs a transpose operation on an N-dimensional tensor.
```mlir
"FHELinalg.transpose"(%a) : (tensor<n0xn1x...xnNxType>) -> tensor<nNx...xn1xn0xType>
```
Examples:
```mlir
// Transpose the input tensor
// [1,2] [1, 3, 5]
// [3,4] => [2, 4, 6]
// [5,6]
//
"FHELinalg.transpose"(%a) : (tensor<3x2xi7>) -> tensor<2x3xi7>
```
}];
// Args and result are both supposed to be of tensor of enccrypted integers, and the verifier does check that
let arguments = (ins AnyType:$tensor);
let results = (outs AnyType);
let verifier = [{
return ::mlir::concretelang::FHELinalg::verifyTranspose(*this);
}];
}
#endif

View File

@@ -1605,6 +1605,42 @@ void FhelinalgConv2DNchwFchwOp::getEffects(
outputBuffers);
}
/// Verify the transpose shapes
mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) {
mlir::Type tensorTy = ((mlir::Type)transposeOp.tensor().getType());
if (!tensorTy.isa<RankedTensorType>()) {
transposeOp->emitOpError() << "should have operand as tensor";
return mlir::failure();
}
mlir::Type resultTy = ((mlir::Type)transposeOp.getResult().getType());
if (!resultTy.isa<RankedTensorType>()) {
transposeOp->emitOpError() << "should have result as tensor";
return mlir::failure();
}
auto tensorShapedTy = tensorTy.dyn_cast_or_null<mlir::ShapedType>();
auto resultShapedTy = resultTy.dyn_cast_or_null<mlir::ShapedType>();
if (tensorShapedTy.getShape().size() != resultShapedTy.getShape().size()) {
transposeOp.emitOpError()
<< "input and output tensors should have the same number of dimensions";
return mlir::failure();
}
if (tensorShapedTy.getElementType() != resultShapedTy.getElementType()) {
transposeOp.emitOpError()
<< "input and output tensors should have the same element type";
return mlir::failure();
}
size_t n_dims = tensorShapedTy.getShape().size();
for (size_t i = 0; i < n_dims; i++) {
if (tensorShapedTy.getDimSize(i) !=
resultShapedTy.getDimSize(n_dims - (i + 1))) {
transposeOp.emitOpError()
<< "output tensor should have inverted dimensions of input";
return mlir::failure();
}
}
return mlir::success();
}
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -352,3 +352,34 @@ func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: ten
%1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
/////////////////////////////////////////////////
// FHELinalg.transpose
/////////////////////////////////////////////////
// CHECK-LABEL: @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
func @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<10x2x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
return %c : tensor<10x2x!FHE.eint<6>>
}
// CHECK-LABEL: @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6>
func @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6>
// CHECK-NEXT: return %[[v0]] : tensor<10x2xi6>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6>
return %c : tensor<10x2xi6>
}
// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>>
// CHECK-NEXT: }
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}

View File

@@ -0,0 +1,28 @@
// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// Incompatible types
func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same element type}}
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>>
return %c : tensor<5x4x3x!FHE.eint<7>>
}
// -----
// Incompatible shapes
func @transpose_eint(%arg0: tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same number of dimensions}}
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}
// -----
// Incompatible shapes
func @transpose_eint(%arg0: tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.transpose' op output tensor should have inverted dimensions of input}}
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
return %c : tensor<5x4x3x!FHE.eint<6>>
}
// -----