From 534e68305501a5ac75fc9720ada050e152521e7c Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 8 Aug 2022 12:04:45 +0200 Subject: [PATCH] fix: Lowering from ConcreteToBConcrete of from_elements on ND tensor of Concrete.lwe_ciphertext --- .../ConcreteToBConcrete.cpp | 64 +++++++++---------- .../tensor_from_elements.mlir | 33 ++++++++++ .../end_to_end_encrypted_tensor.yaml | 37 +++++++++++ 3 files changed, 100 insertions(+), 34 deletions(-) create mode 100644 compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index f6b8cdd94..f57192a07 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -629,31 +629,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern { }; }; -/// This rewrite pattern transforms any instance of -/// `tensor.from_elements` operators that operates on tensor of lwe ciphertext. +/// FromElementsOpPatterns transform each tensor.from_elements that operates on +/// Concrete.lwe_ciphertext /// -/// Example: -/// -/// ```mlir -/// %0 = tensor.from_elements %e0, ..., %e(n-1) -/// : tensor> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %m = memref.alloc() : memref -/// %s0 = memref.subview %m[0, 0][1, lweDim+1][1, 1] : memref -/// %m0 = memref.buffer_cast %e0 : memref -/// memref.copy %m0, s0 : memref to memref -/// ... -/// %s(n-1) = memref.subview %m[(n-1), 0][1, lweDim+1][1, 1] -/// : memref -/// %m(n-1) = memref.buffer_cast %e(n-1) : memref -/// memref.copy %e(n-1), s(n-1) -/// : memref to memref -/// %0 = memref.tensor_load %m : memref -/// ``` +/// refs: check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir struct FromElementsOpPattern : public mlir::OpRewritePattern { FromElementsOpPattern(::mlir::MLIRContext *context, @@ -673,26 +652,33 @@ struct FromElementsOpPattern auto newTensorResultTy = converter.convertType(resultTy).cast(); + auto newRank = newTensorResultTy.getRank(); + auto newShape = newTensorResultTy.getShape(); mlir::Value tensor = rewriter.create( fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{}); - llvm::SmallVector sizes(1, + // sizes are [1, ..., 1, lweSize] + llvm::SmallVector sizes(newRank - 1, rewriter.getI64IntegerAttr(1)); - std::transform(newTensorResultTy.getShape().begin() + 1, - newTensorResultTy.getShape().end(), - std::back_inserter(sizes), - [&](auto v) { return rewriter.getI64IntegerAttr(v); }); + sizes.push_back( + rewriter.getI64IntegerAttr(*(newTensorResultTy.getShape().end() - 1))); + // strides are [1, ..., 1] llvm::SmallVector oneStrides( - newTensorResultTy.getShape().size(), rewriter.getI64IntegerAttr(1)); + newShape.size(), rewriter.getI64IntegerAttr(1)); - llvm::SmallVector offsets( - newTensorResultTy.getRank(), rewriter.getI64IntegerAttr(0)); + // start with offets [0, ..., 0] + llvm::SmallVector currentOffsets(newRank, 0); + // for each elements insert_slice with right offet for (auto elt : llvm::enumerate(fromElementsOp.elements())) { - offsets[0] = rewriter.getI64IntegerAttr(elt.index()); - + // Just create offsets as attributes + llvm::SmallVector offsets; + offsets.reserve(currentOffsets.size()); + std::transform(currentOffsets.begin(), currentOffsets.end(), + std::back_inserter(offsets), + [&](auto v) { return rewriter.getI64IntegerAttr(v); }); mlir::tensor::InsertSliceOp insOp = rewriter.create( fromElementsOp.getLoc(), @@ -708,6 +694,16 @@ struct FromElementsOpPattern }); tensor = insOp.getResult(); + + // Increment the offsets + for (auto i = newRank - 2; i >= 0; i--) { + if (currentOffsets[i] == newShape[i] - 1) { + currentOffsets[i] = 0; + continue; + } + currentOffsets[i]++; + break; + } } rewriter.replaceOp(fromElementsOp, tensor); diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir new file mode 100644 index 000000000..eb5801e5c --- /dev/null +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_from_elements.mlir @@ -0,0 +1,33 @@ +// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete --split-input-file %s 2>&1| FileCheck %s + +// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> { +// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64> +// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][1, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][2, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][3, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][4, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][5, 0] [1, 2049] [1, 1] : tensor<2049xi64> into tensor<6x2049xi64> +// CHECK: return %[[V6]] : tensor<6x2049xi64> +// CHECK: } +func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<6x!Concrete.lwe_ciphertext<2048,4>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!Concrete.lwe_ciphertext<2048,4>> + return %0 : tensor<6x!Concrete.lwe_ciphertext<2048,4>> +} + +// ----- + +// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<2x3x2049xi64> { +// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x2049xi64> +// CHECK: %[[V1:.*]] = tensor.insert_slice %[[A0]] into %[[V0]][0, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: %[[V2:.*]] = tensor.insert_slice %[[A1]] into %[[V1]][0, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: %[[V3:.*]] = tensor.insert_slice %[[A2]] into %[[V2]][0, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: %[[V4:.*]] = tensor.insert_slice %[[A3]] into %[[V3]][1, 0, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: %[[V5:.*]] = tensor.insert_slice %[[A4]] into %[[V4]][1, 1, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: %[[V6:.*]] = tensor.insert_slice %[[A5]] into %[[V5]][1, 2, 0] [1, 1, 2049] [1, 1, 1] : tensor<2049xi64> into tensor<2x3x2049xi64> +// CHECK: return %[[V6]] : tensor<2x3x2049xi64> +// CHECK: } +func.func @main(%arg0 : !Concrete.lwe_ciphertext<2048,4>, %arg1 : !Concrete.lwe_ciphertext<2048,4>, %arg2 : !Concrete.lwe_ciphertext<2048,4>, %arg3 : !Concrete.lwe_ciphertext<2048,4>, %arg4 : !Concrete.lwe_ciphertext<2048,4>, %arg5 : !Concrete.lwe_ciphertext<2048,4>) -> tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> + return %0 : tensor<2x3x!Concrete.lwe_ciphertext<2048,4>> +} diff --git a/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml b/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml index 449f8e787..7d64960a3 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml @@ -111,3 +111,40 @@ tests: - tensor: [63, 12, 7, 43, 52, 31, 32, 34, 22, 0, 0, 1, 2, 3, 4, 33, 34, 7, 8, 9] shape: [2,10] +--- +description: from_elements +program: | + func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>, %arg2: !FHE.eint<4>, %arg3: !FHE.eint<4>, %arg4: !FHE.eint<4>, %arg5: !FHE.eint<4>) -> tensor<6x!FHE.eint<4>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<6x!FHE.eint<4>> + return %0 : tensor<6x!FHE.eint<4>> + } +tests: + - inputs: + - scalar: 0 + - scalar: 1 + - scalar: 2 + - scalar: 3 + - scalar: 4 + - scalar: 5 + outputs: + - tensor: [0, 1, 2, 3, 4, 5] + shape: [6] +--- +description: from_elements_2D +program: | + func.func @main(%arg0 : !FHE.eint<4>, %arg1 : !FHE.eint<4>, %arg2 : !FHE.eint<4>, %arg3 : !FHE.eint<4>, %arg4 : !FHE.eint<4>, %arg5 : !FHE.eint<4>) -> tensor<2x3x!FHE.eint<4>> { + %0 = tensor.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<2x3x!FHE.eint<4>> + return %0 : tensor<2x3x!FHE.eint<4>> + } +tests: + - inputs: + - scalar: 0 + - scalar: 1 + - scalar: 2 + - scalar: 3 + - scalar: 4 + - scalar: 5 + outputs: + - tensor: [0, 1, 2, + 3, 4, 5] + shape: [2, 3]