diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index 02a87b0d6..87db50620 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -541,7 +541,7 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern { ((mlir::Type)converter.convertType( (inRank ? shapeOp.src() : shapeOp.result()).getType())) .cast(); - lweAssoc.push_back(reassocTy.getRank()); + lweAssoc.push_back(reassocTy.getRank() - 1); newReassocs.push_back(lweAssoc); rewriter.replaceOpWithNewOp(shapeOp, newResultTy, shapeOp.src(), @@ -874,9 +874,9 @@ void ConcreteToBConcretePass::runOnOperation() { // Add patterns to rewrite some of memref ops that was introduced by the // linalg bufferization of encrypted tensor (first conversion of this pass) - insertTensorShapeOpPattern( + insertTensorShapeOpPattern( getContext(), patterns, target); - insertTensorShapeOpPattern( + insertTensorShapeOpPattern( getContext(), patterns, target); // Add patterns to rewrite linalg op to nested loops with views on diff --git a/compiler/tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir b/compiler/tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir new file mode 100644 index 000000000..8ecdebd8b --- /dev/null +++ b/compiler/tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir @@ -0,0 +1,39 @@ +// RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + +// CHECK: func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> { +// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x4x5x6x1025xi64> +// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : memref<2x3x4x5x6x1025xi64> into memref<720x1025xi64> +// CHECK-NEXT: %2 = memref.tensor_load %1 : memref<720x1025xi64> +// CHECK-NEXT: return %2 : tensor<720x1025xi64> +func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> { + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>> + return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>> +} + + +// ----- + +// CHECK: func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> { +// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x5x1025xi64> +// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2\], \[3\]\]]] : memref<2x3x5x1025xi64> into memref<30x1025xi64> +// CHECK-NEXT: %2 = memref.expand_shape %1 [[_:\[\[0, 1\], \[2\]\]]] : memref<30x1025xi64> into memref<5x6x1025xi64> +// CHECK-NEXT: %3 = memref.tensor_load %2 : memref<5x6x1025xi64> +// CHECK-NEXT: return %3 : tensor<5x6x1025xi64> +func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> { + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>> + %1 = linalg.tensor_expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> + return %1 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> +} + + +// ----- + +// CHECK: func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> { +// CHECK-NEXT: %0 = memref.buffer_cast %arg0 : memref<2x3x2x3x4x1025xi64> +// CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : memref<2x3x2x3x4x1025xi64> into memref<6x2x12x1025xi64> +// CHECK-NEXT: %2 = memref.tensor_load %1 : memref<6x2x12x1025xi64> +// CHECK-NEXT: return %2 : tensor<6x2x12x1025xi64> +func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> { + %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> + return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> +} \ No newline at end of file