fix: correct reassociation in expand and collapse ops

This commit is contained in:
youben11
2022-03-01 10:02:57 +01:00
committed by Quentin Bourgerie
parent d06e0c0a59
commit 65e2e2f600
2 changed files with 42 additions and 3 deletions

View File

@@ -541,7 +541,7 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
((mlir::Type)converter.convertType(
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
.cast<mlir::MemRefType>();
lweAssoc.push_back(reassocTy.getRank());
lweAssoc.push_back(reassocTy.getRank() - 1);
newReassocs.push_back(lweAssoc);
rewriter.replaceOpWithNewOp<ShapeOp>(shapeOp, newResultTy, shapeOp.src(),
@@ -874,9 +874,9 @@ void ConcreteToBConcretePass::runOnOperation() {
// Add patterns to rewrite some of memref ops that was introduced by the
// linalg bufferization of encrypted tensor (first conversion of this pass)
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, true>(
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, false>(
getContext(), patterns, target);
insertTensorShapeOpPattern<mlir::memref::CollapseShapeOp, false>(
insertTensorShapeOpPattern<mlir::memref::CollapseShapeOp, true>(
getContext(), patterns, target);
// Add patterns to rewrite linalg op to nested loops with views on

View File

@@ -0,0 +1,39 @@
// RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x4x5x6x1025xi64>
// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : memref<2x3x4x5x6x1025xi64> into memref<720x1025xi64>
// CHECK-NEXT: %2 = memref.tensor_load %1 : memref<720x1025xi64>
// CHECK-NEXT: return %2 : tensor<720x1025xi64>
func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> {
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>>
return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>>
}
// -----
// CHECK: func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> {
// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x5x1025xi64>
// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2\], \[3\]\]]] : memref<2x3x5x1025xi64> into memref<30x1025xi64>
// CHECK-NEXT: %2 = memref.expand_shape %1 [[_:\[\[0, 1\], \[2\]\]]] : memref<30x1025xi64> into memref<5x6x1025xi64>
// CHECK-NEXT: %3 = memref.tensor_load %2 : memref<5x6x1025xi64>
// CHECK-NEXT: return %3 : tensor<5x6x1025xi64>
func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> {
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>>
%1 = linalg.tensor_expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
return %1 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
}
// -----
// CHECK: func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> {
// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x2x3x4x1025xi64>
// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : memref<2x3x2x3x4x1025xi64> into memref<6x2x12x1025xi64>
// CHECK-NEXT: %2 = memref.tensor_load %1 : memref<6x2x12x1025xi64>
// CHECK-NEXT: return %2 : tensor<6x2x12x1025xi64>
func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> {
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
}