mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: Lowering from ConcreteToBConcrete of from_elements on ND tensor of Concrete.lwe_ciphertext
This commit is contained in:
@@ -629,31 +629,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
|
||||
/// FromElementsOpPatterns transform each tensor.from_elements that operates on
|
||||
/// Concrete.lwe_ciphertext
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.from_elements %e0, ..., %e(n-1)
|
||||
/// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %m = memref.alloc() : memref<NxlweDim+1xi64>
|
||||
/// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
|
||||
/// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
|
||||
/// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
/// ...
|
||||
/// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
|
||||
/// : memref<lweDim+1xi64>
|
||||
/// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
|
||||
/// memref.copy %e(n-1), s(n-1)
|
||||
/// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
|
||||
/// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
|
||||
/// ```
|
||||
/// refs: check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir
|
||||
struct FromElementsOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::FromElementsOp> {
|
||||
FromElementsOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -673,26 +652,33 @@ struct FromElementsOpPattern
|
||||
|
||||
auto newTensorResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
auto newRank = newTensorResultTy.getRank();
|
||||
auto newShape = newTensorResultTy.getShape();
|
||||
|
||||
mlir::Value tensor = rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{});
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> sizes(1,
|
||||
// sizes are [1, ..., 1, lweSize]
|
||||
llvm::SmallVector<mlir::OpFoldResult> sizes(newRank - 1,
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
std::transform(newTensorResultTy.getShape().begin() + 1,
|
||||
newTensorResultTy.getShape().end(),
|
||||
std::back_inserter(sizes),
|
||||
[&](auto v) { return rewriter.getI64IntegerAttr(v); });
|
||||
sizes.push_back(
|
||||
rewriter.getI64IntegerAttr(*(newTensorResultTy.getShape().end() - 1)));
|
||||
|
||||
// strides are [1, ..., 1]
|
||||
llvm::SmallVector<mlir::OpFoldResult> oneStrides(
|
||||
newTensorResultTy.getShape().size(), rewriter.getI64IntegerAttr(1));
|
||||
newShape.size(), rewriter.getI64IntegerAttr(1));
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets(
|
||||
newTensorResultTy.getRank(), rewriter.getI64IntegerAttr(0));
|
||||
// start with offets [0, ..., 0]
|
||||
llvm::SmallVector<int64_t> currentOffsets(newRank, 0);
|
||||
|
||||
// for each elements insert_slice with right offet
|
||||
for (auto elt : llvm::enumerate(fromElementsOp.elements())) {
|
||||
offsets[0] = rewriter.getI64IntegerAttr(elt.index());
|
||||
|
||||
// Just create offsets as attributes
|
||||
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
|
||||
offsets.reserve(currentOffsets.size());
|
||||
std::transform(currentOffsets.begin(), currentOffsets.end(),
|
||||
std::back_inserter(offsets),
|
||||
[&](auto v) { return rewriter.getI64IntegerAttr(v); });
|
||||
mlir::tensor::InsertSliceOp insOp =
|
||||
rewriter.create<mlir::tensor::InsertSliceOp>(
|
||||
fromElementsOp.getLoc(),
|
||||
@@ -708,6 +694,16 @@ struct FromElementsOpPattern
|
||||
});
|
||||
|
||||
tensor = insOp.getResult();
|
||||
|
||||
// Increment the offsets
|
||||
for (auto i = newRank - 2; i >= 0; i--) {
|
||||
if (currentOffsets[i] == newShape[i] - 1) {
|
||||
currentOffsets[i] = 0;
|
||||
continue;
|
||||
}
|
||||
currentOffsets[i]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(fromElementsOp, tensor);
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete --split-input-file %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> {
|
||||
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64>
|
||||
// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][1, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][2, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][3, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][4, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][5, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
|
||||
// CHECK: return %[[V6]] : tensor<6x2049xi64>
|
||||
// CHECK: }
|
||||
func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<6x!Concrete.lwe_ciphertext<2048,4>> {
|
||||
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!Concrete.lwe_ciphertext<2048,4>>
|
||||
return %0 : tensor<6x!Concrete.lwe_ciphertext<2048,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<2x3x2049xi64> {
|
||||
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][0, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][0, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][1, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][1, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][1, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
|
||||
// CHECK: return %[[V6]] : tensor<2x3x2049xi64>
|
||||
// CHECK: }
|
||||
func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> {
|
||||
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>>
|
||||
return %0 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>>
|
||||
}
|
||||
@@ -111,3 +111,40 @@ tests:
|
||||
- tensor: [63, 12, 7, 43, 52, 31, 32, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 33, 34, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
---
|
||||
description: from_elements
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<4>, %arg3: !FHE.eint<4>, %arg4: !FHE.eint<4>, %arg5: !FHE.eint<4>) -> tensor<6x!FHE.eint<4>> {
|
||||
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!FHE.eint<4>>
|
||||
return %0 : tensor<6x!FHE.eint<4>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 2
|
||||
- scalar: 3
|
||||
- scalar: 4
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- tensor: [0, 1, 2, 3, 4, 5]
|
||||
shape: [6]
|
||||
---
|
||||
description: from_elements_2D
|
||||
program: |
|
||||
func.func @main(%arg0 : !FHE.eint<4>, %arg1 : !FHE.eint<4>, %arg2 : !FHE.eint<4>, %arg3 : !FHE.eint<4>, %arg4 : !FHE.eint<4>, %arg5 : !FHE.eint<4>) -> tensor<2x3x!FHE.eint<4>> {
|
||||
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!FHE.eint<4>>
|
||||
return %0 : tensor<2x3x!FHE.eint<4>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 2
|
||||
- scalar: 3
|
||||
- scalar: 4
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- tensor: [0, 1, 2,
|
||||
3, 4, 5]
|
||||
shape: [2, 3]
|
||||
|
||||
Reference in New Issue
Block a user