diff --git a/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h index acd3b0a91..aae7d3f4a 100644 --- a/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h +++ b/compiler/include/concretelang/Conversion/ConcreteToBConcrete/Pass.h @@ -12,7 +12,7 @@ namespace mlir { namespace concretelang { /// Create a pass to convert `Concrete` dialect to `BConcrete` dialect. std::unique_ptr> -createConvertConcreteToBConcretePass(bool loopParallelize); +createConvertConcreteToBConcretePass(bool loopParallelize, bool useGPU); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Conversion/ConcreteToGPU/Pass.h b/compiler/include/concretelang/Conversion/ConcreteToGPU/Pass.h deleted file mode 100644 index cafb2ae02..000000000 --- a/compiler/include/concretelang/Conversion/ConcreteToGPU/Pass.h +++ /dev/null @@ -1,18 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_ -#define ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace concretelang { -/// Create a pass to convert `Concrete` operations to GPU. -std::unique_ptr> createConvertConcreteToGPUPass(); -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index a0de4a566..4af9f0090 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -15,7 +15,6 @@ #include "concretelang/Conversion/BConcreteToCAPI/Pass.h" #include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" -#include "concretelang/Conversion/ConcreteToGPU/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHE/Pass.h" #include "concretelang/Conversion/LinalgExtras/Passes.h" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index ebc3c930d..b43bf4d1f 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -54,13 +54,6 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"]; } -def ConcreteToGPU : Pass<"concrete-to-gpu", "mlir::ModuleOp"> { - let summary = "Transforms operations in the Concrete dialect to GPU"; - let description = [{ Transforms operations in the Concrete dialect to GPU }]; - let constructor = "mlir::concretelang::createConvertConcreteToGPUPass()"; - let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect"]; -} - def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index d18c226c4..5b3f720ef 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -73,17 +73,6 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> { let results = (outs 2DTensorOf<[I64]>:$result); } -def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> { - let arguments = (ins - 1DTensorOf<[I64]>:$glwe, - I32Attr:$glweDimension, - I32Attr:$polynomialSize, - I32Attr:$outPrecision, - 1DTensorOf<[I64]>:$table - ); - let results = (outs); -} - def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { let arguments = (ins // LweKeySwitchKeyType:$keyswitch_key, @@ -96,11 +85,14 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> { def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { let arguments = (ins - // LweBootstrapKeyType:$bootstrap_key, 1DTensorOf<[I64]>:$input_ciphertext, - 1DTensorOf<[I64]>:$accumulator, - I32Attr:$level, - I32Attr:$baseLog + 1DTensorOf<[I64]>:$lookup_table, + I32:$inputLweDim, + I32:$polySize, + I32:$level, + I32:$baseLog, + I32:$glweDimension, + I32:$outPrecision ); let results = (outs 1DTensorOf<[I64]>:$result); } @@ -143,11 +135,14 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp : def BConcrete_BootstrapLweBufferAsyncOffloadOp : BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> { let arguments = (ins - // LweBootstrapKeyType:$bootstrap_key, 1DTensorOf<[I64]>:$input_ciphertext, - 1DTensorOf<[I64]>:$accumulator, - I32Attr:$level, - I32Attr:$baseLog + 1DTensorOf<[I64]>:$lookup_table, + I32:$inputLweDim, + I32:$polySize, + I32:$level, + I32:$baseLog, + I32:$glweDimension, + I32:$outPrecision ); let results = (outs RT_Future : $result); } @@ -158,6 +153,9 @@ def BConcrete_AwaitFutureOp : let results = (outs 1DTensorOf<[I64]>:$result); } +// This is a different op in BConcrete just because of the way we are lowering to CAPI +// When the CAPI lowering is detached from bufferization, we can remove this op, and lower +// to the appropriate CAPI (gpu or cpu) depending on the useGPU compilation option def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer"> { let arguments = (ins 1DTensorOf<[I64]>:$input_ciphertext, @@ -166,19 +164,10 @@ def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer"> I32:$polySize, I32:$level, I32:$baseLog, - LLVM_PointerTo:$bsk + I32:$glweDimension, + I32:$outPrecision ); let results = (outs 1DTensorOf<[I64]>:$result); } -def BConcrete_MoveBskToGPUOp : BConcrete_Op<"move_bsk_to_gpu"> { - let arguments = (ins); - let results = (outs LLVM_PointerTo:$bsk); -} - -def BConcrete_FreeBskFromGPUOp : BConcrete_Op<"free_bsk_from_gpu"> { - let arguments = (ins LLVM_PointerTo:$bsk); - let results = (outs); -} - #endif diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 0b72e042b..06e99fcf2 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -52,54 +52,22 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> { let results = (outs Concrete_LweCiphertextType:$result); } -def Concrete_GlweFromTable : Concrete_Op<"glwe_from_table", [NoSideEffect]> { - let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)"; - - let arguments = (ins 1DTensorOf<[I64]>:$table); - let results = (outs Concrete_GlweCiphertextType:$result); -} - def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> { let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table"; let arguments = (ins Concrete_LweCiphertextType:$input_ciphertext, - Concrete_GlweCiphertextType:$accumulator, - I32Attr:$level, - I32Attr:$baseLog - ); - let results = (outs Concrete_LweCiphertextType:$result); -} - -def Concrete_BootstrapLweGPUOp : Concrete_Op<"bootstrap_lwe_gpu"> { - let summary = "Bootstrap an LWE ciphertext in GPU using a lookup table"; - - let arguments = (ins - Concrete_LweCiphertextType:$input_ciphertext, - 1DTensorOf<[I64]>:$table, + 1DTensorOf<[I64]>:$lookup_table, I32:$inputLweDim, I32:$polySize, I32:$level, I32:$baseLog, - Concrete_GPUBsk:$bsk + I32:$glweDimension, + I32:$outPrecision ); let results = (outs Concrete_LweCiphertextType:$result); } -def Concrete_MoveBskToGPUOp : Concrete_Op<"move_bsk_to_gpu"> { - let summary = "Move bsk to GPU"; - - let arguments = (ins); - let results = (outs Concrete_GPUBsk:$bsk); -} - -def Concrete_FreeBskFromGPUOp : Concrete_Op<"free_bsk_from_gpu"> { - let summary = "Free bsk memory from GPU"; - - let arguments = (ins Concrete_GPUBsk:$bsk); - let results = (outs); -} - def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> { let summary = "Keyswitches a LWE ciphertext"; diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td index 5c597b527..12c64fefb 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td @@ -93,14 +93,4 @@ def Concrete_Context : Concrete_Type<"Context"> { }]; } -def Concrete_GPUBsk : Concrete_Type<"GPUBsk"> { - let mnemonic = "gpu_bsk"; - - let summary = "A bsk in GPU"; - - let description = [{ - A bootstrapping key in GPU memory - }]; -} - #endif diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index 6ac8b90a3..69f87730b 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -92,24 +92,18 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe"> { let results = (outs TFHE_GLWECipherTextType : $result); } -def TFHE_GLWEFromTableOp : TFHE_Op<"glwe_from_table"> { - let summary = - "Creates a GLWE ciphertext which is the trivial encrytion of a the input " - "table interpreted as a polynomial (to use later in a bootstrap)"; - - let arguments = (ins 1DTensorOf < [I64] > : $table); - let results = (outs TFHE_GLWECipherTextType : $result); -} - def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> { let summary = "Programmable bootstraping of a GLWE ciphertext with a lookup table"; let arguments = (ins TFHE_GLWECipherTextType : $ciphertext, - TFHE_GLWECipherTextType : $lookup_table, + 1DTensorOf<[I64]> : $lookup_table, + I32Attr : $inputLweDim, + I32Attr : $polySize, I32Attr : $level, - I32Attr : $baseLog + I32Attr : $baseLog, + I32Attr : $glweDimension ); let results = (outs TFHE_GLWECipherTextType : $result); diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index cdaea01d8..2650b2fed 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -10,6 +10,21 @@ extern "C" { +/// \brief Expands the input LUT +/// +/// It duplicates values as needed to fill mega cases, taking care of the +/// encoding and the half mega case shift in the process as well. All sizes +/// should be powers of 2. +/// +/// \param output where to write the expanded LUT +/// \param output_size +/// \param out_MESSAGE_BITS number of bits of message to be used +/// \param lut original LUT +/// \param lut_size +void encode_and_expand_lut(uint64_t *output, size_t output_size, + size_t out_MESSAGE_BITS, const uint64_t *lut, + size_t lut_size); + void memref_expand_lut_in_trivial_glwe_ct_u64( uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, @@ -58,15 +73,20 @@ void memref_bootstrap_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, - uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, - uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, + uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, + uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context); + void *memref_bootstrap_async_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, - uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, - uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, + uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, + uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context); void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned, @@ -131,7 +151,9 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, /// \param poly_size polynomial size /// \param level level /// \param base_log base log -/// \param bsk pointer to bsk on GPU +/// \param glwe_dim +/// \param precision +/// \param context void memref_bootstrap_lwe_cuda_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, @@ -139,7 +161,8 @@ void memref_bootstrap_lwe_cuda_u64( uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, - uint32_t base_log, void *bsk); + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, + mlir::concretelang::RuntimeContext *context); /// \brief Copy ciphertext from CPU to GPU using a single stream. /// @@ -174,13 +197,19 @@ void move_ct_to_cpu(uint64_t *out_allocated, uint64_t *out_aligned, /// \brief Copy bootstrapping key from CPU to GPU using a single stream. /// -/// It handles memory allocation on GPU. +/// It handles memory allocation on GPU, as well as conversion to the Fourier +/// domain. /// /// \param context +/// \param input_lwe_dim +/// \param poly_size +/// \param level +/// \param glwe_dim /// \param gpu_idx index of the GPU to use /// \return void* pointer to the GPU bsk void *move_bsk_to_gpu(mlir::concretelang::RuntimeContext *context, - uint32_t gpu_idx); + uint32_t input_lwe_dim, uint32_t poly_size, + uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx); /// \brief Free gpu memory. /// diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 68838e7ef..21c0276c5 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -47,7 +47,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - bool parallelizeLoops); + bool parallelizeLoops, bool useGPU); mlir::LogicalResult optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, @@ -57,10 +57,6 @@ mlir::LogicalResult asyncOffload(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); -mlir::LogicalResult -transformsConcreteToGPU(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass); - mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp index 8a94654cf..8080c9b66 100644 --- a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp @@ -3,88 +3,10 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include #include #include #include "concretelang/Conversion/Passes.h" -#include "concretelang/Conversion/Tools.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" -#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" - -char move_bsk_to_gpu[] = "move_bsk_to_gpu"; -char free_from_gpu[] = "free_from_gpu"; - -/// \brief Rewrites `BConcrete.move_bsk_to_gpu` into a CAPI call to -/// `move_bsk_to_gpu` -/// -/// Also insert the forward declaration of `move_bsk_to_gpu` -struct MoveBskOpPattern : public mlir::OpRewritePattern< - mlir::concretelang::BConcrete::MoveBskToGPUOp> { - MoveBskOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::BConcrete::MoveBskToGPUOp moveBskOp, - ::mlir::PatternRewriter &rewriter) const override { - - auto ctx = getContextArgument(moveBskOp); - - mlir::SmallVector operands{ctx}; - - // Insert forward declaration of the function - auto contextType = - mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), {contextType}, - {mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())}); - if (insertForwardDeclaration(moveBskOp, rewriter, move_bsk_to_gpu, funcType) - .failed()) { - return mlir::failure(); - } - - rewriter.replaceOpWithNewOp( - moveBskOp, move_bsk_to_gpu, moveBskOp.getResult().getType(), operands); - - return ::mlir::success(); - }; -}; - -/// \brief Rewrites `BConcrete.free_bsk_from_gpu` into a CAPI call to -/// `free_from_gpu` -/// -/// Also insert the forward declaration of `free_from_gpu` -struct FreeBskOpPattern : public mlir::OpRewritePattern< - mlir::concretelang::BConcrete::FreeBskFromGPUOp> { - FreeBskOpPattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern< - mlir::concretelang::BConcrete::FreeBskFromGPUOp>(context, benefit) { - } - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::BConcrete::FreeBskFromGPUOp freeBskOp, - ::mlir::PatternRewriter &rewriter) const override { - - mlir::SmallVector operands{freeBskOp.bsk()}; - - // Insert forward declaration of the function - auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - {mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())}, {}); - if (insertForwardDeclaration(freeBskOp, rewriter, free_from_gpu, funcType) - .failed()) { - return mlir::failure(); - } - - rewriter.replaceOpWithNewOp( - freeBskOp, free_from_gpu, mlir::TypeRange({}), operands); - - return ::mlir::success(); - }; -}; namespace { struct BConcreteToCAPIPass : public BConcreteToCAPIBase { @@ -98,12 +20,6 @@ void BConcreteToCAPIPass::runOnOperation() { mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); - target.addIllegalOp(); - target.addLegalDialect(); - - patterns.insert(&getContext()); - patterns.insert(&getContext()); - // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 135514075..6cc321895 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -3,7 +3,6 @@ add_subdirectory(TFHEGlobalParametrization) add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) -add_subdirectory(ConcreteToGPU) add_subdirectory(BConcreteToCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index 3a1a0d45c..0e467be69 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -48,11 +48,12 @@ struct ConcreteToBConcretePass : public ConcreteToBConcreteBase { void runOnOperation() final; ConcreteToBConcretePass() = delete; - ConcreteToBConcretePass(bool loopParallelize) - : loopParallelize(loopParallelize){}; + ConcreteToBConcretePass(bool loopParallelize, bool useGPU) + : loopParallelize(loopParallelize), useGPU(useGPU){}; private: bool loopParallelize; + bool useGPU; }; } // namespace @@ -65,10 +66,6 @@ class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter { public: ConcreteToBConcreteTypeConverter() { addConversion([](mlir::Type type) { return type; }); - addConversion([&](mlir::concretelang::Concrete::GPUBskType type) { - return mlir::LLVM::LLVMPointerType::get( - mlir::IntegerType::get(type.getContext(), 64)); - }); addConversion([&](mlir::concretelang::Concrete::PlaintextType type) { return mlir::IntegerType::get(type.getContext(), 64); }); @@ -310,43 +307,6 @@ struct MulCleartextLweCiphertextOpPattern }; }; -struct GlweFromTablePattern : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::GlweFromTable> { - GlweFromTablePattern(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op, - ::mlir::PatternRewriter &rewriter) const override { - ConcreteToBConcreteTypeConverter converter; - auto resultTy = - op.result() - .getType() - .cast(); - - auto newResultTy = - converter.convertType(resultTy).cast(); - // %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] - // : tensor - mlir::Value init = - rewriter.replaceOpWithNewOp( - op, newResultTy, mlir::ValueRange{}); - - // "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, - // %tlu) - auto polySize = resultTy.getPolynomialSize(); - auto glweDimension = resultTy.getGlweDimension(); - auto outPrecision = resultTy.getP(); - - rewriter.create( - op.getLoc(), init, glweDimension, polySize, outPrecision, op.table()); - - return ::mlir::success(); - }; -}; - struct ExtractSliceOpPattern : public mlir::OpRewritePattern { ExtractSliceOpPattern(::mlir::MLIRContext *context, @@ -915,22 +875,22 @@ void ConcreteToBConcretePass::runOnOperation() { LowToBConcrete, - LowToBConcrete, - LowToBConcrete, - LowToBConcrete, - LowToBConcrete, LowToBConcrete>(&getContext()); - patterns.insert(&getContext()); + if (this->useGPU) { + patterns.insert>( + &getContext()); + } else { + patterns.insert< + LowToBConcrete>( + &getContext()); + } // Add patterns to rewrite tensor operators that works on encrypted // tensors @@ -1058,8 +1018,8 @@ void ConcreteToBConcretePass::runOnOperation() { namespace mlir { namespace concretelang { std::unique_ptr> -createConvertConcreteToBConcretePass(bool loopParallelize) { - return std::make_unique(loopParallelize); +createConvertConcreteToBConcretePass(bool loopParallelize, bool useGPU) { + return std::make_unique(loopParallelize, useGPU); } } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Conversion/ConcreteToGPU/CMakeLists.txt b/compiler/lib/Conversion/ConcreteToGPU/CMakeLists.txt deleted file mode 100644 index cecf498e0..000000000 --- a/compiler/lib/Conversion/ConcreteToGPU/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_dialect_library(ConcreteToGPU - ConcreteToGPU.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete - - DEPENDS - ConcreteDialect - mlir-headers - - LINK_LIBS PUBLIC - MLIRIR - MLIRTransforms - ) - -target_link_libraries(ConcreteToGPU PUBLIC ConcreteDialect MLIRIR) diff --git a/compiler/lib/Conversion/ConcreteToGPU/ConcreteToGPU.cpp b/compiler/lib/Conversion/ConcreteToGPU/ConcreteToGPU.cpp deleted file mode 100644 index 17d0c90ee..000000000 --- a/compiler/lib/Conversion/ConcreteToGPU/ConcreteToGPU.cpp +++ /dev/null @@ -1,108 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include -#include - -#include "concretelang/Conversion/Passes.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" -#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" - -/// This rewrite pattern transforms any instance of `Concrete.bootstrap_lwe` -/// into `Concrete.bootstrap_lwe_gpu`. It also inserts operations to allocate -/// memory, copy bsk into GPU, and free memory after bootstrapping. -struct BstOpPattern : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::BootstrapLweOp> { - BstOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp bstOp, - ::mlir::PatternRewriter &rewriter) const override { - - auto baselog = bstOp.baseLog(); - auto level = bstOp.level(); - mlir::Value ct = bstOp.input_ciphertext(); - - auto ctType = - ct.getType().cast(); - auto inputLweDim = ctType.getDimension(); - - auto outType = bstOp.getResult() - .getType() - .cast(); - auto outputLweDim = outType.getDimension(); - - // copy bsk into GPU - mlir::Value bskGPU = - rewriter - .create( - bstOp.getLoc(), mlir::concretelang::Concrete::GPUBskType::get( - rewriter.getContext())) - .getResult(); - - mlir::Value inputLweDimCst = rewriter.create( - bstOp.getLoc(), inputLweDim, 32); - mlir::Value polySizeCst = rewriter.create( - bstOp.getLoc(), outputLweDim, 32); - mlir::Value levelCst = - rewriter.create(bstOp.getLoc(), level, 32); - mlir::Value baselogCst = rewriter.create( - bstOp.getLoc(), baselog, 32); - - mlir::Type tableType = - mlir::RankedTensorType::get({4}, rewriter.getI64Type()); - mlir::Value tableCst = rewriter.create( - bstOp.getLoc(), - mlir::DenseIntElementsAttr::get( - tableType, {llvm::APInt(64, 0), llvm::APInt(64, 0), - llvm::APInt(64, 0), llvm::APInt(64, 0)})); - - rewriter - .replaceOpWithNewOp( - bstOp, outType, ct, tableCst, inputLweDimCst, polySizeCst, levelCst, - baselogCst, bskGPU); - - // free bsk memory from GPU - rewriter.create( - bstOp.getLoc(), bskGPU); - - return ::mlir::success(); - }; -}; - -namespace { -struct ConcreteToGPUPass : public ConcreteToGPUBase { - void runOnOperation() final; -}; -} // namespace - -void ConcreteToGPUPass::runOnOperation() { - auto op = this->getOperation(); - - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - - target.addLegalDialect(); - target.addIllegalOp(); - - patterns.insert(&getContext()); - - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); - } -} - -namespace mlir { -namespace concretelang { -std::unique_ptr> createConvertConcreteToGPUPass() { - return std::make_unique(); -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp index 43b390a36..a28934863 100644 --- a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp +++ b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp @@ -82,12 +82,10 @@ public: /// becomes: /// /// ```mlir -/// %glwe_lut = "TFHE.glwe_from_table"(%lut) -/// : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}> /// %glwe_ks = "TFHE.keyswitch_glwe"(%ct) /// {baseLog = -1 : i32, level = -1 : i32} /// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> -/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut) +/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %lut) /// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, /// polynomialSize = -1 : i32} /// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> @@ -107,10 +105,6 @@ struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern auto inputTy = converter.convertType(lutOp.a().getType()) .cast(); auto resultTy = converter.convertType(lutOp.getType()); - // %glwe_lut = "TFHE.glwe_from_table"(%lut) - auto glweLut = rewriter.create( - lutOp.getLoc(), resultTy, lutOp.lut()); - // %glwe_ks = "TFHE.keyswitch_glwe"(%ct) auto glweKs = rewriter.create( lutOp.getLoc(), inputTy, lutOp.a(), -1, -1); mlir::concretelang::convertOperandAndResultTypes( @@ -118,8 +112,8 @@ struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern return converter.convertType(t); }); // %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut) - rewriter.replaceOpWithNewOp(lutOp, resultTy, glweKs, - glweLut, -1, -1); + rewriter.replaceOpWithNewOp( + lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1, -1); return ::mlir::success(); }; }; diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index f7a232652..ce1bcacc2 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -152,15 +152,13 @@ struct BootstrapGLWEOpPattern auto newInputTy = converter.glweIntraPBSType(inputTy); auto outputTy = bsOp.result().getType().cast(); auto newOutputTy = converter.convertType(outputTy); - auto tableTy = - bsOp.lookup_table().getType().cast(); - auto newTableTy = converter.glweLookupTableType(tableTy); auto newOp = rewriter.replaceOpWithNewOp( bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(), - cryptoParameters.brLevel, cryptoParameters.brLogBase); + cryptoParameters.nSmall, cryptoParameters.getPolynomialSize(), + cryptoParameters.brLevel, cryptoParameters.brLogBase, + cryptoParameters.glweDimension); rewriter.startRootUpdate(newOp); newOp.ciphertext().setType(newInputTy); - newOp.lookup_table().setType(newTableTy); rewriter.finalizeRootUpdate(newOp); return mlir::success(); }; @@ -212,49 +210,6 @@ private: mlir::concretelang::V0Parameter &cryptoParameters; }; -/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by -/// parametrize GLWE return type and pad the table if the precision has been -/// changed. -/// -/// Example: -/// -/// ```mlir -/// %lut = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64> -/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<4xi64>) -> -/// !TFHE.glwe<{_,_,_}{2}> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %lut = arith.constant dense<[0, 1, 2, 3, 0, 1, 2, 3]> : tensor<8xi64> -/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<8xi64>) -> -/// !TFHE.glwe<{_,_,_}{3}> -/// ``` -struct GLWEFromTablePattern - : public mlir::OpRewritePattern { - GLWEFromTablePattern(mlir::MLIRContext *context, - TFHEGlobalParametrizationTypeConverter &converter, - mlir::PatternBenefit benefit = - mlir::concretelang::DEFAULT_PATTERN_BENEFIT) - : mlir::OpRewritePattern(context, benefit), - converter(converter) {} - - mlir::LogicalResult - matchAndRewrite(TFHE::GLWEFromTableOp glweOp, - mlir::PatternRewriter &rewriter) const override { - auto outputTy = glweOp.result().getType().cast(); - auto newOutputTy = converter.glweLookupTableType(outputTy); - auto tableOp = glweOp.table(); - rewriter.replaceOpWithNewOp(glweOp, newOutputTy, - tableOp); - return mlir::success(); - }; - -private: - TFHEGlobalParametrizationTypeConverter &converter; -}; - template void populateWithTFHEOpTypeConversionPattern( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, @@ -316,13 +271,6 @@ void TFHEGlobalParametrizationPass::runOnOperation() { patterns, converter); // Parametrize keyswitch bootstrap - patterns.add(&getContext(), converter); - target.addDynamicallyLegalOp( - [&](TFHE::GLWEFromTableOp op) { - return !op.getType() - .cast() - .hasUnparametrizedParameters(); - }); target.addLegalOp(); patterns.add(&getContext(), converter, cryptoParameters); diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index ccce14a7c..7796fc4b4 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -66,25 +66,6 @@ public: namespace { -struct GLWEFromTableOpPattern - : public mlir::OpRewritePattern { - GLWEFromTableOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern(context, benefit) {} - - ::mlir::LogicalResult - matchAndRewrite(TFHE::GLWEFromTableOp glweOp, - mlir::PatternRewriter &rewriter) const override { - auto oldTy = glweOp.getType().cast(); - auto newTy = rewriter.getType( - oldTy.getDimension(), oldTy.getPolynomialSize(), oldTy.getP()); - - rewriter.replaceOpWithNewOp(glweOp, newTy, - glweOp.table()); - return ::mlir::success(); - }; -}; - struct BootstrapGLWEOpPattern : public mlir::OpRewritePattern { BootstrapGLWEOpPattern(mlir::MLIRContext *context, @@ -98,21 +79,31 @@ struct BootstrapGLWEOpPattern mlir::PatternRewriter &rewriter) const override { mlir::Type resultType = converter.convertType(bsOp.getType()); + auto precision = bsOp.getType().cast().getP(); + + mlir::Value inputLweDimCst = rewriter.create( + bsOp.getLoc(), bsOp.inputLweDim(), 32); + mlir::Value polySizeCst = rewriter.create( + bsOp.getLoc(), bsOp.polySize(), 32); + mlir::Value levelCst = rewriter.create( + bsOp.getLoc(), bsOp.level(), 32); + mlir::Value baseLogCst = rewriter.create( + bsOp.getLoc(), bsOp.baseLog(), 32); + mlir::Value glweDimCst = rewriter.create( + bsOp.getLoc(), bsOp.glweDimension(), 32); + mlir::Value precisionCst = rewriter.create( + bsOp.getLoc(), precision, 32); + auto newOp = rewriter.replaceOpWithNewOp( - bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), bsOp.level(), - bsOp.baseLog()); + bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), + inputLweDimCst, polySizeCst, levelCst, baseLogCst, glweDimCst, + precisionCst); rewriter.startRootUpdate(newOp); - newOp.input_ciphertext().setType( converter.convertType(bsOp.ciphertext().getType())); - - auto oldTy = bsOp.lookup_table().getType().cast(); - auto newTy = rewriter.getType( - oldTy.getDimension(), oldTy.getPolynomialSize(), oldTy.getP()); - newOp.accumulator().setType(newTy); - rewriter.finalizeRootUpdate(newOp); + return ::mlir::success(); } @@ -170,6 +161,9 @@ void TFHEToConcretePass::runOnOperation() { // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); + // Legalize arith.constant operations introduced by some patterns + target.addLegalOp(); + // Make sure that no ops `linalg.generic` that have illegal types target.addDynamicallyLegalOp( @@ -201,7 +195,6 @@ void TFHEToConcretePass::runOnOperation() { patterns.add>(&getContext(), converter); - patterns.add(&getContext()); patterns.add(&getContext(), converter); patterns.add(&getContext(), converter); target.addDynamicallyLegalOp( diff --git a/compiler/lib/Conversion/Tools.cpp b/compiler/lib/Conversion/Tools.cpp index c8dcaa67e..004e0eeca 100644 --- a/compiler/lib/Conversion/Tools.cpp +++ b/compiler/lib/Conversion/Tools.cpp @@ -42,6 +42,9 @@ mlir::Value getContextArgument(mlir::Operation *op) { mlir::Block *block = op->getBlock(); while (block != nullptr) { if (llvm::isa(block->getParentOp())) { + block = &mlir::cast(block->getParentOp()) + .getBody() + .front(); auto context = std::find_if( block->getArguments().rbegin(), block->getArguments().rend(), diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index 4ef763371..14ffc2e94 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -92,7 +92,6 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); auto i32Type = rewriter.getI32Type(); - auto i64PointerType = mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type()); mlir::FunctionType funcType; @@ -114,17 +113,21 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {}); } else if (funcName == memref_bootstrap_lwe_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), - {memref1DType, memref1DType, memref1DType, contextType}, {}); + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {}); } else if (funcName == memref_keyswitch_async_lwe_u64) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, contextType}, {futureType}); } else if (funcName == memref_bootstrap_async_lwe_u64) { - funcType = mlir::FunctionType::get( - rewriter.getContext(), - {memref1DType, memref1DType, memref1DType, contextType}, {futureType}); + funcType = mlir::FunctionType::get(rewriter.getContext(), + {memref1DType, memref1DType, + memref1DType, i32Type, i32Type, i32Type, + i32Type, i32Type, i32Type, contextType}, + {futureType}); } else if (funcName == memref_await_future) { funcType = mlir::FunctionType::get( rewriter.getContext(), @@ -133,7 +136,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( funcType = mlir::FunctionType::get(rewriter.getContext(), {memref1DType, memref1DType, memref1DType, i32Type, i32Type, i32Type, - i32Type, i64PointerType}, + i32Type, i32Type, i32Type, contextType}, {}); } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), @@ -244,90 +247,6 @@ struct BufferizableWithCallOpInterface } }; -struct BufferizableGlweFromTableOpInterface - : public BufferizableOpInterface::ExternalModel< - BufferizableGlweFromTableOpInterface, BConcrete::FillGlweFromTable> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return false; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::None; - } - - /// Bufferize GlweFromTable - /// ``` - /// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1, - /// polynomialSize=2048, outPrecision=3} : - /// (tensor<4096xi64>, tensor<32xi64>) -> () - /// ``` - /// - /// to - /// - /// ``` - /// %glweDim = arith.constant 1 : i32 - /// %polySize = arith.constant 2048 : i32 - /// %outPrecision = arith.constant 3 : i32 - /// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref - /// %lut_ = memref.cast %lut : memref<32xi64> to memref - /// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim, - /// %outPrecision, %lut_) : - /// (tensor, i32, i32, tensor) -> () - /// ``` - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - - auto loc = op->getLoc(); - auto castOp = cast(op); - - auto glweOp = - getCastedMemRef(rewriter, loc, - bufferization::getBuffer( - rewriter, castOp->getOpOperand(0).get(), options)); - auto lutOp = - getCastedMemRef(rewriter, loc, - bufferization::getBuffer( - rewriter, castOp->getOpOperand(1).get(), options)); - - auto polySizeOp = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize())); - auto glweDimensionOp = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(castOp.glweDimension())); - auto outPrecisionOp = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(castOp.outPrecision())); - - mlir::SmallVector operands{glweOp, polySizeOp, glweDimensionOp, - outPrecisionOp, lutOp}; - - // Insert forward declaration of the function - if (insertForwardDeclarationOfTheCAPI( - op, rewriter, memref_expand_lut_in_trivial_glwe_ct_u64) - .failed()) { - return mlir::failure(); - } - - rewriter.create( - loc, memref_expand_lut_in_trivial_glwe_ct_u64, mlir::TypeRange{}, - operands); - - replaceOpWithBufferizedValues(rewriter, op, {}); - - return success(); - } -}; - template struct BufferizableWithAsyncCallOpInterface : public BufferizableOpInterface::ExternalModel< @@ -565,7 +484,7 @@ void mlir::concretelang::BConcrete:: *ctx); BConcrete::BootstrapLweGPUBufferOp::attachInterface< BufferizableWithCallOpInterface>( + memref_bootstrap_lwe_cuda_u64, true>>( *ctx); BConcrete::KeySwitchLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); - BConcrete::FillGlweFromTable::attachInterface< - BufferizableGlweFromTableOpInterface>(*ctx); }); } diff --git a/compiler/lib/Runtime/AsyncOffload.cpp b/compiler/lib/Runtime/AsyncOffload.cpp index 205888d7b..1b47c4237 100644 --- a/compiler/lib/Runtime/AsyncOffload.cpp +++ b/compiler/lib/Runtime/AsyncOffload.cpp @@ -45,25 +45,44 @@ void async_bootstrap( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, - uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, - uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, + uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, + uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context, std::promise> promise) { + + uint64_t glwe_ct_size = poly_size * (glwe_dim + 1); + uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t)); + + std::vector expanded_tabulated_function_array(poly_size); + + encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size, + precision, tlu_aligned + tlu_offset, tlu_size); + + CAPI_ASSERT_ERROR( + default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers( + get_engine(context), glwe_ct, glwe_ct_size, + expanded_tabulated_function_array.data(), poly_size)); + CAPI_ASSERT_ERROR( fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers( get_fftw_engine(context), get_engine(context), get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset, - ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset)); + ct0_aligned + ct0_offset, glwe_ct)); promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{ out_allocated, out_aligned, out_offset, out_size, out_stride}); + free(glwe_ct); } void *memref_bootstrap_async_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, - uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, - uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, + uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, + uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context) { std::promise> promise; auto ret = new std::future>( @@ -71,8 +90,9 @@ void *memref_bootstrap_async_lwe_u64( std::thread offload_thread( async_bootstrap, out_allocated, out_aligned, out_offset, out_size, out_stride, ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride, - glwe_ct_allocated, glwe_ct_aligned, glwe_ct_offset, glwe_ct_size, - glwe_ct_stride, context, std::move(promise)); + tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride, + input_lwe_dim, poly_size, level, base_log, glwe_dim, precision, context, + std::move(promise)); offload_thread.detach(); return (void *)ret; } diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index 24067960f..4ccf8233f 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -23,9 +23,6 @@ DefaultEngine *get_levelled_engine() { return levelled_engine; } -// This helper function expands the input LUT into output, duplicating values as -// needed to fill mega cases, taking care of the encoding and the half mega case -// shift in the process as well. All sizes should be powers of 2. void encode_and_expand_lut(uint64_t *output, size_t output_size, size_t out_MESSAGE_BITS, const uint64_t *lut, size_t lut_size) { @@ -65,17 +62,15 @@ void encode_and_expand_lut(uint64_t *output, size_t output_size, typedef struct double2 { double x, y; } double2; - +// From concrete-cuda #include "bootstrap.h" #include "device.h" -void memref_keyswitch_lwe_cuda_u64(uint64_t *out_allocated, - uint64_t *out_aligned, uint64_t out_offset, - uint64_t out_size, uint64_t out_stride, - uint64_t *ct0_allocated, - uint64_t *ct0_aligned, uint64_t ct0_offset, - uint64_t ct0_size, uint64_t ct0_stride, - void *ksk_gpu) { +void memref_keyswitch_lwe_cuda_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context) { // TODO: GPU implementation } @@ -83,25 +78,37 @@ void *move_ct_to_gpu(uint64_t *ct_allocated, uint64_t *ct_aligned, uint64_t ct_offset, uint64_t ct_size, uint64_t ct_stride, uint32_t gpu_idx) { void *stream = cuda_create_stream(gpu_idx); - void *ct_gpu = cuda_malloc(ct_size * sizeof(uint64_t), gpu_idx); - cuda_memcpy_async_to_gpu(ct_gpu, ct_aligned + ct_offset, - ct_size * sizeof(uint64_t), stream, gpu_idx); + size_t buf_size = ct_size * sizeof(uint64_t); + void *ct_gpu = cuda_malloc(buf_size, gpu_idx); + cuda_memcpy_async_to_gpu(ct_gpu, ct_aligned + ct_offset, buf_size, stream, + gpu_idx); cuda_synchronize_device(gpu_idx); cuda_destroy_stream(stream, gpu_idx); return ct_gpu; } void *move_bsk_to_gpu(mlir::concretelang::RuntimeContext *context, - uint32_t gpu_idx = 0) { + uint32_t input_lwe_dim, uint32_t poly_size, + uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx = 0) { void *stream = cuda_create_stream(gpu_idx); - LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context); - BufferView bskBuffer = bootstrap_buffer_lwe_u64(bsk); - void *bsk_gpu = cuda_malloc(bskBuffer.length, gpu_idx); - cuda_memcpy_async_to_gpu(bsk_gpu, (void *)bskBuffer.pointer, bskBuffer.length, - stream, gpu_idx); + LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context); + size_t bsk_buffer_len = + input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level; + size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t); + uint64_t *bsk_buffer = + (uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size); + size_t fbsk_gpu_buffer_size = bsk_buffer_len * sizeof(double); + void *fbsk_gpu = cuda_malloc(fbsk_gpu_buffer_size, gpu_idx); + CAPI_ASSERT_ERROR( + default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers( + get_levelled_engine(), bsk, bsk_buffer)); + cuda_initialize_twiddles(poly_size, gpu_idx); + cuda_convert_lwe_bootstrap_key_64(fbsk_gpu, bsk_buffer, stream, gpu_idx, + input_lwe_dim, glwe_dim, level, poly_size); cuda_synchronize_device(gpu_idx); cuda_destroy_stream(stream, gpu_idx); - return bsk_gpu; + free(bsk_buffer); + return fbsk_gpu; } void move_ct_to_cpu(uint64_t *out_allocated, uint64_t *out_aligned, @@ -125,44 +132,58 @@ void memref_bootstrap_lwe_cuda_u64( uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, - uint32_t base_log, void *bsk_gpu) { - + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, + mlir::concretelang::RuntimeContext *context) { + // we currently just use the first GPU, but this should be decided + // dynamically, or during compilation, in the future uint32_t gpu_idx = 0; void *stream = cuda_create_stream(gpu_idx); - + // move bsk to gpu + void *fbsk_gpu = move_bsk_to_gpu(context, input_lwe_dim, poly_size, level, + glwe_dim, gpu_idx); // move input ciphertext into gpu void *ct0_gpu = move_ct_to_gpu(ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride, gpu_idx); // move output ciphertext into gpu void *out_gpu = move_ct_to_gpu(out_allocated, out_aligned, out_offset, out_size, out_stride, gpu_idx); - // hardcoded values + // construct LUT GLWE ciphertext + uint64_t glwe_ct_len = poly_size * (glwe_dim + 1); + uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t); + uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size); + std::vector expanded_tabulated_function_array(poly_size); + encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size, + precision, tlu_aligned + tlu_offset, tlu_size); + CAPI_ASSERT_ERROR( + default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers( + get_levelled_engine(), glwe_ct, glwe_ct_len, + expanded_tabulated_function_array.data(), poly_size)); + // move test vector into gpu + void *test_vector_gpu = + cuda_malloc(poly_size * (glwe_dim + 1) * sizeof(uint64_t), gpu_idx); + cuda_memcpy_async_to_gpu(test_vector_gpu, (void *)glwe_ct, glwe_ct_size, + stream, gpu_idx); + // free LUT ciphertext (CPU) + free(glwe_ct); + // move test vector indexes into gpu uint32_t num_samples = 1, num_test_vectors = 1, lwe_idx = 0; void *test_vector_idxes = malloc(num_samples * sizeof(uint32_t)); ((uint32_t *)test_vector_idxes)[0] = 0; - void *test_vector = malloc(poly_size * sizeof(uint64_t)); - for (size_t i = 0; i < poly_size; i++) { - ((uint64_t *)test_vector)[i] = (uint64_t)1 << 61; - } - // move test vector into gpu - void *test_vector_gpu = cuda_malloc(poly_size * sizeof(uint64_t), gpu_idx); - cuda_memcpy_async_to_gpu(test_vector_gpu, test_vector, - poly_size * sizeof(uint64_t), stream, gpu_idx); - // move test vector indexes into gpu void *test_vector_idxes_gpu = cuda_malloc(num_samples * sizeof(uint32_t), gpu_idx); cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, test_vector_idxes, num_samples * sizeof(uint32_t), stream, gpu_idx); // run gpu bootstrap cuda_bootstrap_low_latency_lwe_ciphertext_vector_64( - stream, out_gpu, test_vector_gpu, test_vector_idxes_gpu, ct0_gpu, bsk_gpu, - input_lwe_dim, poly_size, base_log, level, num_samples, num_test_vectors, - lwe_idx, cuda_get_max_shared_memory(gpu_idx)); + stream, out_gpu, test_vector_gpu, test_vector_idxes_gpu, ct0_gpu, + fbsk_gpu, input_lwe_dim, poly_size, base_log, level, num_samples, + num_test_vectors, lwe_idx, cuda_get_max_shared_memory(gpu_idx)); // copy output ciphertext back to cpu move_ct_to_cpu(out_allocated, out_aligned, out_offset, out_size, out_stride, out_gpu, out_size, gpu_idx); cuda_synchronize_device(gpu_idx); // free memory that we allocated on gpu + cuda_drop(fbsk_gpu, gpu_idx); cuda_drop(ct0_gpu, gpu_idx); cuda_drop(out_gpu, gpu_idx); cuda_drop(test_vector_gpu, gpu_idx); @@ -271,14 +292,31 @@ void memref_bootstrap_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, - uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, - uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, + uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, + uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, + uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context) { + + uint64_t glwe_ct_size = poly_size * (glwe_dim + 1); + uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t)); + + std::vector expanded_tabulated_function_array(poly_size); + + encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size, + precision, tlu_aligned + tlu_offset, tlu_size); + + CAPI_ASSERT_ERROR( + default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers( + get_levelled_engine(), glwe_ct, glwe_ct_size, + expanded_tabulated_function_array.data(), poly_size)); + CAPI_ASSERT_ERROR( fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers( get_fftw_engine(context), get_engine(context), get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset, - ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset)); + ct0_aligned + ct0_offset, glwe_ct)); + free(glwe_ct); } uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) { diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index e67380f48..fc5c0015f 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -317,14 +317,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Optimizing Concrete failed"); } - // Transforming into GPU - if (this->compilerOptions.useGPU && - mlir::concretelang::pipeline::transformsConcreteToGPU(mlirContext, module, - this->enablePass) - .failed()) { - return errorDiag("Transforming Concrete to GPU failed"); - } - if (target == Target::CONCRETE) return std::move(res); @@ -363,7 +355,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { // Concrete -> BConcrete if (mlir::concretelang::pipeline::lowerConcreteToBConcrete( - mlirContext, module, this->enablePass, loopParallelize) + mlirContext, module, this->enablePass, loopParallelize, + options.useGPU) .failed()) { return StreamStringError( "Lowering from Concrete to Bufferized Concrete failed"); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 75abb4f8d..7ca12616f 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -239,26 +239,16 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } -mlir::LogicalResult -transformsConcreteToGPU(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass) { - mlir::PassManager pm(&context); - pipelinePrinting("ConcreteToGPU", pm, context); - addPotentiallyNestedPass( - pm, mlir::concretelang::createConvertConcreteToGPUPass(), enablePass); - return pm.run(module.getOperation()); -} - mlir::LogicalResult lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, - bool parallelizeLoops) { + bool parallelizeLoops, bool useGPU) { mlir::PassManager pm(&context); pipelinePrinting("ConcreteToBConcrete", pm, context); std::unique_ptr conversionPass = - mlir::concretelang::createConvertConcreteToBConcretePass( - parallelizeLoops); + mlir::concretelang::createConvertConcreteToBConcretePass(parallelizeLoops, + useGPU); bool passEnabled = enablePass(conversionPass.get()); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 04da13418..76b0873b8 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -98,11 +98,11 @@ llvm::cl::opt "dialects. (Enabled by default)"), llvm::cl::init(true)); -llvm::cl::opt - useGPU("use-gpu", - llvm::cl::desc("enable/disable generating concrete GPU " - "operations (Disabled by default)"), - llvm::cl::init(false)); +llvm::cl::opt useGPU( + "use-gpu", + llvm::cl::desc( + "enable/disable generating GPU operations (Disabled by default)"), + llvm::cl::init(false)); llvm::cl::list passes( "passes", diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir index 871a2b348..687fee07b 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table.mlir @@ -1,15 +1,24 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s //CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> { -//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2048xi64> -//CHECK: "BConcrete.fill_glwe_from_table"(%[[V0]], %[[A1]]) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 1024 : i32} : (tensor<2048xi64>, tensor<16xi64>) -> () +//CHECK: %[[C1:.*]] = arith.constant 600 : i32 +//CHECK: %[[C2:.*]] = arith.constant 1024 : i32 +//CHECK: %[[C3:.*]] = arith.constant 2 : i32 +//CHECK: %[[C4:.*]] = arith.constant 3 : i32 +//CHECK: %[[C5:.*]] = arith.constant 1 : i32 +//CHECK: %[[C6:.*]] = arith.constant 4 : i32 //CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %[[V0]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<601xi64>, tensor<2048xi64>) -> tensor<1025xi64> +//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<1025xi64> //CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> { - %0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1,1024,4> + %c1 = arith.constant 600 : i32 + %c2 = arith.constant 1024 : i32 + %c3 = arith.constant 2 : i32 + %c4 = arith.constant 3 : i32 + %c5 = arith.constant 1 : i32 + %c6 = arith.constant 4 : i32 %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1,1024,4>) -> !Concrete.lwe_ciphertext<1024,4> + %2 = "Concrete.bootstrap_lwe"(%1, %arg1, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4> return %2 : !Concrete.lwe_ciphertext<1024,4> } diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir index d93df777b..8e221d0b0 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/apply_lookup_table_cst.mlir @@ -1,17 +1,25 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s //CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { +//CHECK: %[[C1:.*]] = arith.constant 600 : i32 +//CHECK: %[[C2:.*]] = arith.constant 2048 : i32 +//CHECK: %[[C3:.*]] = arith.constant 4 : i32 +//CHECK: %[[C4:.*]] = arith.constant 5 : i32 +//CHECK: %[[C5:.*]] = arith.constant 1 : i32 //CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> -//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4096xi64> -//CHECK: "BConcrete.fill_glwe_from_table"(%[[V0]], %cst) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 2048 : i32} : (tensor<4096xi64>, tensor<16xi64>) -> () //CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64> -//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %[[V0]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<601xi64>, tensor<4096xi64>) -> tensor<2049xi64> +//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C3]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64> //CHECK: return %[[V2]] : tensor<2049xi64> //CHECK: } func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> { + %c1 = arith.constant 600 : i32 + %c2 = arith.constant 2048 : i32 + %c3 = arith.constant 4 : i32 + %c4 = arith.constant 5 : i32 + %c5 = arith.constant 1 : i32 + %c6 = arith.constant 4 : i32 %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - %0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1, 2048,4> %1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4> - %2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1,2048,4>) -> !Concrete.lwe_ciphertext<2048,4> + %2 = "Concrete.bootstrap_lwe"(%1, %tlu, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,4> return %2 : !Concrete.lwe_ciphertext<2048,4> } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir index 3044dff73..32b5b6c06 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate.mlir @@ -1,9 +1,8 @@ // RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> { -// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> // CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> -// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{3}>) -> !TFHE.glwe<{_,_,_}{3}> +// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> // CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}> func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> { %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>) diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate_cst.mlir index 65457547e..911d118f9 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHE/apply_univariate_cst.mlir @@ -2,9 +2,8 @@ //CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> { //CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -//CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}> //CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> -//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> +//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}> //CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}> //CHECK-NEXT: } func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir index 61bc3e5ff..014ca3d40 100644 --- a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir @@ -2,15 +2,13 @@ //CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> { //CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> -//CHECK: %[[V0:.*]] = "TFHE.glwe_from_table"(%cst) : (tensor<16xi64>) -> !TFHE.glwe<{2,1024,64}{4}> //CHECK: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = 4 : i32, level = 3 : i32} : (!TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{750,1,64}{4}> -//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = 23 : i32, level = 1 : i32} : (!TFHE.glwe<{750,1,64}{4}>, !TFHE.glwe<{2,1024,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> +//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = 23 : i32, glweDimension = 2 : i32, inputLweDim = 750 : i32, level = 1 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{750,1,64}{4}>, tensor<16xi64>) -> !TFHE.glwe<{2048,1,64}{4}> //CHECK: return %[[V2]] : !TFHE.glwe<{2048,1,64}{4}> //CHECK: } func.func @main(%arg0: !TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}> { %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> - %0 = "TFHE.glwe_from_table"(%cst) : (tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}> %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}> - %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, !TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}> + %2 = "TFHE.bootstrap_glwe"(%1, %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}> return %2 : !TFHE.glwe<{_,_,_}{4}> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir index 3b91accfc..d24101b1d 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/bootstrap.mlir @@ -2,13 +2,17 @@ //CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<600,7>) -> !Concrete.lwe_ciphertext<1024,4> { //CHECK: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -//CHECK: %[[V0:.*]] = "Concrete.glwe_from_table"(%cst) : (tensor<128xi64>) -> !Concrete.glwe_ciphertext<1,1024,7> -//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%[[A0]], %[[V0]]) {baseLog = 1 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<600,7>, !Concrete.glwe_ciphertext<1,1024,7>) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: %[[C1:.*]] = arith.constant 600 : i32 +//CHECK: %[[C2:.*]] = arith.constant 1024 : i32 +//CHECK: %[[C3:.*]] = arith.constant 3 : i32 +//CHECK: %[[C4:.*]] = arith.constant 1 : i32 +//CHECK: %[[C5:.*]] = arith.constant 1 : i32 +//CHECK: %[[C6:.*]] = arith.constant 4 : i32 +//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (!Concrete.lwe_ciphertext<600,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4> //CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4> //CHECK: } func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<{600,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> { %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> - %glwe_lut = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> - %bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %glwe_lut) {baseLog = 1 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, !TFHE.glwe<{1,1024,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> + %bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1024,1,64}{4}> return %bootstraped : !TFHE.glwe<{1024,1,64}{4}> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/glwe_from_table.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/glwe_from_table.mlir deleted file mode 100644 index ada1db3ff..000000000 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/glwe_from_table.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s - -//CHECK: func.func @glwe_from_table() { -//CHECK-NEXT: %[[V0:.*]] = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -//CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[V0]]) : (tensor<128xi64>) -> !Concrete.glwe_ciphertext<1,1024,7> -//CHECK-NEXT: return -//CHECK-NEXT: } -func.func @glwe_from_table() { - %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> - %0 = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> - return -} diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir index e57b7d656..84fb12927 100644 --- a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir @@ -72,12 +72,12 @@ func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049 return %0 : tensor<5x2049xi64> } -//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<4096xi64>) -> tensor<2049xi64> { -//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[A0]], %[[A1]]) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> tensor<2049xi64> +//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } -func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> { - %0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> (tensor<2049xi64>) +func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> { + %0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> (tensor<2049xi64>) return %0 : tensor<2049xi64> } diff --git a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir index 70135e8b3..6f10384c4 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir @@ -36,11 +36,11 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co return %1: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7> +func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> + %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7> return %1: !Concrete.lwe_ciphertext<2048,7> } diff --git a/compiler/tests/check_tests/Dialect/TFHE/ops.mlir b/compiler/tests/check_tests/Dialect/TFHE/ops.mlir index 1d0d9772c..278994f8c 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/ops.mlir @@ -8,18 +8,11 @@ func.func @keyswitch_glwe(%arg0: !TFHE.glwe<{1,1024,64}{7}>) -> !TFHE.glwe<{1,52 return %0: !TFHE.glwe<{1,527,64}{7}> } -// CHECK: func.func @bootstrap_glwe(%[[GLWE:.*]]: !TFHE.glwe<{1,527,64}{7}>, %[[LUT:.*]]: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}> -func.func @bootstrap_glwe(%glwe: !TFHE.glwe<{1,527,64}{7}>, %lookup_table_glwe: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}> { - // CHECK-NEXT: %[[V0:.*]] = "TFHE.bootstrap_glwe"(%[[GLWE]], %[[LUT]]) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}> +// CHECK: func.func @bootstrap_glwe(%[[GLWE:.*]]: !TFHE.glwe<{1,527,64}{7}>, %[[LUT:.*]]: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> +func.func @bootstrap_glwe(%glwe: !TFHE.glwe<{1,527,64}{7}>, %lut: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> { + // CHECK-NEXT: %[[V0:.*]] = "TFHE.bootstrap_glwe"(%[[GLWE]], %[[LUT]]) {baseLog = 2 : i32, glweDimension = 1 : i32, inputLweDim = 527 : i32, level = 3 : i32, polySize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> // CHECK-NEXT: return %[[V0]] : !TFHE.glwe<{1,1024,64}{7}> - %0 = "TFHE.bootstrap_glwe"(%glwe, %lookup_table_glwe) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}> + %0 = "TFHE.bootstrap_glwe"(%glwe, %lut) {baseLog = 2 : i32, glweDimension = 1 : i32, inputLweDim = 527 : i32, level = 3 : i32, polySize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> return %0 : !TFHE.glwe<{1,1024,64}{7}> } -// CHECK: func.func @glwe_from_table(%[[LUT:.*]]: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> -func.func @glwe_from_table(%lookup_table: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> { - // CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> - // CHECK-NEXT: return %[[V0]] : !TFHE.glwe<{1,1024,64}{7}> - %0 = "TFHE.glwe_from_table"(%lookup_table) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> - return %0 : !TFHE.glwe<{1,1024,64}{7}> -}