fix: Lowering from ConcreteToBConcrete of from_elements on ND tensor of Concrete.lwe_ciphertext

This commit is contained in:
Quentin Bourgerie
2022-08-08 12:04:45 +02:00
parent f91c74d7dd
commit 534e683055
3 changed files with 100 additions and 34 deletions

View File

@@ -629,31 +629,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
};
};
/// This rewrite pattern transforms any instance of
/// `tensor.from_elements` operators that operates on tensor of lwe ciphertext.
/// FromElementsOpPatterns transform each tensor.from_elements that operates on
/// Concrete.lwe_ciphertext
///
/// Example:
///
/// ```mlir
/// %0 = tensor.from_elements %e0, ..., %e(n-1)
/// : tensor<Nx!Concrete.lwe_ciphertext<lweDim,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %m = memref.alloc() : memref<NxlweDim+1xi64>
/// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref<lweDim+1xi64>
/// %m0 = memref.buffer_cast %e0 : memref<lweDim+1xi64>
/// memref.copy %m0, s0 : memref<lweDim+1xi64> to memref<lweDim+1xi64>
/// ...
/// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1]
/// : memref<lweDim+1xi64>
/// %m(n-1) = memref.buffer_cast %e(n-1) : memref<lweDim+1xi64>
/// memref.copy %e(n-1), s(n-1)
/// : memref<lweDim+1xi64> to memref<lweDim+1xi64>
/// %0 = memref.tensor_load %m : memref<NxlweDim+1xi64>
/// ```
/// refs: check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir
struct FromElementsOpPattern
: public mlir::OpRewritePattern<mlir::tensor::FromElementsOp> {
FromElementsOpPattern(::mlir::MLIRContext *context,
@@ -673,26 +652,33 @@ struct FromElementsOpPattern
auto newTensorResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
auto newRank = newTensorResultTy.getRank();
auto newShape = newTensorResultTy.getShape();
mlir::Value tensor = rewriter.create<mlir::bufferization::AllocTensorOp>(
fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{});
llvm::SmallVector<mlir::OpFoldResult> sizes(1,
// sizes are [1, ..., 1, lweSize]
llvm::SmallVector<mlir::OpFoldResult> sizes(newRank - 1,
rewriter.getI64IntegerAttr(1));
std::transform(newTensorResultTy.getShape().begin() + 1,
newTensorResultTy.getShape().end(),
std::back_inserter(sizes),
[&](auto v) { return rewriter.getI64IntegerAttr(v); });
sizes.push_back(
rewriter.getI64IntegerAttr(*(newTensorResultTy.getShape().end() - 1)));
// strides are [1, ..., 1]
llvm::SmallVector<mlir::OpFoldResult> oneStrides(
newTensorResultTy.getShape().size(), rewriter.getI64IntegerAttr(1));
newShape.size(), rewriter.getI64IntegerAttr(1));
llvm::SmallVector<mlir::OpFoldResult> offsets(
newTensorResultTy.getRank(), rewriter.getI64IntegerAttr(0));
// start with offets [0, ..., 0]
llvm::SmallVector<int64_t> currentOffsets(newRank, 0);
// for each elements insert_slice with right offet
for (auto elt : llvm::enumerate(fromElementsOp.elements())) {
offsets[0] = rewriter.getI64IntegerAttr(elt.index());
// Just create offsets as attributes
llvm::SmallVector<mlir::OpFoldResult, 4> offsets;
offsets.reserve(currentOffsets.size());
std::transform(currentOffsets.begin(), currentOffsets.end(),
std::back_inserter(offsets),
[&](auto v) { return rewriter.getI64IntegerAttr(v); });
mlir::tensor::InsertSliceOp insOp =
rewriter.create<mlir::tensor::InsertSliceOp>(
fromElementsOp.getLoc(),
@@ -708,6 +694,16 @@ struct FromElementsOpPattern
});
tensor = insOp.getResult();
// Increment the offsets
for (auto i = newRank - 2; i >= 0; i--) {
if (currentOffsets[i] == newShape[i] - 1) {
currentOffsets[i] = 0;
continue;
}
currentOffsets[i]++;
break;
}
}
rewriter.replaceOp(fromElementsOp, tensor);

View File

@@ -0,0 +1,33 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete --split-input-file %s 2>&1| FileCheck %s
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> {
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64>
// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][1, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][2, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][3, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][4, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][5, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64>
// CHECK: return %[[V6]] : tensor<6x2049xi64>
// CHECK: }
func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<6x!Concrete.lwe_ciphertext<2048,4>> {
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!Concrete.lwe_ciphertext<2048,4>>
return %0 : tensor<6x!Concrete.lwe_ciphertext<2048,4>>
}
// -----
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<2x3x2049xi64> {
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x2049xi64>
// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][0, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][0, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][1, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][1, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][1, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64>
// CHECK: return %[[V6]] : tensor<2x3x2049xi64>
// CHECK: }
func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> {
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>>
return %0 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>>
}

View File

@@ -111,3 +111,40 @@ tests:
- tensor: [63, 12, 7, 43, 52, 31, 32, 34, 22, 0,
0, 1, 2, 3, 4, 33, 34, 7, 8, 9]
shape: [2,10]
---
description: from_elements
program: |
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<4>, %arg3: !FHE.eint<4>, %arg4: !FHE.eint<4>, %arg5: !FHE.eint<4>) -> tensor<6x!FHE.eint<4>> {
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!FHE.eint<4>>
return %0 : tensor<6x!FHE.eint<4>>
}
tests:
- inputs:
- scalar: 0
- scalar: 1
- scalar: 2
- scalar: 3
- scalar: 4
- scalar: 5
outputs:
- tensor: [0, 1, 2, 3, 4, 5]
shape: [6]
---
description: from_elements_2D
program: |
func.func @main(%arg0 : !FHE.eint<4>, %arg1 : !FHE.eint<4>, %arg2 : !FHE.eint<4>, %arg3 : !FHE.eint<4>, %arg4 : !FHE.eint<4>, %arg5 : !FHE.eint<4>) -> tensor<2x3x!FHE.eint<4>> {
%0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!FHE.eint<4>>
return %0 : tensor<2x3x!FHE.eint<4>>
}
tests:
- inputs:
- scalar: 0
- scalar: 1
- scalar: 2
- scalar: 3
- scalar: 4
- scalar: 5
outputs:
- tensor: [0, 1, 2,
3, 4, 5]
shape: [2, 3]