mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: redesign GPU support
- unify CPU and GPU bootstrapping operations - remove operations to build GLWE from table: this is now done in wrapper functions - remove GPU memory management operations: done in wrappers now, but we will have to think about how to deal with it later in MLIR
This commit is contained in:
@@ -12,7 +12,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize);
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize, bool useGPU);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` operations to GPU.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToGPUPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToGPU/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
|
||||
@@ -54,13 +54,6 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToGPU : Pass<"concrete-to-gpu", "mlir::ModuleOp"> {
|
||||
let summary = "Transforms operations in the Concrete dialect to GPU";
|
||||
let description = [{ Transforms operations in the Concrete dialect to GPU }];
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToGPUPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -73,17 +73,6 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$glwe,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$polynomialSize,
|
||||
I32Attr:$outPrecision,
|
||||
1DTensorOf<[I64]>:$table
|
||||
);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
@@ -96,11 +85,14 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
|
||||
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
@@ -143,11 +135,14 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
|
||||
def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> {
|
||||
let arguments = (ins
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
@@ -158,6 +153,9 @@ def BConcrete_AwaitFutureOp :
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// This is a different op in BConcrete just because of the way we are lowering to CAPI
|
||||
// When the CAPI lowering is detached from bufferization, we can remove this op, and lower
|
||||
// to the appropriate CAPI (gpu or cpu) depending on the useGPU compilation option
|
||||
def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
@@ -166,19 +164,10 @@ def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer">
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
LLVM_PointerTo<I64>:$bsk
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MoveBskToGPUOp : BConcrete_Op<"move_bsk_to_gpu"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs LLVM_PointerTo<I64>:$bsk);
|
||||
}
|
||||
|
||||
def BConcrete_FreeBskFromGPUOp : BConcrete_Op<"free_bsk_from_gpu"> {
|
||||
let arguments = (ins LLVM_PointerTo<I64>:$bsk);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -52,54 +52,22 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_GlweFromTable : Concrete_Op<"glwe_from_table", [NoSideEffect]> {
|
||||
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$table);
|
||||
let results = (outs Concrete_GlweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$input_ciphertext,
|
||||
Concrete_GlweCiphertextType:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweGPUOp : Concrete_Op<"bootstrap_lwe_gpu"> {
|
||||
let summary = "Bootstrap an LWE ciphertext in GPU using a lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$table,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
Concrete_GPUBsk:$bsk
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_MoveBskToGPUOp : Concrete_Op<"move_bsk_to_gpu"> {
|
||||
let summary = "Move bsk to GPU";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs Concrete_GPUBsk:$bsk);
|
||||
}
|
||||
|
||||
def Concrete_FreeBskFromGPUOp : Concrete_Op<"free_bsk_from_gpu"> {
|
||||
let summary = "Free bsk memory from GPU";
|
||||
|
||||
let arguments = (ins Concrete_GPUBsk:$bsk);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
|
||||
@@ -93,14 +93,4 @@ def Concrete_Context : Concrete_Type<"Context"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def Concrete_GPUBsk : Concrete_Type<"GPUBsk"> {
|
||||
let mnemonic = "gpu_bsk";
|
||||
|
||||
let summary = "A bsk in GPU";
|
||||
|
||||
let description = [{
|
||||
A bootstrapping key in GPU memory
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -92,24 +92,18 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe"> {
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
def TFHE_GLWEFromTableOp : TFHE_Op<"glwe_from_table"> {
|
||||
let summary =
|
||||
"Creates a GLWE ciphertext which is the trivial encrytion of a the input "
|
||||
"table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins 1DTensorOf < [I64] > : $table);
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
|
||||
let summary =
|
||||
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
TFHE_GLWECipherTextType : $ciphertext,
|
||||
TFHE_GLWECipherTextType : $lookup_table,
|
||||
1DTensorOf<[I64]> : $lookup_table,
|
||||
I32Attr : $inputLweDim,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $level,
|
||||
I32Attr : $baseLog
|
||||
I32Attr : $baseLog,
|
||||
I32Attr : $glweDimension
|
||||
);
|
||||
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
|
||||
@@ -10,6 +10,21 @@
|
||||
|
||||
extern "C" {
|
||||
|
||||
/// \brief Expands the input LUT
|
||||
///
|
||||
/// It duplicates values as needed to fill mega cases, taking care of the
|
||||
/// encoding and the half mega case shift in the process as well. All sizes
|
||||
/// should be powers of 2.
|
||||
///
|
||||
/// \param output where to write the expanded LUT
|
||||
/// \param output_size
|
||||
/// \param out_MESSAGE_BITS number of bits of message to be used
|
||||
/// \param lut original LUT
|
||||
/// \param lut_size
|
||||
void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
size_t out_MESSAGE_BITS, const uint64_t *lut,
|
||||
size_t lut_size);
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
@@ -58,15 +73,20 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
@@ -131,7 +151,9 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
/// \param poly_size polynomial size
|
||||
/// \param level level
|
||||
/// \param base_log base log
|
||||
/// \param bsk pointer to bsk on GPU
|
||||
/// \param glwe_dim
|
||||
/// \param precision
|
||||
/// \param context
|
||||
void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
@@ -139,7 +161,8 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, void *bsk);
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
/// \brief Copy ciphertext from CPU to GPU using a single stream.
|
||||
///
|
||||
@@ -174,13 +197,19 @@ void move_ct_to_cpu(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
|
||||
/// \brief Copy bootstrapping key from CPU to GPU using a single stream.
|
||||
///
|
||||
/// It handles memory allocation on GPU.
|
||||
/// It handles memory allocation on GPU, as well as conversion to the Fourier
|
||||
/// domain.
|
||||
///
|
||||
/// \param context
|
||||
/// \param input_lwe_dim
|
||||
/// \param poly_size
|
||||
/// \param level
|
||||
/// \param glwe_dim
|
||||
/// \param gpu_idx index of the GPU to use
|
||||
/// \return void* pointer to the GPU bsk
|
||||
void *move_bsk_to_gpu(mlir::concretelang::RuntimeContext *context,
|
||||
uint32_t gpu_idx);
|
||||
uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx);
|
||||
|
||||
/// \brief Free gpu memory.
|
||||
///
|
||||
|
||||
@@ -47,7 +47,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
bool parallelizeLoops, bool useGPU);
|
||||
|
||||
mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
@@ -57,10 +57,6 @@ mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
transformsConcreteToGPU(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
Reference in New Issue
Block a user