From fd362342f52bdce84a966b355ebcf86c27f0bdb6 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 21 Nov 2022 12:19:47 +0100 Subject: [PATCH] fix(compiler): Batching: Emit collapse/expand shape operations only for rank > 1 The batching pass passes operands to the batched operation as a flat, one-dimensional vector produced through a `tensor.collapse_shape` operation collapsing all dimensions of the original tensor of operands. Similarly, the shape of the result vector of the batched operation is expanded to the original shape afterwards using a `tensor.expand_shape` operation. The pass emits the `tensor.collapse_shape` and `tensor.expand_shape` operations unconditionally, even for tensors, which already have only a single dimension. This causes the verifiers of these operations to fail in some cases, aborting the entire compilation process. This patch lets the batching pass emit `tensor.collapse_shape` and `tensor.expand_shape` for batched operands and batched results only if the rank of the corresponding tensors is greater than one. --- compiler/lib/Transforms/Batching.cpp | 49 ++++++++++++------- .../check_tests/Transforms/batching.mlir | 21 ++++++++ 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/compiler/lib/Transforms/Batching.cpp b/compiler/lib/Transforms/Batching.cpp index f8f104e8b..4286383c8 100644 --- a/compiler/lib/Transforms/Batching.cpp +++ b/compiler/lib/Transforms/Batching.cpp @@ -649,16 +649,22 @@ public: mlir::RankedTensorType sliceType = slice.getType().cast(); - // Flatten the tensor with the batched operands, so that they can - // be passed as a one-dimensional tensor to the batched operation + mlir::Value flattenedSlice; mlir::ReassociationIndices indices; - for (int64_t i = 0; i < sliceType.getRank(); i++) - indices.push_back(i); - mlir::tensor::CollapseShapeOp flattenedSlice = - rewriter.create( - targetExtractOp.getLoc(), slice, - llvm::SmallVector{indices}); + if (sliceType.getRank() == 1) { + flattenedSlice = slice; + } else { + // Flatten the tensor with the batched operands, so that they + // can be passed as a one-dimensional tensor to the batched + // operation + for (int64_t i = 0; i < sliceType.getRank(); i++) + indices.push_back(i); + + flattenedSlice = rewriter.create( + targetExtractOp.getLoc(), slice, + llvm::SmallVector{indices}); + } // Create the batched operation and pass flattened, batched // operands @@ -666,18 +672,23 @@ public: mlir::Value batchedOpResult = targetOp.createBatchedOperation(ilob, flattenedSlice); - // Restore original shape of the batched operands for the result - // of the batched operation. Dimensions, result from indexing with - // non-loop-IVs are collapsed. - mlir::Type expandedBatchResultType = mlir::RankedTensorType::get( - sliceType.getShape(), batchedOpResult.getType() - .dyn_cast() - .getElementType()); + mlir::Value expandedBatchResultTensor; - mlir::Value expandedBatchResultTensor = - rewriter.create( - targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult, - llvm::SmallVector{indices}); + if (sliceType.getRank() == 1) { + expandedBatchResultTensor = batchedOpResult; + } else { + // Restore original shape of the batched operands for the result + // of the batched operation. Dimensions, result from indexing + // with non-loop-IVs are collapsed. + mlir::Type expandedBatchResultType = mlir::RankedTensorType::get( + sliceType.getShape(), batchedOpResult.getType() + .dyn_cast() + .getElementType()); + + expandedBatchResultTensor = rewriter.create( + targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult, + llvm::SmallVector{indices}); + } // Collect all loop IVs from the extract op. These will be used to // index the batched result tensor within the loop for consumers diff --git a/compiler/tests/check_tests/Transforms/batching.mlir b/compiler/tests/check_tests/Transforms/batching.mlir index 648c468e2..0375784ad 100644 --- a/compiler/tests/check_tests/Transforms/batching.mlir +++ b/compiler/tests/check_tests/Transforms/batching.mlir @@ -31,6 +31,27 @@ func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ci // ----- +// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> { +func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %0 = bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<572,2>> + // CHECK: %[[V0:.*]] = "Concrete.batched_keyswitch_lwe"(%[[ARG:.*]]) {baseLog = 2 : i32, level = 5 : i32} : (tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> + // CHECK: return %[[V0]] + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<4x!Concrete.lwe_ciphertext<572,2>>) { + %2 = tensor.extract %arg0[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>> + %3 = "Concrete.keyswitch_lwe"(%2) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<572,2>) -> !Concrete.lwe_ciphertext<572,2> + %4 = tensor.insert %3 into %arg3[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>> + scf.yield %4 : tensor<4x!Concrete.lwe_ciphertext<572,2>> + } + return %1 : tensor<4x!Concrete.lwe_ciphertext<572,2>> +} + +// ----- + // CHECK-LABEL: func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> { func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> { %c2 = arith.constant 2 : index