mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(compiler): Batching: Emit collapse/expand shape operations only for rank > 1
The batching pass passes operands to the batched operation as a flat, one-dimensional vector produced through a `tensor.collapse_shape` operation collapsing all dimensions of the original tensor of operands. Similarly, the shape of the result vector of the batched operation is expanded to the original shape afterwards using a `tensor.expand_shape` operation. The pass emits the `tensor.collapse_shape` and `tensor.expand_shape` operations unconditionally, even for tensors, which already have only a single dimension. This causes the verifiers of these operations to fail in some cases, aborting the entire compilation process. This patch lets the batching pass emit `tensor.collapse_shape` and `tensor.expand_shape` for batched operands and batched results only if the rank of the corresponding tensors is greater than one.
This commit is contained in:
@@ -649,16 +649,22 @@ public:
|
||||
mlir::RankedTensorType sliceType =
|
||||
slice.getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
// Flatten the tensor with the batched operands, so that they can
|
||||
// be passed as a one-dimensional tensor to the batched operation
|
||||
mlir::Value flattenedSlice;
|
||||
mlir::ReassociationIndices indices;
|
||||
for (int64_t i = 0; i < sliceType.getRank(); i++)
|
||||
indices.push_back(i);
|
||||
|
||||
mlir::tensor::CollapseShapeOp flattenedSlice =
|
||||
rewriter.create<mlir::tensor::CollapseShapeOp>(
|
||||
targetExtractOp.getLoc(), slice,
|
||||
llvm::SmallVector<mlir::ReassociationIndices>{indices});
|
||||
if (sliceType.getRank() == 1) {
|
||||
flattenedSlice = slice;
|
||||
} else {
|
||||
// Flatten the tensor with the batched operands, so that they
|
||||
// can be passed as a one-dimensional tensor to the batched
|
||||
// operation
|
||||
for (int64_t i = 0; i < sliceType.getRank(); i++)
|
||||
indices.push_back(i);
|
||||
|
||||
flattenedSlice = rewriter.create<mlir::tensor::CollapseShapeOp>(
|
||||
targetExtractOp.getLoc(), slice,
|
||||
llvm::SmallVector<mlir::ReassociationIndices>{indices});
|
||||
}
|
||||
|
||||
// Create the batched operation and pass flattened, batched
|
||||
// operands
|
||||
@@ -666,18 +672,23 @@ public:
|
||||
mlir::Value batchedOpResult =
|
||||
targetOp.createBatchedOperation(ilob, flattenedSlice);
|
||||
|
||||
// Restore original shape of the batched operands for the result
|
||||
// of the batched operation. Dimensions, result from indexing with
|
||||
// non-loop-IVs are collapsed.
|
||||
mlir::Type expandedBatchResultType = mlir::RankedTensorType::get(
|
||||
sliceType.getShape(), batchedOpResult.getType()
|
||||
.dyn_cast<mlir::RankedTensorType>()
|
||||
.getElementType());
|
||||
mlir::Value expandedBatchResultTensor;
|
||||
|
||||
mlir::Value expandedBatchResultTensor =
|
||||
rewriter.create<mlir::tensor::ExpandShapeOp>(
|
||||
targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult,
|
||||
llvm::SmallVector<mlir::ReassociationIndices>{indices});
|
||||
if (sliceType.getRank() == 1) {
|
||||
expandedBatchResultTensor = batchedOpResult;
|
||||
} else {
|
||||
// Restore original shape of the batched operands for the result
|
||||
// of the batched operation. Dimensions, result from indexing
|
||||
// with non-loop-IVs are collapsed.
|
||||
mlir::Type expandedBatchResultType = mlir::RankedTensorType::get(
|
||||
sliceType.getShape(), batchedOpResult.getType()
|
||||
.dyn_cast<mlir::RankedTensorType>()
|
||||
.getElementType());
|
||||
|
||||
expandedBatchResultTensor = rewriter.create<mlir::tensor::ExpandShapeOp>(
|
||||
targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult,
|
||||
llvm::SmallVector<mlir::ReassociationIndices>{indices});
|
||||
}
|
||||
|
||||
// Collect all loop IVs from the extract op. These will be used to
|
||||
// index the batched result tensor within the loop for consumers
|
||||
|
||||
@@ -31,6 +31,27 @@ func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ci
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> {
|
||||
func.func @batch_continuous_slice_keyswitch_1dim(%arg0: tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>> {
|
||||
%c4 = arith.constant 4 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
|
||||
%0 = bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
// CHECK: %[[V0:.*]] = "Concrete.batched_keyswitch_lwe"(%[[ARG:.*]]) {baseLog = 2 : i32, level = 5 : i32} : (tensor<4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
// CHECK: return %[[V0]]
|
||||
|
||||
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<4x!Concrete.lwe_ciphertext<572,2>>) {
|
||||
%2 = tensor.extract %arg0[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
%3 = "Concrete.keyswitch_lwe"(%2) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<572,2>) -> !Concrete.lwe_ciphertext<572,2>
|
||||
%4 = tensor.insert %3 into %arg3[%arg2] : tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
scf.yield %4 : tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
}
|
||||
return %1 : tensor<4x!Concrete.lwe_ciphertext<572,2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> {
|
||||
func.func @batch_continuous_slice_bootstrap(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,2>> {
|
||||
%c2 = arith.constant 2 : index
|
||||
|
||||
Reference in New Issue
Block a user