mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add FHELinalg.transpose operation
This commit is contained in:
@@ -1001,5 +1001,35 @@ $_state.addAttribute("dilations", dilations);
|
||||
}];
|
||||
}
|
||||
|
||||
def TransposeOp : FHELinalg_Op<"transpose", []> {
|
||||
let summary = "Returns a tensor that contains the transposition of the input tensor.";
|
||||
|
||||
let description = [{
|
||||
Performs a transpose operation on an N-dimensional tensor.
|
||||
|
||||
```mlir
|
||||
"FHELinalg.transpose"(%a) : (tensor<n0xn1x...xnNxType>) -> tensor<nNx...xn1xn0xType>
|
||||
```
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Transpose the input tensor
|
||||
// [1,2] [1, 3, 5]
|
||||
// [3,4] => [2, 4, 6]
|
||||
// [5,6]
|
||||
//
|
||||
"FHELinalg.transpose"(%a) : (tensor<3x2xi7>) -> tensor<2x3xi7>
|
||||
```
|
||||
}];
|
||||
|
||||
// Args and result are both supposed to be of tensor of enccrypted integers, and the verifier does check that
|
||||
let arguments = (ins AnyType:$tensor);
|
||||
let results = (outs AnyType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::FHELinalg::verifyTranspose(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1605,6 +1605,42 @@ void FhelinalgConv2DNchwFchwOp::getEffects(
|
||||
outputBuffers);
|
||||
}
|
||||
|
||||
/// Verify the transpose shapes
|
||||
mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) {
|
||||
mlir::Type tensorTy = ((mlir::Type)transposeOp.tensor().getType());
|
||||
if (!tensorTy.isa<RankedTensorType>()) {
|
||||
transposeOp->emitOpError() << "should have operand as tensor";
|
||||
return mlir::failure();
|
||||
}
|
||||
mlir::Type resultTy = ((mlir::Type)transposeOp.getResult().getType());
|
||||
if (!resultTy.isa<RankedTensorType>()) {
|
||||
transposeOp->emitOpError() << "should have result as tensor";
|
||||
return mlir::failure();
|
||||
}
|
||||
auto tensorShapedTy = tensorTy.dyn_cast_or_null<mlir::ShapedType>();
|
||||
auto resultShapedTy = resultTy.dyn_cast_or_null<mlir::ShapedType>();
|
||||
if (tensorShapedTy.getShape().size() != resultShapedTy.getShape().size()) {
|
||||
transposeOp.emitOpError()
|
||||
<< "input and output tensors should have the same number of dimensions";
|
||||
return mlir::failure();
|
||||
}
|
||||
if (tensorShapedTy.getElementType() != resultShapedTy.getElementType()) {
|
||||
transposeOp.emitOpError()
|
||||
<< "input and output tensors should have the same element type";
|
||||
return mlir::failure();
|
||||
}
|
||||
size_t n_dims = tensorShapedTy.getShape().size();
|
||||
for (size_t i = 0; i < n_dims; i++) {
|
||||
if (tensorShapedTy.getDimSize(i) !=
|
||||
resultShapedTy.getDimSize(n_dims - (i + 1))) {
|
||||
transposeOp.emitOpError()
|
||||
<< "output tensor should have inverted dimensions of input";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -352,3 +352,34 @@ func @conv2d_without_bias(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: ten
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.transpose
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK-LABEL: @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
|
||||
func @transpose_eint_2D(%arg0: tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<10x2x!FHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<2x10x!FHE.eint<6>>) -> tensor<10x2x!FHE.eint<6>>
|
||||
return %c : tensor<10x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6>
|
||||
func @transpose_int_2D(%arg0: tensor<2x10xi6>) -> tensor<10x2xi6> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<10x2xi6>
|
||||
// CHECK-NEXT: }
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<2x10xi6>) -> tensor<10x2xi6>
|
||||
return %c : tensor<10x2xi6>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>>
|
||||
// CHECK-NEXT: }
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// Incompatible types
|
||||
func @transpose_eint(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>> {
|
||||
// expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same element type}}
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<7>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Incompatible shapes
|
||||
func @transpose_eint(%arg0: tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
|
||||
// expected-error @+1 {{'FHELinalg.transpose' op input and output tensors should have the same number of dimensions}}
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Incompatible shapes
|
||||
func @transpose_eint(%arg0: tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> {
|
||||
// expected-error @+1 {{'FHELinalg.transpose' op output tensor should have inverted dimensions of input}}
|
||||
%c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x6x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>>
|
||||
return %c : tensor<5x4x3x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
Reference in New Issue
Block a user