From 9e16f31b87ecdec1045f2bea17f4450aa952c4c5 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 8 Nov 2022 11:40:39 +0100 Subject: [PATCH] refactor(bconcrete): Separate bufferization and CAPI call generation --- .../Conversion/BConcreteToCAPI/Pass.h | 3 +- .../Dialect/BConcrete/IR/BConcreteOps.td | 154 ++++- .../include/concretelang/Support/Pipeline.h | 2 +- .../BConcreteToCAPI/BConcreteToCAPI.cpp | 372 ++++++++++- .../ConcreteToBConcrete.cpp | 28 +- .../BConcrete/Transforms/AsyncOffload.cpp | 8 +- .../BufferizableOpInterfaceImpl.cpp | 608 ++---------------- .../BConcrete/Transforms/EliminateCRTOps.cpp | 44 +- compiler/lib/Support/CompilerEngine.cpp | 2 +- compiler/lib/Support/Pipeline.cpp | 8 +- compiler/src/main.cpp | 1 + .../ConcreteToBConcrete/add_lwe.mlir | 4 +- .../ConcreteToBConcrete/add_lwe_int.mlir | 6 +- .../apply_lookup_table.mlir | 4 +- .../apply_lookup_table_cst.mlir | 4 +- .../ConcreteToBConcrete/mul_lwe_int.mlir | 6 +- .../ConcreteToBConcrete/neg_lwe.mlir | 4 +- .../Dialect/BConcrete/ops_memref.mlir | 37 ++ .../BConcrete/{ops.mlir => ops_tensor.mlir} | 40 +- .../end_to_end_tests/end_to_end_jit_fhe.cc | 3 +- 20 files changed, 668 insertions(+), 670 deletions(-) create mode 100644 compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir rename compiler/tests/check_tests/Dialect/BConcrete/{ops.mlir => ops_tensor.mlir} (79%) diff --git a/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h b/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h index 0b28a4692..ebd9c3d91 100644 --- a/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h +++ b/compiler/include/concretelang/Conversion/BConcreteToCAPI/Pass.h @@ -11,7 +11,8 @@ namespace mlir { namespace concretelang { /// Create a pass to convert `BConcrete` dialect to CAPI calls. -std::unique_ptr> createConvertBConcreteToCAPIPass(); +std::unique_ptr> +createConvertBConcreteToCAPIPass(bool gpu); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index ebf0a7d05..892354145 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -15,15 +15,17 @@ include "concretelang/Dialect/RT/IR/RTTypes.td" class BConcrete_Op traits = []> : Op; -def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> { - let arguments = (ins +// BConcrete tensor operators ///////////////////////////////////////////////// + +def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor"> { + let arguments = (ins 1DTensorOf<[I64]>:$lhs, 1DTensorOf<[I64]>:$rhs ); let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> { +def BConcrete_AddCRTLweTensorOp : BConcrete_Op<"add_crt_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$lhs, 2DTensorOf<[I64]>:$rhs, @@ -32,12 +34,12 @@ def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> { let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> { +def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor"> { let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> { +def BConcrete_AddPlaintextCRTLweTensorOp : BConcrete_Op<"add_plaintext_crt_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$lhs, AnyInteger:$rhs, @@ -46,12 +48,12 @@ def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_b let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> { +def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor"> { let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> { +def BConcrete_MulCleartextCRTLweTensorOp : BConcrete_Op<"mul_cleartext_crt_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$lhs, AnyInteger:$rhs, @@ -60,12 +62,12 @@ def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_b let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> { +def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor"> { let arguments = (ins 1DTensorOf<[I64]>:$ciphertext); let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> { +def BConcrete_NegateCRTLweTensorOp : BConcrete_Op<"negate_crt_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$ciphertext, I64ArrayAttr:$crtDecomposition @@ -73,7 +75,7 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> { let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { +def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor"> { let arguments = (ins // LweKeySwitchKeyType:$keyswitch_key, 1DTensorOf<[I64]>:$ciphertext, @@ -85,7 +87,7 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> { +def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_tensor"> { let arguments = (ins // LweKeySwitchKeyType:$keyswitch_key, 2DTensorOf<[I64]>:$ciphertext, @@ -97,7 +99,7 @@ def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_ let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { +def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor"> { let arguments = (ins 1DTensorOf<[I64]>:$input_ciphertext, 1DTensorOf<[I64]>:$lookup_table, @@ -111,7 +113,7 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> { +def BConcrete_BatchedBootstrapLweTensorOp : BConcrete_Op<"batched_bootstrap_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$input_ciphertext, 1DTensorOf<[I64]>:$lookup_table, @@ -126,7 +128,7 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_ } // TODO(16bits): hack -def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { +def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor"> { let arguments = (ins 2DTensorOf<[I64]>:$ciphertext, 1DTensorOf<[I64]>:$lookupTable, @@ -149,10 +151,9 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_KeySwitchLweBufferAsyncOffloadOp : - BConcrete_Op<"keyswitch_lwe_buffer_async_offload"> { +def BConcrete_KeySwitchLweTensorAsyncOffloadOp : + BConcrete_Op<"keyswitch_lwe_tensor_async_offload"> { let arguments = (ins - // LweKeySwitchKeyType:$keyswitch_key, 1DTensorOf<[I64]>:$ciphertext, I32Attr:$level, I32Attr:$baseLog @@ -160,8 +161,8 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp : let results = (outs RT_Future : $result); } -def BConcrete_BootstrapLweBufferAsyncOffloadOp : - BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> { +def BConcrete_BootstrapLweTensorAsyncOffloadOp : + BConcrete_Op<"bootstrap_lwe_tensor_async_offload"> { let arguments = (ins 1DTensorOf<[I64]>:$input_ciphertext, 1DTensorOf<[I64]>:$lookup_table, @@ -175,6 +176,121 @@ def BConcrete_BootstrapLweBufferAsyncOffloadOp : let results = (outs RT_Future : $result); } +// BConcrete memref operators ///////////////////////////////////////////////// + +def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>; +def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; +def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; + +def BConcrete_AddLweBufferOp : BConcrete_Op<"add_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$lhs, + BConcrete_LweBuffer:$rhs + ); +} + +def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$lhs, + I64:$rhs + ); +} + +def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$lhs, + I64:$rhs + ); +} + +def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$ciphertext + ); +} + +def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out + ); +} + +def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> { + let arguments = (ins + BConcrete_BatchLweBuffer:$result, + BConcrete_BatchLweBuffer:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$lwe_dim_in, + I32Attr:$lwe_dim_out + ); +} + +def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { + let arguments = (ins + BConcrete_LweBuffer:$result, + BConcrete_LweBuffer:$input_ciphertext, + MemRefRankOf<[I64], [1]>:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$outPrecision + ); +} + +def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> { + let arguments = (ins + BConcrete_BatchLweBuffer:$result, + BConcrete_BatchLweBuffer:$input_ciphertext, + MemRefRankOf<[I64], [1]>:$lookup_table, + I32Attr:$inputLweDim, + I32Attr:$polySize, + I32Attr:$level, + I32Attr:$baseLog, + I32Attr:$glweDimension, + I32Attr:$outPrecision + ); +} + +// TODO(16bits): hack +def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { + let arguments = (ins + BConcrete_LweCRTBuffer:$result, + BConcrete_LweCRTBuffer:$ciphertext, + MemRefRankOf<[I64], [1]>:$lookup_table, + // Bootstrap parameters + I32Attr : $bootstrapLevel, + I32Attr : $bootstrapBaseLog, + // Keyswitch parameters + I32Attr : $keyswitchLevel, + I32Attr : $keyswitchBaseLog, + // Packing keyswitch key parameters + I32Attr : $packingKeySwitchInputLweDimension, + I32Attr : $packingKeySwitchoutputPolynomialSize, + I32Attr : $packingKeySwitchLevel, + I32Attr : $packingKeySwitchBaseLog, + // Circuit bootstrap parameters + I32Attr : $circuitBootstrapLevel, + I32Attr : $circuitBootstrapBaseLog, + I64ArrayAttr:$crtDecomposition + ); +} + +// TODO + + + def BConcrete_AwaitFutureOp : BConcrete_Op<"await_future"> { let arguments = (ins RT_Future : $future); diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 4a2cce61c..8178d689f 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -69,7 +69,7 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - bool parallelizeLoops); + bool parallelizeLoops, bool gpu); mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, llvm::Module &module); diff --git a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp index 8080c9b66..e940b238a 100644 --- a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp @@ -7,29 +7,375 @@ #include #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Tools.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" namespace { -struct BConcreteToCAPIPass : public BConcreteToCAPIBase { - void runOnOperation() final; -}; -} // namespace -void BConcreteToCAPIPass::runOnOperation() { - auto op = this->getOperation(); +namespace BConcrete = mlir::concretelang::BConcrete; +namespace arith = mlir::arith; +namespace func = mlir::func; +namespace memref = mlir::memref; - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); +char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64"; +char memref_add_plaintext_lwe_ciphertext_u64[] = + "memref_add_plaintext_lwe_ciphertext_u64"; +char memref_mul_cleartext_lwe_ciphertext_u64[] = + "memref_mul_cleartext_lwe_ciphertext_u64"; +char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; +char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; +char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; +char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; +char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); +char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; +char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; +char memref_await_future[] = "memref_await_future"; +char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64"; +char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64"; +char memref_expand_lut_in_trivial_glwe_ct_u64[] = + "memref_expand_lut_in_trivial_glwe_ct_u64"; + +char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; + +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank) { + std::vector shape(rank, -1); + mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); + for (size_t i = 0; i < rank; i++) { + expr = expr + + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); + } + return mlir::MemRefType::get( + shape, rewriter.getI64Type(), + mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); +} + +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) { + mlir::Type valueType = value.getType(); + + if (auto memrefTy = valueType.dyn_cast_or_null()) { + return rewriter.create( + value.getLoc(), + getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), + value); + } else { + return value; } } +mlir::LogicalResult insertForwardDeclarationOfTheCAPI( + mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { + + auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); + auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); + auto futureType = + mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); + auto contextType = + mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); + auto i32Type = rewriter.getI32Type(); + + mlir::FunctionType funcType; + + if (funcName == memref_add_lwe_ciphertexts_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {}); + } else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, memref1DType, rewriter.getI64Type()}, {}); + } else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, memref1DType, rewriter.getI64Type()}, {}); + } else if (funcName == memref_negate_lwe_ciphertext_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType}, {}); + } else if (funcName == memref_keyswitch_lwe_u64 || + funcName == memref_keyswitch_lwe_cuda_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); + } else if (funcName == memref_bootstrap_lwe_u64 || + funcName == memref_bootstrap_lwe_cuda_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); + } else if (funcName == memref_keyswitch_async_lwe_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref1DType, memref1DType, contextType}, + {futureType}); + } else if (funcName == memref_bootstrap_async_lwe_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {futureType}); + } else if (funcName == memref_batched_keyswitch_lwe_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); + } else if (funcName == memref_batched_bootstrap_lwe_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref2DType, memref2DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); + } else if (funcName == memref_await_future) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, futureType, memref1DType, memref1DType}, {}); + } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + { + memref1DType, + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + memref1DType, + }, + {}); + } else if (funcName == memref_wop_pbs_crt_buffer) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + { + memref2DType, + memref2DType, + memref1DType, + memref1DType, + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + contextType, + }, + {}); + } else { + op->emitError("unknwon external function") << funcName; + return mlir::failure(); + } + + return insertForwardDeclaration(op, rewriter, funcName, funcType); +} + +template +void addNoOperands(BConcreteOp op, mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) {} + +template +struct BConcreteToCAPICallPattern : public mlir::OpRewritePattern { + BConcreteToCAPICallPattern( + ::mlir::MLIRContext *context, + std::function &, + mlir::RewriterBase &)> + addOperands = addNoOperands, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit), + addOperands(addOperands) {} + + ::mlir::LogicalResult + matchAndRewrite(BConcreteOp bOp, + ::mlir::PatternRewriter &rewriter) const override { + + // Create the operands + mlir::SmallVector operands; + // For all tensor operand get the corresponding casted buffer + for (auto &operand : bOp->getOpOperands()) { + mlir::Type type = operand.get().getType(); + if (!type.isa()) { + operands.push_back(operand.get()); + } else { + operands.push_back(getCastedMemRef(rewriter, operand.get())); + } + } + + // append additional argument + addOperands(bOp, operands, rewriter); + + // Insert forward declaration of the function + if (insertForwardDeclarationOfTheCAPI(bOp, rewriter, callee).failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(bOp, callee, mlir::TypeRange{}, + operands); + + return ::mlir::success(); + }; + +private: + std::function &, + mlir::RewriterBase &)> + addOperands; +}; + +template +void keyswitchAddOperands(KeySwitchOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + // level + operands.push_back( + rewriter.create(op.getLoc(), op.levelAttr())); + // base_log + operands.push_back( + rewriter.create(op.getLoc(), op.baseLogAttr())); + // lwe_dim_in + operands.push_back( + rewriter.create(op.getLoc(), op.lwe_dim_inAttr())); + // lwe_dim_out + operands.push_back( + rewriter.create(op.getLoc(), op.lwe_dim_outAttr())); + // context + operands.push_back(getContextArgument(op)); +} + +template +void bootstrapAddOperands(BootstrapOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + // input_lwe_dim + operands.push_back(rewriter.create( + op.getLoc(), op.inputLweDimAttr())); + // poly_size + operands.push_back( + rewriter.create(op.getLoc(), op.polySizeAttr())); + // level + operands.push_back( + rewriter.create(op.getLoc(), op.levelAttr())); + // base_log + operands.push_back( + rewriter.create(op.getLoc(), op.baseLogAttr())); + // glwe_dim + operands.push_back(rewriter.create( + op.getLoc(), op.glweDimensionAttr())); + // out_precision + operands.push_back(rewriter.create( + op.getLoc(), op.outPrecisionAttr())); + // context + operands.push_back(getContextArgument(op)); +} + +void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + mlir::Type crtType = mlir::RankedTensorType::get( + {(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type()); + std::vector values; + for (auto a : op.crtDecomposition()) { + values.push_back(a.cast().getValue().getZExtValue()); + } + auto attr = rewriter.getI64TensorAttr(values); + auto x = rewriter.create(op.getLoc(), attr, crtType); + auto globalMemref = mlir::bufferization::getGlobalFor(x, 0); + rewriter.eraseOp(x); + assert(!failed(globalMemref)); + + auto globalRef = rewriter.create( + op.getLoc(), (*globalMemref).type(), (*globalMemref).getName()); + operands.push_back(getCastedMemRef(rewriter, globalRef)); + + // lwe_small_size + operands.push_back(rewriter.create( + op.getLoc(), op.packingKeySwitchInputLweDimensionAttr())); + // cbs_level_count + operands.push_back(rewriter.create( + op.getLoc(), op.circuitBootstrapLevelAttr())); + // cbs_base_log + operands.push_back(rewriter.create( + op.getLoc(), op.circuitBootstrapBaseLogAttr())); + // polynomial_size + operands.push_back(rewriter.create( + op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr())); + // context + operands.push_back(getContextArgument(op)); +} + +struct BConcreteToCAPIPass : public BConcreteToCAPIBase { + + BConcreteToCAPIPass(bool gpu) : gpu(gpu) {} + + void runOnOperation() override { + auto op = this->getOperation(); + + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + // Mark ops from the target dialect as legal operations + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Make sure that no ops from `FHE` remain after the lowering + target.addIllegalDialect(); + + // Add patterns to transform BConcrete operators to CAPI call + patterns.add>( + &getContext()); + patterns.add< + BConcreteToCAPICallPattern>( + &getContext()); + patterns.add< + BConcreteToCAPICallPattern>( + &getContext()); + patterns.add>( + &getContext()); + if (gpu) { + patterns.add>( + &getContext(), keyswitchAddOperands); + patterns.add>( + &getContext(), bootstrapAddOperands); + } else { + patterns.add>( + &getContext(), keyswitchAddOperands); + patterns.add>( + &getContext(), bootstrapAddOperands); + patterns.add< + BConcreteToCAPICallPattern>( + &getContext(), + keyswitchAddOperands); + patterns.add< + BConcreteToCAPICallPattern>( + &getContext(), + bootstrapAddOperands); + } + + patterns.add>( + &getContext(), wopPBSAddOperands); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } + +private: + bool gpu; +}; + +} // namespace + namespace mlir { namespace concretelang { -std::unique_ptr> createConvertBConcreteToCAPIPass() { - return std::make_unique(); +std::unique_ptr> +createConvertBConcreteToCAPIPass(bool gpu) { + return std::make_unique(gpu); } } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index a0e9c993c..45a28ad37 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -216,7 +216,7 @@ struct LowerKeySwitch : public mlir::OpRewritePattern< rewriter.getI32IntegerAttr(inputType.getDimension()); mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::KeySwitchLweBufferOp>( + mlir::concretelang::BConcrete::KeySwitchLweTensorOp>( ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(), inputDimAttr, outDimAttr); @@ -261,7 +261,7 @@ struct LowerBatchedKeySwitch rewriter.getI32IntegerAttr(inputType.getDimension()); mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>( + mlir::concretelang::BConcrete::BatchedKeySwitchLweTensorOp>( bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(), bksOp.baseLogAttr(), inputDimAttr, outDimAttr); @@ -293,7 +293,7 @@ struct LowerBootstrap : public mlir::OpRewritePattern< auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension()); auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BootstrapLweBufferOp>( + mlir::concretelang::BConcrete::BootstrapLweTensorOp>( bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(), inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(), bsOp.glweDimensionAttr(), outputPrecisionAttr); @@ -338,7 +338,7 @@ struct LowerBatchedBootstrap auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp< - mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>( + mlir::concretelang::BConcrete::BatchedBootstrapLweTensorOp>( bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(), inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(), bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr); @@ -385,7 +385,7 @@ struct AddPlaintextLweCiphertextOpPattern auto encoded = rewriter.create( loc, rewriter.getI64Type(), castedInt, constantShiftOp); bConcreteOp = - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( concreteOp, newResultTy, mlir::ValueRange{concreteOp.lhs(), encoded}, attributes); } else { @@ -394,7 +394,7 @@ struct AddPlaintextLweCiphertextOpPattern newAttributes.push_back(rewriter.getNamedAttr( "crtDecomposition", rewriter.getI64ArrayAttr(crt))); bConcreteOp = - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), newAttributes); } @@ -436,7 +436,7 @@ struct MulCleartextLweCiphertextOpPattern mlir::Value castedInt = rewriter.create( loc, rewriter.getIntegerType(64), concreteOp.rhs()); bConcreteOp = - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( concreteOp, newResultTy, mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes); } else { @@ -444,7 +444,7 @@ struct MulCleartextLweCiphertextOpPattern newAttributes.push_back(rewriter.getNamedAttr( "crtDecomposition", rewriter.getI64ArrayAttr(crt))); bConcreteOp = - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), newAttributes); } @@ -1022,14 +1022,14 @@ void ConcreteToBConcretePass::runOnOperation() { LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch, LowerBatchedKeySwitch, LowToBConcrete, + mlir::concretelang::BConcrete::AddLweTensorOp, + BConcrete::AddCRTLweTensorOp>, AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern, LowToBConcrete, - LowToBConcrete>(&getContext()); + mlir::concretelang::BConcrete::NegateLweTensorOp, + BConcrete::NegateCRTLweTensorOp>, + LowToBConcrete>(&getContext()); // Add patterns to rewrite tensor operators that works on encrypted // tensors diff --git a/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp b/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp index 4a0310639..f6fa898ab 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp @@ -22,12 +22,12 @@ void AsyncOffloadPass::runOnOperation() { auto module = getOperation(); std::vector ops; - module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweBufferOp op) { + module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweTensorOp op) { mlir::OpBuilder builder(op); mlir::Type futType = mlir::concretelang::RT::FutureType::get(op.getResult().getType()); mlir::Value future = builder.create< - mlir::concretelang::BConcrete::KeySwitchLweBufferAsyncOffloadOp>( + mlir::concretelang::BConcrete::KeySwitchLweTensorAsyncOffloadOp>( op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs()); assert(op.getResult().hasOneUse() && @@ -43,12 +43,12 @@ void AsyncOffloadPass::runOnOperation() { } ops.push_back(op); }); - module.walk([&](mlir::concretelang::BConcrete::BootstrapLweBufferOp op) { + module.walk([&](mlir::concretelang::BConcrete::BootstrapLweTensorOp op) { mlir::OpBuilder builder(op); mlir::Type futType = mlir::concretelang::RT::FutureType::get(op.getResult().getType()); mlir::Value future = builder.create< - mlir::concretelang::BConcrete::BootstrapLweBufferAsyncOffloadOp>( + mlir::concretelang::BConcrete::BootstrapLweTensorAsyncOffloadOp>( op.getLoc(), mlir::TypeRange{futType}, op.getOperands(), op->getAttrs()); diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index 372bc6a7a..c54fcf3a0 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -27,170 +27,13 @@ using namespace mlir; using namespace mlir::bufferization; using namespace mlir::tensor; -namespace BConcrete = mlir::concretelang::BConcrete; - -namespace mlir { -namespace concretelang { -namespace BConcrete { -namespace {} // namespace -} // namespace BConcrete -} // namespace concretelang -} // namespace mlir - namespace { -mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, - size_t rank) { - std::vector shape(rank, -1); - mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); - for (size_t i = 0; i < rank; i++) { - expr = expr + - (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); - } - return mlir::MemRefType::get( - shape, rewriter.getI64Type(), - mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); -} +namespace BConcrete = mlir::concretelang::BConcrete; -// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` -mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, - mlir::Value value) { - mlir::Type valueType = value.getType(); - if (auto memrefTy = valueType.dyn_cast_or_null()) { - return rewriter.create( - loc, - getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), - value); - } else { - return value; - } -} - -char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64"; -char memref_add_plaintext_lwe_ciphertext_u64[] = - "memref_add_plaintext_lwe_ciphertext_u64"; -char memref_mul_cleartext_lwe_ciphertext_u64[] = - "memref_mul_cleartext_lwe_ciphertext_u64"; -char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; -char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; -char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64"; -char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; -char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64"; -char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; -char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; -char memref_await_future[] = "memref_await_future"; -char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64"; -char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64"; -char memref_expand_lut_in_trivial_glwe_ct_u64[] = - "memref_expand_lut_in_trivial_glwe_ct_u64"; - -char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; - -mlir::LogicalResult insertForwardDeclarationOfTheCAPI( - mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { - - auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); - auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); - auto futureType = - mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); - auto contextType = - mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - auto i32Type = rewriter.getI32Type(); - - mlir::FunctionType funcType; - - if (funcName == memref_add_lwe_ciphertexts_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {}); - } else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), - {memref1DType, memref1DType, rewriter.getI64Type()}, {}); - } else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), - {memref1DType, memref1DType, rewriter.getI64Type()}, {}); - } else if (funcName == memref_negate_lwe_ciphertext_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref1DType, memref1DType}, {}); - } else if (funcName == memref_keyswitch_lwe_u64 || - funcName == memref_keyswitch_lwe_cuda_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref1DType, memref1DType, i32Type, - i32Type, i32Type, i32Type, contextType}, - {}); - } else if (funcName == memref_batched_keyswitch_lwe_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref2DType, memref2DType, i32Type, - i32Type, i32Type, i32Type, contextType}, - {}); - } else if (funcName == memref_bootstrap_lwe_u64 || - funcName == memref_bootstrap_lwe_cuda_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref1DType, memref1DType, - memref1DType, i32Type, i32Type, i32Type, - i32Type, i32Type, i32Type, contextType}, - {}); - } else if (funcName == memref_batched_bootstrap_lwe_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref2DType, memref2DType, - memref1DType, i32Type, i32Type, i32Type, - i32Type, i32Type, i32Type, contextType}, - {}); - } else if (funcName == memref_keyswitch_async_lwe_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), {memref1DType, memref1DType, contextType}, - {futureType}); - } else if (funcName == memref_bootstrap_async_lwe_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - {memref1DType, memref1DType, - memref1DType, i32Type, i32Type, i32Type, - i32Type, i32Type, i32Type, contextType}, - {futureType}); - } else if (funcName == memref_await_future) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), - {memref1DType, futureType, memref1DType, memref1DType}, {}); - } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - { - memref1DType, - rewriter.getI32Type(), - rewriter.getI32Type(), - rewriter.getI32Type(), - memref1DType, - }, - {}); - } else if (funcName == memref_wop_pbs_crt_buffer) { - funcType = mlir::FunctionType::get(rewriter.getContext(), - { - memref2DType, - memref2DType, - memref1DType, - memref1DType, - rewriter.getI32Type(), - rewriter.getI32Type(), - rewriter.getI32Type(), - rewriter.getI32Type(), - contextType, - }, - {}); - } else { - op->emitError("unknwon external function") << funcName; - return mlir::failure(); - } - - return insertForwardDeclaration(op, rewriter, funcName, funcType); -} - -template -void pushAdditionalArgs(Op op, mlir::SmallVector &operands, - RewriterBase &rewriter); - -template -struct BufferizableWithCallOpInterface - : public BufferizableOpInterface::ExternalModel< - BufferizableWithCallOpInterface, Op> { +template +struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel< + TensorToMemrefOp, TensorOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; @@ -215,161 +58,7 @@ struct BufferizableWithCallOpInterface const BufferizationOptions &options) const { auto loc = op->getLoc(); - auto castOp = cast(op); - - // For now we always alloc for the result, we didn't have the in place - // operators yet. - auto resTensorType = - castOp.result().getType().template cast(); - - auto outMemrefType = MemRefType::get(resTensorType.getShape(), - resTensorType.getElementType()); - auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); - if (mlir::failed(outMemref)) { - return mlir::failure(); - } - - // The first operand is the result - mlir::SmallVector operands{ - getCastedMemRef(rewriter, loc, *outMemref), - }; - // For all tensor operand get the corresponding casted buffer - for (auto &operand : op->getOpOperands()) { - if (!operand.get().getType().isa()) { - operands.push_back(operand.get()); - } else { - auto memrefOperand = - bufferization::getBuffer(rewriter, operand.get(), options); - operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); - } - } - // Append additional argument - pushAdditionalArgs(castOp, operands, rewriter); - - // Insert forward declaration of the function - if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { - return mlir::failure(); - } - - rewriter.create(loc, funcName, mlir::TypeRange{}, - operands); - - replaceOpWithBufferizedValues(rewriter, op, *outMemref); - - return success(); - } -}; - -template -struct BufferizableWithAsyncCallOpInterface - : public BufferizableOpInterface::ExternalModel< - BufferizableWithAsyncCallOpInterface, Op> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return false; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::None; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - - auto loc = op->getLoc(); - auto castOp = cast(op); - - // For now we always alloc for the result, we didn't have the in place - // operators yet. - auto resTensorType = - castOp.result() - .getType() - .template cast() - .getElementType() - .template cast(); - - auto outMemrefType = MemRefType::get(resTensorType.getShape(), - resTensorType.getElementType()); - auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); - if (mlir::failed(outMemref)) { - return mlir::failure(); - } - - // The first operand is the result - mlir::SmallVector operands{ - getCastedMemRef(rewriter, loc, *outMemref), - }; - // For all tensor operand get the corresponding casted buffer - for (auto &operand : op->getOpOperands()) { - if (!operand.get().getType().isa()) { - operands.push_back(operand.get()); - } else { - auto memrefOperand = - bufferization::getBuffer(rewriter, operand.get(), options); - operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); - } - } - - // Append additional arguments - pushAdditionalArgs(castOp, operands, rewriter); - - // Insert forward declaration of the function - if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { - return mlir::failure(); - } - - auto result = rewriter.create( - loc, funcName, - mlir::TypeRange{ - mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())}, - operands); - - replaceOpWithBufferizedValues(rewriter, op, result.getResult(0)); - - return success(); - } -}; - -template -struct BufferizableWithSyncCallOpInterface - : public BufferizableOpInterface::ExternalModel< - BufferizableWithSyncCallOpInterface, Op> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return false; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::None; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - - auto loc = op->getLoc(); - auto castOp = cast(op); + auto castOp = cast(op); auto resTensorType = castOp.result().getType().template cast(); @@ -383,23 +72,18 @@ struct BufferizableWithSyncCallOpInterface // The first operand is the result mlir::SmallVector operands{ - getCastedMemRef(rewriter, loc, *outMemref), + *outMemref, }; - // Then add the future operand - operands.push_back(op->getOpOperand(0).get()); - // Finally add a dependence on the memref covered by the future to - // prevent early deallocation - auto def = op->getOpOperand(0).get().getDefiningOp(); - operands.push_back(def->getOpOperand(0).get()); - operands.push_back(def->getOpOperand(1).get()); - - // Insert forward declaration of the function - if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { - return mlir::failure(); + for (auto &operand : op->getOpOperands()) { + if (!operand.get().getType().isa()) { + operands.push_back(operand.get()); + } else { + operands.push_back( + bufferization::getBuffer(rewriter, operand.get(), options)); + } } - rewriter.create(loc, funcName, mlir::TypeRange{}, - operands); + rewriter.create(loc, mlir::TypeRange{}, operands, op->getAttrs()); replaceOpWithBufferizedValues(rewriter, op, *outMemref); @@ -407,239 +91,49 @@ struct BufferizableWithSyncCallOpInterface } }; -template <> -void pushAdditionalArgs(BConcrete::AddPlaintextLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) {} -template <> -void pushAdditionalArgs(BConcrete::AddLweBuffersOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) {} -template <> -void pushAdditionalArgs(BConcrete::MulCleartextLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) {} -template <> -void pushAdditionalArgs(BConcrete::NegateLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) {} - -template <> -void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // level - operands.push_back( - rewriter.create(op.getLoc(), op.levelAttr())); - // base_log - operands.push_back( - rewriter.create(op.getLoc(), op.baseLogAttr())); - // lwe_dim_in - operands.push_back(rewriter.create( - op.getLoc(), op.lwe_dim_inAttr())); - // lwe_dim_out - operands.push_back(rewriter.create( - op.getLoc(), op.lwe_dim_outAttr())); - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // level - operands.push_back( - rewriter.create(op.getLoc(), op.levelAttr())); - // base_log - operands.push_back( - rewriter.create(op.getLoc(), op.baseLogAttr())); - // lwe_dim_in - operands.push_back(rewriter.create( - op.getLoc(), op.lwe_dim_inAttr())); - // lwe_dim_out - operands.push_back(rewriter.create( - op.getLoc(), op.lwe_dim_outAttr())); - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // input_lwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.inputLweDimAttr())); - // poly_size - operands.push_back( - rewriter.create(op.getLoc(), op.polySizeAttr())); - // level - operands.push_back( - rewriter.create(op.getLoc(), op.levelAttr())); - // base_log - operands.push_back( - rewriter.create(op.getLoc(), op.baseLogAttr())); - // glwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.glweDimensionAttr())); - // out_precision - operands.push_back(rewriter.create( - op.getLoc(), op.outPrecisionAttr())); - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // input_lwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.inputLweDimAttr())); - // poly_size - operands.push_back( - rewriter.create(op.getLoc(), op.polySizeAttr())); - // level - operands.push_back( - rewriter.create(op.getLoc(), op.levelAttr())); - // base_log - operands.push_back( - rewriter.create(op.getLoc(), op.baseLogAttr())); - // glwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.glweDimensionAttr())); - // out_precision - operands.push_back(rewriter.create( - op.getLoc(), op.outPrecisionAttr())); - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::BootstrapLweBufferAsyncOffloadOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - // input_lwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.inputLweDimAttr())); - // poly_size - operands.push_back( - rewriter.create(op.getLoc(), op.polySizeAttr())); - // level - operands.push_back( - rewriter.create(op.getLoc(), op.levelAttr())); - // base_log - operands.push_back( - rewriter.create(op.getLoc(), op.baseLogAttr())); - // glwe_dim - operands.push_back(rewriter.create( - op.getLoc(), op.glweDimensionAttr())); - // out_precision - operands.push_back(rewriter.create( - op.getLoc(), op.outPrecisionAttr())); - // context - operands.push_back(getContextArgument(op)); -} - -template <> -void pushAdditionalArgs(BConcrete::WopPBSCRTLweBufferOp op, - mlir::SmallVector &operands, - RewriterBase &rewriter) { - mlir::Type crtType = mlir::RankedTensorType::get( - {(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type()); - std::vector values; - for (auto a : op.crtDecomposition()) { - values.push_back(a.cast().getValue().getZExtValue()); - } - auto attr = rewriter.getI64TensorAttr(values); - auto x = rewriter.create(op.getLoc(), attr, crtType); - auto globalMemref = bufferization::getGlobalFor(x, 0); - assert(!failed(globalMemref)); - - auto globalRef = rewriter.create( - op.getLoc(), (*globalMemref).type(), (*globalMemref).getName()); - operands.push_back(getCastedMemRef(rewriter, op.getLoc(), globalRef)); - - // lwe_small_size - operands.push_back(rewriter.create( - op.getLoc(), op.packingKeySwitchInputLweDimensionAttr())); - // cbs_level_count - operands.push_back(rewriter.create( - op.getLoc(), op.circuitBootstrapLevelAttr())); - // cbs_base_log - operands.push_back(rewriter.create( - op.getLoc(), op.circuitBootstrapBaseLogAttr())); - // polynomial_size - operands.push_back(rewriter.create( - op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr())); - // context - operands.push_back(getContextArgument(op)); -} } // namespace void mlir::concretelang::BConcrete:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, BConcrete::BConcreteDialect *dialect) { - BConcrete::AddLweBuffersOp::attachInterface>(*ctx); - BConcrete::AddPlaintextLweBufferOp::attachInterface< - BufferizableWithCallOpInterface< - BConcrete::AddPlaintextLweBufferOp, - memref_add_plaintext_lwe_ciphertext_u64>>(*ctx); - BConcrete::MulCleartextLweBufferOp::attachInterface< - BufferizableWithCallOpInterface< - BConcrete::MulCleartextLweBufferOp, - memref_mul_cleartext_lwe_ciphertext_u64>>(*ctx); - BConcrete::NegateLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>( + // add_lwe_tensor => add_lwe_buffer + BConcrete::AddLweTensorOp::attachInterface< + TensorToMemrefOp>( + *ctx); + // add_plaintext_lwe_tensor => add_plaintext_lwe_buffer + BConcrete::AddPlaintextLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // mul_cleartext_lwe_tensor => mul_cleartext_lwe_buffer + BConcrete::MulCleartextLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer + BConcrete::NegateLweTensorOp::attachInterface>(*ctx); + // negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer + BConcrete::NegateLweTensorOp::attachInterface>(*ctx); + // keyswitch_lwe_tensor => keyswitch_lwe_buffer + BConcrete::KeySwitchLweTensorOp::attachInterface>( + *ctx); + // bootstrap_lwe_tensor => bootstrap_lwe_buffer + BConcrete::BootstrapLweTensorOp::attachInterface>( + *ctx); + // batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer + BConcrete::BatchedKeySwitchLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // batched_bootstrap_lwe_tensor => batched_bootstrap_lwe_buffer + BConcrete::BatchedBootstrapLweTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); + // wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer + BConcrete::WopPBSCRTLweTensorOp::attachInterface>( *ctx); - if (mlir::concretelang::getEmitGPUOption()) { - BConcrete::KeySwitchLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>(*ctx); - BConcrete::BootstrapLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>(*ctx); - } else { - BConcrete::KeySwitchLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>(*ctx); - BConcrete::BatchedKeySwitchLweBufferOp::attachInterface< - BufferizableWithCallOpInterface< - BConcrete::BatchedKeySwitchLweBufferOp, - memref_batched_keyswitch_lwe_u64>>(*ctx); - BConcrete::BootstrapLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>(*ctx); - BConcrete::BatchedBootstrapLweBufferOp::attachInterface< - BufferizableWithCallOpInterface< - BConcrete::BatchedBootstrapLweBufferOp, - memref_batched_bootstrap_lwe_u64>>(*ctx); - } - BConcrete::WopPBSCRTLweBufferOp::attachInterface< - BufferizableWithCallOpInterface>(*ctx); - BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface< - BufferizableWithAsyncCallOpInterface< - BConcrete::KeySwitchLweBufferAsyncOffloadOp, - memref_keyswitch_async_lwe_u64>>(*ctx); - BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface< - BufferizableWithAsyncCallOpInterface< - BConcrete::BootstrapLweBufferAsyncOffloadOp, - memref_bootstrap_async_lwe_u64>>(*ctx); - BConcrete::AwaitFutureOp::attachInterface< - BufferizableWithSyncCallOpInterface>(*ctx); }); } diff --git a/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp b/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp index 3c3cf1161..a9b14d54a 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp @@ -273,16 +273,16 @@ struct BConcreteCRTBinaryOpPattern // scf.yield %res : tensor // } // ``` -struct AddPlaintextCRTLweBufferOpPattern - : public mlir::OpRewritePattern { - AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context, +struct AddPlaintextCRTLweTensorOpPattern + : public mlir::OpRewritePattern { + AddPlaintextCRTLweTensorOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, + : mlir::OpRewritePattern(context, benefit) { } mlir::LogicalResult - matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op, + matchAndRewrite(BConcrete::AddPlaintextCRTLweTensorOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); @@ -381,7 +381,7 @@ struct AddPlaintextCRTLweBufferOpPattern auto blockArg1 = builder.create(loc, x_decomp, i); // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) // : (tensor, i64) -> (tensor) - auto tmp = builder.create( + auto tmp = builder.create( loc, blockTy, blockArg0, blockArg1); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, @@ -436,16 +436,16 @@ struct AddPlaintextCRTLweBufferOpPattern // scf.yield %res : tensor // } // ``` -struct MulCleartextCRTLweBufferOpPattern - : public mlir::OpRewritePattern { - MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context, +struct MulCleartextCRTLweTensorOpPattern + : public mlir::OpRewritePattern { + MulCleartextCRTLweTensorOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, + : mlir::OpRewritePattern(context, benefit) { } mlir::LogicalResult - matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op, + matchAndRewrite(BConcrete::MulCleartextCRTLweTensorOp op, mlir::PatternRewriter &rewriter) const override { auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); @@ -494,7 +494,7 @@ struct MulCleartextCRTLweBufferOpPattern // %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x) // : (tensor, i64) -> (tensor) - auto tmp = builder.create( + auto tmp = builder.create( loc, blockTy, blockArg0, rhs); // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, @@ -520,22 +520,22 @@ void EliminateCRTOpsPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); // add_crt_lwe_buffers - target.addIllegalOp(); - patterns.add>( + target.addIllegalOp(); + patterns.add>( &getContext()); // add_plaintext_crt_lwe_buffers - target.addIllegalOp(); - patterns.add(&getContext()); + target.addIllegalOp(); + patterns.add(&getContext()); // mul_cleartext_crt_lwe_buffer - target.addIllegalOp(); - patterns.add(&getContext()); + target.addIllegalOp(); + patterns.add(&getContext()); - target.addIllegalOp(); - patterns.add>( + target.addIllegalOp(); + patterns.add>( &getContext()); // This dialect are used to transforms crt ops to bconcrete ops diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index d8650f7ed..c2e449f85 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -414,7 +414,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { // MLIR canonical dialects -> LLVM Dialect if (mlir::concretelang::pipeline::lowerStdToLLVMDialect( - mlirContext, module, enablePass, loopParallelize) + mlirContext, module, enablePass, loopParallelize, options.emitGPUOps) .failed()) { return errorDiag("Failed to lower to LLVM dialect"); } diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 98925c1a5..f4b23c7b7 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -293,15 +293,13 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, enablePass); addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(), enablePass); - addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertBConcreteToCAPIPass(), enablePass); return pm.run(module.getOperation()); } mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - bool parallelizeLoops) { + bool parallelizeLoops, bool gpu) { mlir::PassManager pm(&context); pipelinePrinting("StdToLLVM", pm, context); @@ -345,6 +343,10 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu), + enablePass); + // Convert to MLIR LLVM Dialect addPotentiallyNestedPass( pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(), diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 8c3b61241..1f1fb871c 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -300,6 +300,7 @@ cmdlineCompilationOptions() { options.loopParallelize = cmdline::loopParallelize; options.dataflowParallelize = cmdline::dataflowParallelize; options.batchConcreteOps = cmdline::batchConcreteOps; + options.asyncOffload = cmdline::asyncOffload; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps; diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir index c2f5eec44..24521602b 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s //CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { @@ -10,7 +10,7 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: ! } //CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> //CHECK: return %[[V0]] : tensor<5x2049xi64> //CHECK: } func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext, %arg1: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir index 4cfcb99a4..af33f3b12 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir @@ -6,7 +6,7 @@ //CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64 //CHECK: %c56_i64 = arith.constant 56 : i64 //CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64 -//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> //CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { @@ -19,7 +19,7 @@ func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concr //CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64 //CHECK: %c59_i64 = arith.constant 59 : i64 //CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64 -//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> //CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { @@ -30,7 +30,7 @@ func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> ! //CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> { //CHECK: %c1_i8 = arith.constant 1 : i8 -//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64> //CHECK: return %[[V0]] : tensor<5x1025xi64> //CHECK: } func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir index 1569b3418..617f861bb 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir @@ -1,8 +1,8 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s //CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> { -//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64> +//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64> +//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64> //CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> { diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir index 93ff17d5e..ac1b09fde 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir @@ -2,8 +2,8 @@ //CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { //CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> -//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64> +//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64> +//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64> //CHECK: return %[[V2]] : tensor<2049xi64> //CHECK: } func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> { diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir index eb70c0cae..0acc7f376 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir @@ -3,7 +3,7 @@ //CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i8 = arith.constant 1 : i8 //CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64 -//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> //CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { @@ -14,7 +14,7 @@ func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concre //CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> { //CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64 -//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> //CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { @@ -24,7 +24,7 @@ func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !C //CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64> +//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64> //CHECK: return %[[V0]] : tensor<5x1025xi64> //CHECK: } func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext, %arg1: i5) -> !Concrete.lwe_ciphertext { diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir index 6157557e7..0cf067eeb 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s //CHECK: func.func @neg_lwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> +//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64> //CHECK: return %[[V0]] : tensor<1025xi64> //CHECK: } func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> { @@ -10,7 +10,7 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip } //CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64> +//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64> //CHECK: return %[[V0]] : tensor<5x1025xi64> //CHECK: } func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir new file mode 100644 index 000000000..91d822d91 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir @@ -0,0 +1,37 @@ +// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s + +func.func @add_lwe_ciphertexts(%arg0: memref<2049xi64>, %arg1: memref<2049xi64>, %result : memref<2049xi64>) { + //CHECK: "BConcrete.add_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () + "BConcrete.add_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> () + return +} + +func.func @add_plaintext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { + //CHECK: "BConcrete.add_plaintext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + "BConcrete.add_plaintext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + return +} + +func.func @mul_cleartext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) { + //CHECK: "BConcrete.mul_cleartext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + "BConcrete.mul_cleartext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> () + return +} + +func.func @negate_lwe_ciphertext(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { + //CHECK: "BConcrete.negate_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>) -> () + "BConcrete.negate_lwe_buffer"(%result, %arg0) : (memref<2049xi64>, memref<2049xi64>) -> () + return +} + +func.func @bootstrap_lwe(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) { + //CHECK: "BConcrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () + "BConcrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> () + return +} + +func.func @keyswitch_lwe(%arg0: memref<2049xi64>, %result: memref<2049xi64>) { + //CHECK: "BConcrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () + "BConcrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> () + return +} diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir similarity index 79% rename from compiler/tests/check_tests/Dialect/BConcrete/ops.mlir rename to compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir index 5dc8fa7b5..cfed4b403 100644 --- a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/BConcrete/ops_tensor.mlir @@ -1,91 +1,91 @@ // RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s //CHECK: func.func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.add_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>) + %0 = "BConcrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>) return %0 : tensor<2049xi64> } //CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> //CHECK: return %[[V0]] : tensor<5x2049xi64> //CHECK: } func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> { - %0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>) + %0 = "BConcrete.add_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>) return %0 : tensor<5x2049xi64> } //CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { - %0 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>) + %0 = "BConcrete.add_plaintext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>) return %0 : tensor<2049xi64> } //CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> //CHECK: return %[[V0]] : tensor<5x2049xi64> //CHECK: } func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> { - %0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>) + %0 = "BConcrete.add_plaintext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>) return %0 : tensor<5x2049xi64> } //CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> { - %0 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>) + %0 = "BConcrete.mul_cleartext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>) return %0 : tensor<2049xi64> } //CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> //CHECK: return %[[V0]] : tensor<5x2049xi64> //CHECK: } func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> { - %0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>) + %0 = "BConcrete.mul_cleartext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>) return %0 : tensor<5x2049xi64> } //CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.negate_lwe_buffer"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>) + %0 = "BConcrete.negate_lwe_tensor"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>) return %0 : tensor<2049xi64> } //CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64> //CHECK: return %[[V0]] : tensor<5x2049xi64> //CHECK: } func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> { - %0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>) + %0 = "BConcrete.negate_crt_lwe_tensor"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>) return %0 : tensor<5x2049xi64> } //CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>) + %0 = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>) return %0 : tensor<2049xi64> } //CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64> +//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.keyswitch_lwe_buffer"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>) + %0 = "BConcrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>) return %0 : tensor<2049xi64> } diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc index a64a8cc1b..c5f4f1f9f 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc @@ -266,7 +266,8 @@ INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES( JitTest, {defaultOptions()}, mlir::concretelang::JITSupport()) std::vector allOptions{ - defaultOptions(), loopOptions(), asyncOptions(), + defaultOptions(), + loopOptions(), #ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED dataflowOptions(), #endif