fix(compiler): Batching: Emit collapse/expand shape operations only for rank > 1

The batching pass passes operands to the batched operation as a flat,
one-dimensional vector produced through a `tensor.collapse_shape`
operation collapsing all dimensions of the original tensor of
operands. Similarly, the shape of the result vector of the batched
operation is expanded to the original shape afterwards using a
`tensor.expand_shape` operation.

The pass emits the `tensor.collapse_shape` and `tensor.expand_shape`
operations unconditionally, even for tensors, which already have only
a single dimension. This causes the verifiers of these operations to
fail in some cases, aborting the entire compilation process.

This patch lets the batching pass emit `tensor.collapse_shape` and
`tensor.expand_shape` for batched operands and batched results only if
the rank of the corresponding tensors is greater than one.
This commit is contained in:
Andi Drebes
2022-11-21 12:19:47 +01:00
parent 3c2a75186f
commit fd362342f5
2 changed files with 51 additions and 19 deletions

View File

@@ -649,16 +649,22 @@ public:
mlir::RankedTensorType sliceType =
slice.getType().cast<mlir::RankedTensorType>();
// Flatten the tensor with the batched operands, so that they can
// be passed as a one-dimensional tensor to the batched operation
mlir::Value flattenedSlice;
mlir::ReassociationIndices indices;
for (int64_t i = 0; i < sliceType.getRank(); i++)
indices.push_back(i);
mlir::tensor::CollapseShapeOp flattenedSlice =
rewriter.create<mlir::tensor::CollapseShapeOp>(
targetExtractOp.getLoc(), slice,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
if (sliceType.getRank() == 1) {
flattenedSlice = slice;
} else {
// Flatten the tensor with the batched operands, so that they
// can be passed as a one-dimensional tensor to the batched
// operation
for (int64_t i = 0; i < sliceType.getRank(); i++)
indices.push_back(i);
flattenedSlice = rewriter.create<mlir::tensor::CollapseShapeOp>(
targetExtractOp.getLoc(), slice,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
}
// Create the batched operation and pass flattened, batched
// operands
@@ -666,18 +672,23 @@ public:
mlir::Value batchedOpResult =
targetOp.createBatchedOperation(ilob, flattenedSlice);
// Restore original shape of the batched operands for the result
// of the batched operation. Dimensions, result from indexing with
// non-loop-IVs are collapsed.
mlir::Type expandedBatchResultType = mlir::RankedTensorType::get(
sliceType.getShape(), batchedOpResult.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType());
mlir::Value expandedBatchResultTensor;
mlir::Value expandedBatchResultTensor =
rewriter.create<mlir::tensor::ExpandShapeOp>(
targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
if (sliceType.getRank() == 1) {
expandedBatchResultTensor = batchedOpResult;
} else {
// Restore original shape of the batched operands for the result
// of the batched operation. Dimensions, result from indexing
// with non-loop-IVs are collapsed.
mlir::Type expandedBatchResultType = mlir::RankedTensorType::get(
sliceType.getShape(), batchedOpResult.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType());
expandedBatchResultTensor = rewriter.create<mlir::tensor::ExpandShapeOp>(
targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult,
llvm::SmallVector<mlir::ReassociationIndices>{indices});
}
// Collect all loop IVs from the extract op. These will be used to
// index the batched result tensor within the loop for consumers

View File

@@ -31,6 +31,27 @@ func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ci
// -----
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> {
func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<572,2>>
// CHECK: %[[V0:.*]] = "Concrete.batched_keyswitch_lwe"(%[[ARG:.*]]) {baseLog = 2 : i32, level = 5 : i32} : (tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>>
// CHECK: return %[[V0]]
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<4x!Concrete.lwe_ciphertext<572,2>>) {
%2 = tensor.extract %arg0[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>>
%3 = "Concrete.keyswitch_lwe"(%2) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<572,2>) -> !Concrete.lwe_ciphertext<572,2>
%4 = tensor.insert %3 into %arg3[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>>
scf.yield %4 : tensor<4x!Concrete.lwe_ciphertext<572,2>>
}
return %1 : tensor<4x!Concrete.lwe_ciphertext<572,2>>
}
// -----
// CHECK-LABEL: func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> {
func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> {
%c2 = arith.constant 2 : index