mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: redesign GPU support
- unify CPU and GPU bootstrapping operations - remove operations to build GLWE from table: this is now done in wrapper functions - remove GPU memory management operations: done in wrappers now, but we will have to think about how to deal with it later in MLIR
This commit is contained in:
@@ -12,7 +12,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize);
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize, bool useGPU);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_CONCRETETOGPU_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` operations to GPU.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToGPUPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToGPU/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
|
||||
@@ -54,13 +54,6 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def ConcreteToGPU : Pass<"concrete-to-gpu", "mlir::ModuleOp"> {
|
||||
let summary = "Transforms operations in the Concrete dialect to GPU";
|
||||
let description = [{ Transforms operations in the Concrete dialect to GPU }];
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToGPUPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -73,17 +73,6 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$glwe,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$polynomialSize,
|
||||
I32Attr:$outPrecision,
|
||||
1DTensorOf<[I64]>:$table
|
||||
);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
@@ -96,11 +85,14 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
|
||||
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
@@ -143,11 +135,14 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
|
||||
def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> {
|
||||
let arguments = (ins
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
@@ -158,6 +153,9 @@ def BConcrete_AwaitFutureOp :
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// This is a different op in BConcrete just because of the way we are lowering to CAPI
|
||||
// When the CAPI lowering is detached from bufferization, we can remove this op, and lower
|
||||
// to the appropriate CAPI (gpu or cpu) depending on the useGPU compilation option
|
||||
def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
@@ -166,19 +164,10 @@ def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer">
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
LLVM_PointerTo<I64>:$bsk
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MoveBskToGPUOp : BConcrete_Op<"move_bsk_to_gpu"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs LLVM_PointerTo<I64>:$bsk);
|
||||
}
|
||||
|
||||
def BConcrete_FreeBskFromGPUOp : BConcrete_Op<"free_bsk_from_gpu"> {
|
||||
let arguments = (ins LLVM_PointerTo<I64>:$bsk);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -52,54 +52,22 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_GlweFromTable : Concrete_Op<"glwe_from_table", [NoSideEffect]> {
|
||||
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$table);
|
||||
let results = (outs Concrete_GlweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$input_ciphertext,
|
||||
Concrete_GlweCiphertextType:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweGPUOp : Concrete_Op<"bootstrap_lwe_gpu"> {
|
||||
let summary = "Bootstrap an LWE ciphertext in GPU using a lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$table,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
Concrete_GPUBsk:$bsk
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_MoveBskToGPUOp : Concrete_Op<"move_bsk_to_gpu"> {
|
||||
let summary = "Move bsk to GPU";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs Concrete_GPUBsk:$bsk);
|
||||
}
|
||||
|
||||
def Concrete_FreeBskFromGPUOp : Concrete_Op<"free_bsk_from_gpu"> {
|
||||
let summary = "Free bsk memory from GPU";
|
||||
|
||||
let arguments = (ins Concrete_GPUBsk:$bsk);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
|
||||
@@ -93,14 +93,4 @@ def Concrete_Context : Concrete_Type<"Context"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def Concrete_GPUBsk : Concrete_Type<"GPUBsk"> {
|
||||
let mnemonic = "gpu_bsk";
|
||||
|
||||
let summary = "A bsk in GPU";
|
||||
|
||||
let description = [{
|
||||
A bootstrapping key in GPU memory
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -92,24 +92,18 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe"> {
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
def TFHE_GLWEFromTableOp : TFHE_Op<"glwe_from_table"> {
|
||||
let summary =
|
||||
"Creates a GLWE ciphertext which is the trivial encrytion of a the input "
|
||||
"table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins 1DTensorOf < [I64] > : $table);
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
|
||||
let summary =
|
||||
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
|
||||
|
||||
let arguments = (ins
|
||||
TFHE_GLWECipherTextType : $ciphertext,
|
||||
TFHE_GLWECipherTextType : $lookup_table,
|
||||
1DTensorOf<[I64]> : $lookup_table,
|
||||
I32Attr : $inputLweDim,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $level,
|
||||
I32Attr : $baseLog
|
||||
I32Attr : $baseLog,
|
||||
I32Attr : $glweDimension
|
||||
);
|
||||
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
|
||||
@@ -10,6 +10,21 @@
|
||||
|
||||
extern "C" {
|
||||
|
||||
/// \brief Expands the input LUT
|
||||
///
|
||||
/// It duplicates values as needed to fill mega cases, taking care of the
|
||||
/// encoding and the half mega case shift in the process as well. All sizes
|
||||
/// should be powers of 2.
|
||||
///
|
||||
/// \param output where to write the expanded LUT
|
||||
/// \param output_size
|
||||
/// \param out_MESSAGE_BITS number of bits of message to be used
|
||||
/// \param lut original LUT
|
||||
/// \param lut_size
|
||||
void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
size_t out_MESSAGE_BITS, const uint64_t *lut,
|
||||
size_t lut_size);
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
@@ -58,15 +73,20 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
@@ -131,7 +151,9 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
/// \param poly_size polynomial size
|
||||
/// \param level level
|
||||
/// \param base_log base log
|
||||
/// \param bsk pointer to bsk on GPU
|
||||
/// \param glwe_dim
|
||||
/// \param precision
|
||||
/// \param context
|
||||
void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
@@ -139,7 +161,8 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, void *bsk);
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
/// \brief Copy ciphertext from CPU to GPU using a single stream.
|
||||
///
|
||||
@@ -174,13 +197,19 @@ void move_ct_to_cpu(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
|
||||
/// \brief Copy bootstrapping key from CPU to GPU using a single stream.
|
||||
///
|
||||
/// It handles memory allocation on GPU.
|
||||
/// It handles memory allocation on GPU, as well as conversion to the Fourier
|
||||
/// domain.
|
||||
///
|
||||
/// \param context
|
||||
/// \param input_lwe_dim
|
||||
/// \param poly_size
|
||||
/// \param level
|
||||
/// \param glwe_dim
|
||||
/// \param gpu_idx index of the GPU to use
|
||||
/// \return void* pointer to the GPU bsk
|
||||
void *move_bsk_to_gpu(mlir::concretelang::RuntimeContext *context,
|
||||
uint32_t gpu_idx);
|
||||
uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx);
|
||||
|
||||
/// \brief Free gpu memory.
|
||||
///
|
||||
|
||||
@@ -47,7 +47,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
bool parallelizeLoops, bool useGPU);
|
||||
|
||||
mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
@@ -57,10 +57,6 @@ mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
transformsConcreteToGPU(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -3,88 +3,10 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
|
||||
char move_bsk_to_gpu[] = "move_bsk_to_gpu";
|
||||
char free_from_gpu[] = "free_from_gpu";
|
||||
|
||||
/// \brief Rewrites `BConcrete.move_bsk_to_gpu` into a CAPI call to
|
||||
/// `move_bsk_to_gpu`
|
||||
///
|
||||
/// Also insert the forward declaration of `move_bsk_to_gpu`
|
||||
struct MoveBskOpPattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::MoveBskToGPUOp> {
|
||||
MoveBskOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::BConcrete::MoveBskToGPUOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::MoveBskToGPUOp moveBskOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto ctx = getContextArgument(moveBskOp);
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands{ctx};
|
||||
|
||||
// Insert forward declaration of the function
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {contextType},
|
||||
{mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())});
|
||||
if (insertForwardDeclaration(moveBskOp, rewriter, move_bsk_to_gpu, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
moveBskOp, move_bsk_to_gpu, moveBskOp.getResult().getType(), operands);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// \brief Rewrites `BConcrete.free_bsk_from_gpu` into a CAPI call to
|
||||
/// `free_from_gpu`
|
||||
///
|
||||
/// Also insert the forward declaration of `free_from_gpu`
|
||||
struct FreeBskOpPattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FreeBskFromGPUOp> {
|
||||
FreeBskOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::BConcrete::FreeBskFromGPUOp>(context, benefit) {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::BConcrete::FreeBskFromGPUOp freeBskOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands{freeBskOp.bsk()};
|
||||
|
||||
// Insert forward declaration of the function
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())}, {});
|
||||
if (insertForwardDeclaration(freeBskOp, rewriter, free_from_gpu, funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
freeBskOp, free_from_gpu, mlir::TypeRange({}), operands);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
@@ -98,12 +20,6 @@ void BConcreteToCAPIPass::runOnOperation() {
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addIllegalOp<mlir::concretelang::BConcrete::MoveBskToGPUOp>();
|
||||
target.addLegalDialect<mlir::func::FuncDialect>();
|
||||
|
||||
patterns.insert<MoveBskOpPattern>(&getContext());
|
||||
patterns.insert<FreeBskOpPattern>(&getContext());
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
|
||||
@@ -3,7 +3,6 @@ add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(ConcreteToGPU)
|
||||
add_subdirectory(BConcreteToCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
|
||||
@@ -48,11 +48,12 @@ struct ConcreteToBConcretePass
|
||||
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
|
||||
void runOnOperation() final;
|
||||
ConcreteToBConcretePass() = delete;
|
||||
ConcreteToBConcretePass(bool loopParallelize)
|
||||
: loopParallelize(loopParallelize){};
|
||||
ConcreteToBConcretePass(bool loopParallelize, bool useGPU)
|
||||
: loopParallelize(loopParallelize), useGPU(useGPU){};
|
||||
|
||||
private:
|
||||
bool loopParallelize;
|
||||
bool useGPU;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -65,10 +66,6 @@ class ConcreteToBConcreteTypeConverter : public mlir::TypeConverter {
|
||||
public:
|
||||
ConcreteToBConcreteTypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::Concrete::GPUBskType type) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
@@ -310,43 +307,6 @@ struct MulCleartextLweCiphertextOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTablePattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy =
|
||||
op.result()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::GlweCiphertextType>();
|
||||
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
mlir::Value init =
|
||||
rewriter.replaceOpWithNewOp<mlir::bufferization::AllocTensorOp>(
|
||||
op, newResultTy, mlir::ValueRange{});
|
||||
|
||||
// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension,
|
||||
// %tlu)
|
||||
auto polySize = resultTy.getPolynomialSize();
|
||||
auto glweDimension = resultTy.getGlweDimension();
|
||||
auto outPrecision = resultTy.getP();
|
||||
|
||||
rewriter.create<mlir::concretelang::BConcrete::FillGlweFromTable>(
|
||||
op.getLoc(), init, glweDimension, polySize, outPrecision, op.table());
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ExtractSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
ExtractSliceOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -915,22 +875,22 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweGPUOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::MoveBskToGPUOp,
|
||||
mlir::concretelang::BConcrete::MoveBskToGPUOp,
|
||||
mlir::concretelang::BConcrete::MoveBskToGPUOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::FreeBskFromGPUOp,
|
||||
mlir::concretelang::BConcrete::FreeBskFromGPUOp,
|
||||
mlir::concretelang::BConcrete::FreeBskFromGPUOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
|
||||
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
|
||||
|
||||
patterns.insert<GlweFromTablePattern>(&getContext());
|
||||
if (this->useGPU) {
|
||||
patterns.insert<LowToBConcrete<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp>>(
|
||||
&getContext());
|
||||
} else {
|
||||
patterns.insert<
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
|
||||
&getContext());
|
||||
}
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
@@ -1058,8 +1018,8 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize) {
|
||||
return std::make_unique<ConcreteToBConcretePass>(loopParallelize);
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize, bool useGPU) {
|
||||
return std::make_unique<ConcreteToBConcretePass>(loopParallelize, useGPU);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_dialect_library(ConcreteToGPU
|
||||
ConcreteToGPU.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
mlir-headers
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(ConcreteToGPU PUBLIC ConcreteDialect MLIRIR)
|
||||
@@ -1,108 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
|
||||
/// This rewrite pattern transforms any instance of `Concrete.bootstrap_lwe`
|
||||
/// into `Concrete.bootstrap_lwe_gpu`. It also inserts operations to allocate
|
||||
/// memory, copy bsk into GPU, and free memory after bootstrapping.
|
||||
struct BstOpPattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp> {
|
||||
BstOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp bstOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto baselog = bstOp.baseLog();
|
||||
auto level = bstOp.level();
|
||||
mlir::Value ct = bstOp.input_ciphertext();
|
||||
|
||||
auto ctType =
|
||||
ct.getType().cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto inputLweDim = ctType.getDimension();
|
||||
|
||||
auto outType = bstOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto outputLweDim = outType.getDimension();
|
||||
|
||||
// copy bsk into GPU
|
||||
mlir::Value bskGPU =
|
||||
rewriter
|
||||
.create<mlir::concretelang::Concrete::MoveBskToGPUOp>(
|
||||
bstOp.getLoc(), mlir::concretelang::Concrete::GPUBskType::get(
|
||||
rewriter.getContext()))
|
||||
.getResult();
|
||||
|
||||
mlir::Value inputLweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bstOp.getLoc(), inputLweDim, 32);
|
||||
mlir::Value polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bstOp.getLoc(), outputLweDim, 32);
|
||||
mlir::Value levelCst =
|
||||
rewriter.create<mlir::arith::ConstantIntOp>(bstOp.getLoc(), level, 32);
|
||||
mlir::Value baselogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bstOp.getLoc(), baselog, 32);
|
||||
|
||||
mlir::Type tableType =
|
||||
mlir::RankedTensorType::get({4}, rewriter.getI64Type());
|
||||
mlir::Value tableCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
bstOp.getLoc(),
|
||||
mlir::DenseIntElementsAttr::get(
|
||||
tableType, {llvm::APInt(64, 0), llvm::APInt(64, 0),
|
||||
llvm::APInt(64, 0), llvm::APInt(64, 0)}));
|
||||
|
||||
rewriter
|
||||
.replaceOpWithNewOp<mlir::concretelang::Concrete::BootstrapLweGPUOp>(
|
||||
bstOp, outType, ct, tableCst, inputLweDimCst, polySizeCst, levelCst,
|
||||
baselogCst, bskGPU);
|
||||
|
||||
// free bsk memory from GPU
|
||||
rewriter.create<mlir::concretelang::Concrete::FreeBskFromGPUOp>(
|
||||
bstOp.getLoc(), bskGPU);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct ConcreteToGPUPass : public ConcreteToGPUBase<ConcreteToGPUPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConcreteToGPUPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addLegalDialect<mlir::concretelang::Concrete::ConcreteDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
target.addIllegalOp<mlir::concretelang::Concrete::BootstrapLweOp>();
|
||||
|
||||
patterns.insert<BstOpPattern>(&getContext());
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertConcreteToGPUPass() {
|
||||
return std::make_unique<ConcreteToGPUPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -82,12 +82,10 @@ public:
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %glwe_lut = "TFHE.glwe_from_table"(%lut)
|
||||
/// : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
|
||||
/// {baseLog = -1 : i32, level = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %lut)
|
||||
/// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32,
|
||||
/// polynomialSize = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
|
||||
@@ -107,10 +105,6 @@ struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
|
||||
auto inputTy = converter.convertType(lutOp.a().getType())
|
||||
.cast<TFHE::GLWECipherTextType>();
|
||||
auto resultTy = converter.convertType(lutOp.getType());
|
||||
// %glwe_lut = "TFHE.glwe_from_table"(%lut)
|
||||
auto glweLut = rewriter.create<TFHE::GLWEFromTableOp>(
|
||||
lutOp.getLoc(), resultTy, lutOp.lut());
|
||||
// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
|
||||
auto glweKs = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
lutOp.getLoc(), inputTy, lutOp.a(), -1, -1);
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
@@ -118,8 +112,8 @@ struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
|
||||
return converter.convertType(t);
|
||||
});
|
||||
// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(lutOp, resultTy, glweKs,
|
||||
glweLut, -1, -1);
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1, -1);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
@@ -152,15 +152,13 @@ struct BootstrapGLWEOpPattern
|
||||
auto newInputTy = converter.glweIntraPBSType(inputTy);
|
||||
auto outputTy = bsOp.result().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newOutputTy = converter.convertType(outputTy);
|
||||
auto tableTy =
|
||||
bsOp.lookup_table().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newTableTy = converter.glweLookupTableType(tableTy);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(),
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase);
|
||||
cryptoParameters.nSmall, cryptoParameters.getPolynomialSize(),
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
cryptoParameters.glweDimension);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.ciphertext().setType(newInputTy);
|
||||
newOp.lookup_table().setType(newTableTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -212,49 +210,6 @@ private:
|
||||
mlir::concretelang::V0Parameter &cryptoParameters;
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
|
||||
/// parametrize GLWE return type and pad the table if the precision has been
|
||||
/// changed.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %lut = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>
|
||||
/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<4xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %lut = arith.constant dense<[0, 1, 2, 3, 0, 1, 2, 3]> : tensor<8xi64>
|
||||
/// %0 = "TFHE.glwe_from_table" (%lut) : (tensor<8xi64>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{3}>
|
||||
/// ```
|
||||
struct GLWEFromTablePattern
|
||||
: public mlir::OpRewritePattern<TFHE::GLWEFromTableOp> {
|
||||
GLWEFromTablePattern(mlir::MLIRContext *context,
|
||||
TFHEGlobalParametrizationTypeConverter &converter,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<TFHE::GLWEFromTableOp>(context, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::GLWEFromTableOp glweOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto outputTy = glweOp.result().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newOutputTy = converter.glweLookupTableType(outputTy);
|
||||
auto tableOp = glweOp.table();
|
||||
rewriter.replaceOpWithNewOp<TFHE::GLWEFromTableOp>(glweOp, newOutputTy,
|
||||
tableOp);
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
TFHEGlobalParametrizationTypeConverter &converter;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void populateWithTFHEOpTypeConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target,
|
||||
@@ -316,13 +271,6 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
patterns, converter);
|
||||
|
||||
// Parametrize keyswitch bootstrap
|
||||
patterns.add<GLWEFromTablePattern>(&getContext(), converter);
|
||||
target.addDynamicallyLegalOp<TFHE::GLWEFromTableOp>(
|
||||
[&](TFHE::GLWEFromTableOp op) {
|
||||
return !op.getType()
|
||||
.cast<TFHE::GLWECipherTextType>()
|
||||
.hasUnparametrizedParameters();
|
||||
});
|
||||
target.addLegalOp<mlir::arith::ConstantOp>();
|
||||
patterns.add<KeySwitchGLWEOpPattern>(&getContext(), converter,
|
||||
cryptoParameters);
|
||||
|
||||
@@ -66,25 +66,6 @@ public:
|
||||
|
||||
namespace {
|
||||
|
||||
struct GLWEFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<TFHE::GLWEFromTableOp> {
|
||||
GLWEFromTableOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<TFHE::GLWEFromTableOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::GLWEFromTableOp glweOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto oldTy = glweOp.getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newTy = rewriter.getType<Concrete::GlweCiphertextType>(
|
||||
oldTy.getDimension(), oldTy.getPolynomialSize(), oldTy.getP());
|
||||
|
||||
rewriter.replaceOpWithNewOp<Concrete::GlweFromTable>(glweOp, newTy,
|
||||
glweOp.table());
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct BootstrapGLWEOpPattern
|
||||
: public mlir::OpRewritePattern<TFHE::BootstrapGLWEOp> {
|
||||
BootstrapGLWEOpPattern(mlir::MLIRContext *context,
|
||||
@@ -98,21 +79,31 @@ struct BootstrapGLWEOpPattern
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Type resultType = converter.convertType(bsOp.getType());
|
||||
|
||||
auto precision = bsOp.getType().cast<TFHE::GLWECipherTextType>().getP();
|
||||
|
||||
mlir::Value inputLweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.inputLweDim(), 32);
|
||||
mlir::Value polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.polySize(), 32);
|
||||
mlir::Value levelCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.level(), 32);
|
||||
mlir::Value baseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.baseLog(), 32);
|
||||
mlir::Value glweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.glweDimension(), 32);
|
||||
mlir::Value precisionCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), precision, 32);
|
||||
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Concrete::BootstrapLweOp>(
|
||||
bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), bsOp.level(),
|
||||
bsOp.baseLog());
|
||||
bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(),
|
||||
inputLweDimCst, polySizeCst, levelCst, baseLogCst, glweDimCst,
|
||||
precisionCst);
|
||||
|
||||
rewriter.startRootUpdate(newOp);
|
||||
|
||||
newOp.input_ciphertext().setType(
|
||||
converter.convertType(bsOp.ciphertext().getType()));
|
||||
|
||||
auto oldTy = bsOp.lookup_table().getType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newTy = rewriter.getType<Concrete::GlweCiphertextType>(
|
||||
oldTy.getDimension(), oldTy.getPolynomialSize(), oldTy.getP());
|
||||
newOp.accumulator().setType(newTy);
|
||||
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
@@ -170,6 +161,9 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
// Make sure that no ops from `TFHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
|
||||
// Legalize arith.constant operations introduced by some patterns
|
||||
target.addLegalOp<mlir::arith::ConstantOp>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp, mlir::scf::ForOp>(
|
||||
@@ -201,7 +195,6 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp,
|
||||
mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter);
|
||||
patterns.add<GLWEFromTableOpPattern>(&getContext());
|
||||
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter);
|
||||
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter);
|
||||
target.addDynamicallyLegalOp<Concrete::BootstrapLweOp>(
|
||||
|
||||
@@ -42,6 +42,9 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::func::FuncOp>(block->getParentOp())) {
|
||||
block = &mlir::cast<mlir::func::FuncOp>(block->getParentOp())
|
||||
.getBody()
|
||||
.front();
|
||||
|
||||
auto context = std::find_if(
|
||||
block->getArguments().rbegin(), block->getArguments().rend(),
|
||||
|
||||
@@ -92,7 +92,6 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
auto i64PointerType = mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type());
|
||||
|
||||
mlir::FunctionType funcType;
|
||||
|
||||
@@ -114,17 +113,21 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType}, {});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, contextType}, {});
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_bootstrap_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, contextType}, {futureType});
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
@@ -133,7 +136,7 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i64PointerType},
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
@@ -244,90 +247,6 @@ struct BufferizableWithCallOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferizableGlweFromTableOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableGlweFromTableOpInterface, BConcrete::FillGlweFromTable> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
/// Bufferize GlweFromTable
|
||||
/// ```
|
||||
/// "BConcrete.fill_glwe_table"(%glwe, %lut) {glweDimension=1,
|
||||
/// polynomialSize=2048, outPrecision=3} :
|
||||
/// (tensor<4096xi64>, tensor<32xi64>) -> ()
|
||||
/// ```
|
||||
///
|
||||
/// to
|
||||
///
|
||||
/// ```
|
||||
/// %glweDim = arith.constant 1 : i32
|
||||
/// %polySize = arith.constant 2048 : i32
|
||||
/// %outPrecision = arith.constant 3 : i32
|
||||
/// %glwe_ = memref.cast %glwe : memref<4096xi64> to memref<?xi64>
|
||||
/// %lut_ = memref.cast %lut : memref<32xi64> to memref<?xi64>
|
||||
/// call @expand_lut_in_trivial_glwe_ct(%glwe, %polySize, %glweDim,
|
||||
/// %outPrecision, %lut_) :
|
||||
/// (tensor<?xi64>, i32, i32, tensor<?xi64>) -> ()
|
||||
/// ```
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<BConcrete::FillGlweFromTable>(op);
|
||||
|
||||
auto glweOp =
|
||||
getCastedMemRef(rewriter, loc,
|
||||
bufferization::getBuffer(
|
||||
rewriter, castOp->getOpOperand(0).get(), options));
|
||||
auto lutOp =
|
||||
getCastedMemRef(rewriter, loc,
|
||||
bufferization::getBuffer(
|
||||
rewriter, castOp->getOpOperand(1).get(), options));
|
||||
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize()));
|
||||
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI32IntegerAttr(castOp.glweDimension()));
|
||||
auto outPrecisionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI32IntegerAttr(castOp.outPrecision()));
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands{glweOp, polySizeOp, glweDimensionOp,
|
||||
outPrecisionOp, lutOp};
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(
|
||||
op, rewriter, memref_expand_lut_in_trivial_glwe_ct_u64)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(
|
||||
loc, memref_expand_lut_in_trivial_glwe_ct_u64, mlir::TypeRange{},
|
||||
operands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, {});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName, bool withContext = false>
|
||||
struct BufferizableWithAsyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
@@ -565,7 +484,7 @@ void mlir::concretelang::BConcrete::
|
||||
*ctx);
|
||||
BConcrete::BootstrapLweGPUBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweGPUBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64, false>>(
|
||||
memref_bootstrap_lwe_cuda_u64, true>>(
|
||||
*ctx);
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
@@ -587,7 +506,5 @@ void mlir::concretelang::BConcrete::
|
||||
BConcrete::AwaitFutureOp::attachInterface<
|
||||
BufferizableWithSyncCallOpInterface<BConcrete::AwaitFutureOp,
|
||||
memref_await_future>>(*ctx);
|
||||
BConcrete::FillGlweFromTable::attachInterface<
|
||||
BufferizableGlweFromTableOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -45,25 +45,44 @@ void async_bootstrap(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context,
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise) {
|
||||
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_engine(context), glwe_ct, glwe_ct_size,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fftw_engine(context), get_engine(context),
|
||||
get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset));
|
||||
ct0_aligned + ct0_offset, glwe_ct));
|
||||
promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{
|
||||
out_allocated, out_aligned, out_offset, out_size, out_stride});
|
||||
free(glwe_ct);
|
||||
}
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise;
|
||||
auto ret = new std::future<concretelang::clientlib::MemRefDescriptor<1>>(
|
||||
@@ -71,8 +90,9 @@ void *memref_bootstrap_async_lwe_u64(
|
||||
std::thread offload_thread(
|
||||
async_bootstrap, out_allocated, out_aligned, out_offset, out_size,
|
||||
out_stride, ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride,
|
||||
glwe_ct_allocated, glwe_ct_aligned, glwe_ct_offset, glwe_ct_size,
|
||||
glwe_ct_stride, context, std::move(promise));
|
||||
tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride,
|
||||
input_lwe_dim, poly_size, level, base_log, glwe_dim, precision, context,
|
||||
std::move(promise));
|
||||
offload_thread.detach();
|
||||
return (void *)ret;
|
||||
}
|
||||
|
||||
@@ -23,9 +23,6 @@ DefaultEngine *get_levelled_engine() {
|
||||
return levelled_engine;
|
||||
}
|
||||
|
||||
// This helper function expands the input LUT into output, duplicating values as
|
||||
// needed to fill mega cases, taking care of the encoding and the half mega case
|
||||
// shift in the process as well. All sizes should be powers of 2.
|
||||
void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
size_t out_MESSAGE_BITS, const uint64_t *lut,
|
||||
size_t lut_size) {
|
||||
@@ -65,17 +62,15 @@ void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
typedef struct double2 {
|
||||
double x, y;
|
||||
} double2;
|
||||
|
||||
// From concrete-cuda
|
||||
#include "bootstrap.h"
|
||||
#include "device.h"
|
||||
|
||||
void memref_keyswitch_lwe_cuda_u64(uint64_t *out_allocated,
|
||||
uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride,
|
||||
uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
void *ksk_gpu) {
|
||||
void memref_keyswitch_lwe_cuda_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context) {
|
||||
// TODO: GPU implementation
|
||||
}
|
||||
|
||||
@@ -83,25 +78,37 @@ void *move_ct_to_gpu(uint64_t *ct_allocated, uint64_t *ct_aligned,
|
||||
uint64_t ct_offset, uint64_t ct_size, uint64_t ct_stride,
|
||||
uint32_t gpu_idx) {
|
||||
void *stream = cuda_create_stream(gpu_idx);
|
||||
void *ct_gpu = cuda_malloc(ct_size * sizeof(uint64_t), gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(ct_gpu, ct_aligned + ct_offset,
|
||||
ct_size * sizeof(uint64_t), stream, gpu_idx);
|
||||
size_t buf_size = ct_size * sizeof(uint64_t);
|
||||
void *ct_gpu = cuda_malloc(buf_size, gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(ct_gpu, ct_aligned + ct_offset, buf_size, stream,
|
||||
gpu_idx);
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
cuda_destroy_stream(stream, gpu_idx);
|
||||
return ct_gpu;
|
||||
}
|
||||
|
||||
void *move_bsk_to_gpu(mlir::concretelang::RuntimeContext *context,
|
||||
uint32_t gpu_idx = 0) {
|
||||
uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx = 0) {
|
||||
void *stream = cuda_create_stream(gpu_idx);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
BufferView bskBuffer = bootstrap_buffer_lwe_u64(bsk);
|
||||
void *bsk_gpu = cuda_malloc(bskBuffer.length, gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(bsk_gpu, (void *)bskBuffer.pointer, bskBuffer.length,
|
||||
stream, gpu_idx);
|
||||
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
|
||||
size_t bsk_buffer_len =
|
||||
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level;
|
||||
size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t);
|
||||
uint64_t *bsk_buffer =
|
||||
(uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size);
|
||||
size_t fbsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
|
||||
void *fbsk_gpu = cuda_malloc(fbsk_gpu_buffer_size, gpu_idx);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), bsk, bsk_buffer));
|
||||
cuda_initialize_twiddles(poly_size, gpu_idx);
|
||||
cuda_convert_lwe_bootstrap_key_64(fbsk_gpu, bsk_buffer, stream, gpu_idx,
|
||||
input_lwe_dim, glwe_dim, level, poly_size);
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
cuda_destroy_stream(stream, gpu_idx);
|
||||
return bsk_gpu;
|
||||
free(bsk_buffer);
|
||||
return fbsk_gpu;
|
||||
}
|
||||
|
||||
void move_ct_to_cpu(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
@@ -125,44 +132,58 @@ void memref_bootstrap_lwe_cuda_u64(
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, void *bsk_gpu) {
|
||||
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
// we currently just use the first GPU, but this should be decided
|
||||
// dynamically, or during compilation, in the future
|
||||
uint32_t gpu_idx = 0;
|
||||
void *stream = cuda_create_stream(gpu_idx);
|
||||
|
||||
// move bsk to gpu
|
||||
void *fbsk_gpu = move_bsk_to_gpu(context, input_lwe_dim, poly_size, level,
|
||||
glwe_dim, gpu_idx);
|
||||
// move input ciphertext into gpu
|
||||
void *ct0_gpu = move_ct_to_gpu(ct0_allocated, ct0_aligned, ct0_offset,
|
||||
ct0_size, ct0_stride, gpu_idx);
|
||||
// move output ciphertext into gpu
|
||||
void *out_gpu = move_ct_to_gpu(out_allocated, out_aligned, out_offset,
|
||||
out_size, out_stride, gpu_idx);
|
||||
// hardcoded values
|
||||
// construct LUT GLWE ciphertext
|
||||
uint64_t glwe_ct_len = poly_size * (glwe_dim + 1);
|
||||
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_len,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
// move test vector into gpu
|
||||
void *test_vector_gpu =
|
||||
cuda_malloc(poly_size * (glwe_dim + 1) * sizeof(uint64_t), gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(test_vector_gpu, (void *)glwe_ct, glwe_ct_size,
|
||||
stream, gpu_idx);
|
||||
// free LUT ciphertext (CPU)
|
||||
free(glwe_ct);
|
||||
// move test vector indexes into gpu
|
||||
uint32_t num_samples = 1, num_test_vectors = 1, lwe_idx = 0;
|
||||
void *test_vector_idxes = malloc(num_samples * sizeof(uint32_t));
|
||||
((uint32_t *)test_vector_idxes)[0] = 0;
|
||||
void *test_vector = malloc(poly_size * sizeof(uint64_t));
|
||||
for (size_t i = 0; i < poly_size; i++) {
|
||||
((uint64_t *)test_vector)[i] = (uint64_t)1 << 61;
|
||||
}
|
||||
// move test vector into gpu
|
||||
void *test_vector_gpu = cuda_malloc(poly_size * sizeof(uint64_t), gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(test_vector_gpu, test_vector,
|
||||
poly_size * sizeof(uint64_t), stream, gpu_idx);
|
||||
// move test vector indexes into gpu
|
||||
void *test_vector_idxes_gpu =
|
||||
cuda_malloc(num_samples * sizeof(uint32_t), gpu_idx);
|
||||
cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, test_vector_idxes,
|
||||
num_samples * sizeof(uint32_t), stream, gpu_idx);
|
||||
// run gpu bootstrap
|
||||
cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
stream, out_gpu, test_vector_gpu, test_vector_idxes_gpu, ct0_gpu, bsk_gpu,
|
||||
input_lwe_dim, poly_size, base_log, level, num_samples, num_test_vectors,
|
||||
lwe_idx, cuda_get_max_shared_memory(gpu_idx));
|
||||
stream, out_gpu, test_vector_gpu, test_vector_idxes_gpu, ct0_gpu,
|
||||
fbsk_gpu, input_lwe_dim, poly_size, base_log, level, num_samples,
|
||||
num_test_vectors, lwe_idx, cuda_get_max_shared_memory(gpu_idx));
|
||||
// copy output ciphertext back to cpu
|
||||
move_ct_to_cpu(out_allocated, out_aligned, out_offset, out_size, out_stride,
|
||||
out_gpu, out_size, gpu_idx);
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
// free memory that we allocated on gpu
|
||||
cuda_drop(fbsk_gpu, gpu_idx);
|
||||
cuda_drop(ct0_gpu, gpu_idx);
|
||||
cuda_drop(out_gpu, gpu_idx);
|
||||
cuda_drop(test_vector_gpu, gpu_idx);
|
||||
@@ -271,14 +292,31 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_size,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fftw_engine(context), get_engine(context),
|
||||
get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset));
|
||||
ct0_aligned + ct0_offset, glwe_ct));
|
||||
free(glwe_ct);
|
||||
}
|
||||
|
||||
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
|
||||
|
||||
@@ -317,14 +317,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
// Transforming into GPU
|
||||
if (this->compilerOptions.useGPU &&
|
||||
mlir::concretelang::pipeline::transformsConcreteToGPU(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Transforming Concrete to GPU failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETE)
|
||||
return std::move(res);
|
||||
|
||||
@@ -363,7 +355,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
mlirContext, module, this->enablePass, loopParallelize,
|
||||
options.useGPU)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete to Bufferized Concrete failed");
|
||||
|
||||
@@ -239,26 +239,16 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
transformsConcreteToGPU(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToGPU", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertConcreteToGPUPass(), enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops) {
|
||||
bool parallelizeLoops, bool useGPU) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToBConcrete", pm, context);
|
||||
|
||||
std::unique_ptr<Pass> conversionPass =
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass(
|
||||
parallelizeLoops);
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass(parallelizeLoops,
|
||||
useGPU);
|
||||
|
||||
bool passEnabled = enablePass(conversionPass.get());
|
||||
|
||||
|
||||
@@ -98,11 +98,11 @@ llvm::cl::opt<bool>
|
||||
"dialects. (Enabled by default)"),
|
||||
llvm::cl::init<bool>(true));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
useGPU("use-gpu",
|
||||
llvm::cl::desc("enable/disable generating concrete GPU "
|
||||
"operations (Disabled by default)"),
|
||||
llvm::cl::init<bool>(false));
|
||||
llvm::cl::opt<bool> useGPU(
|
||||
"use-gpu",
|
||||
llvm::cl::desc(
|
||||
"enable/disable generating GPU operations (Disabled by default)"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::list<std::string> passes(
|
||||
"passes",
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2048xi64>
|
||||
//CHECK: "BConcrete.fill_glwe_from_table"(%[[V0]], %[[A1]]) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 1024 : i32} : (tensor<2048xi64>, tensor<16xi64>) -> ()
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 1024 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 2 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 3 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C6:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %[[V0]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (tensor<601xi64>, tensor<2048xi64>) -> tensor<1025xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%0 = "Concrete.glwe_from_table"(%arg1) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1,1024,4>
|
||||
%c1 = arith.constant 600 : i32
|
||||
%c2 = arith.constant 1024 : i32
|
||||
%c3 = arith.constant 2 : i32
|
||||
%c4 = arith.constant 3 : i32
|
||||
%c5 = arith.constant 1 : i32
|
||||
%c6 = arith.constant 4 : i32
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1,1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %arg1, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 2048 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 5 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4096xi64>
|
||||
//CHECK: "BConcrete.fill_glwe_from_table"(%[[V0]], %cst) {glweDimension = 1 : i32, outPrecision = 4 : i32, polynomialSize = 2048 : i32} : (tensor<4096xi64>, tensor<16xi64>) -> ()
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %[[V0]]) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (tensor<601xi64>, tensor<4096xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C3]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V2]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
|
||||
%c1 = arith.constant 600 : i32
|
||||
%c2 = arith.constant 2048 : i32
|
||||
%c3 = arith.constant 4 : i32
|
||||
%c4 = arith.constant 5 : i32
|
||||
%c5 = arith.constant 1 : i32
|
||||
%c6 = arith.constant 4 : i32
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%0 = "Concrete.glwe_from_table"(%tlu) {glweDimension = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32} : (tensor<16xi64>) -> !Concrete.glwe_ciphertext<1, 2048,4>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, glweDimension = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<600,4>, !Concrete.glwe_ciphertext<1,2048,4>) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %tlu, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<2048,4>
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{3}>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: }
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
|
||||
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> {
|
||||
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK: %[[V0:.*]] = "TFHE.glwe_from_table"(%cst) : (tensor<16xi64>) -> !TFHE.glwe<{2,1024,64}{4}>
|
||||
//CHECK: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = 4 : i32, level = 3 : i32} : (!TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{750,1,64}{4}>
|
||||
//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[V0]]) {baseLog = 23 : i32, level = 1 : i32} : (!TFHE.glwe<{750,1,64}{4}>, !TFHE.glwe<{2,1024,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = 23 : i32, glweDimension = 2 : i32, inputLweDim = 750 : i32, level = 1 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{750,1,64}{4}>, tensor<16xi64>) -> !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: return %[[V2]] : !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: }
|
||||
func.func @main(%arg0: !TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}> {
|
||||
%cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%0 = "TFHE.glwe_from_table"(%cst) : (tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
%1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, !TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
return %2 : !TFHE.glwe<{_,_,_}{4}>
|
||||
}
|
||||
|
||||
@@ -2,13 +2,17 @@
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<600,7>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK: %[[V0:.*]] = "Concrete.glwe_from_table"(%cst) : (tensor<128xi64>) -> !Concrete.glwe_ciphertext<1,1024,7>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%[[A0]], %[[V0]]) {baseLog = 1 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<600,7>, !Concrete.glwe_ciphertext<1,1024,7>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 1024 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 3 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C6:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (!Concrete.lwe_ciphertext<600,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<{600,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
%glwe_lut = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
%bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %glwe_lut) {baseLog = 1 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, !TFHE.glwe<{1,1024,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}>
|
||||
%bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1024,1,64}{4}>
|
||||
return %bootstraped : !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @glwe_from_table() {
|
||||
//CHECK-NEXT: %[[V0:.*]] = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "Concrete.glwe_from_table"(%[[V0]]) : (tensor<128xi64>) -> !Concrete.glwe_ciphertext<1,1024,7>
|
||||
//CHECK-NEXT: return
|
||||
//CHECK-NEXT: }
|
||||
func.func @glwe_from_table() {
|
||||
%cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
%0 = "TFHE.glwe_from_table"(%cst) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
return
|
||||
}
|
||||
@@ -72,12 +72,12 @@ func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<4096xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[A0]], %[[A1]]) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> tensor<2049xi64>
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<4096xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> (tensor<2049xi64>)
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
|
||||
@@ -36,11 +36,11 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
|
||||
@@ -8,18 +8,11 @@ func.func @keyswitch_glwe(%arg0: !TFHE.glwe<{1,1024,64}{7}>) -> !TFHE.glwe<{1,52
|
||||
return %0: !TFHE.glwe<{1,527,64}{7}>
|
||||
}
|
||||
|
||||
// CHECK: func.func @bootstrap_glwe(%[[GLWE:.*]]: !TFHE.glwe<{1,527,64}{7}>, %[[LUT:.*]]: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
func.func @bootstrap_glwe(%glwe: !TFHE.glwe<{1,527,64}{7}>, %lookup_table_glwe: !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.bootstrap_glwe"(%[[GLWE]], %[[LUT]]) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
// CHECK: func.func @bootstrap_glwe(%[[GLWE:.*]]: !TFHE.glwe<{1,527,64}{7}>, %[[LUT:.*]]: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
func.func @bootstrap_glwe(%glwe: !TFHE.glwe<{1,527,64}{7}>, %lut: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.bootstrap_glwe"(%[[GLWE]], %[[LUT]]) {baseLog = 2 : i32, glweDimension = 1 : i32, inputLweDim = 527 : i32, level = 3 : i32, polySize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
// CHECK-NEXT: return %[[V0]] : !TFHE.glwe<{1,1024,64}{7}>
|
||||
%0 = "TFHE.bootstrap_glwe"(%glwe, %lookup_table_glwe) {baseLog = 2 : i32, glweDimension = 1 : i32, level = 3 : i32, polynomialSize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, !TFHE.glwe<{1,527,64}{7}>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
%0 = "TFHE.bootstrap_glwe"(%glwe, %lut) {baseLog = 2 : i32, glweDimension = 1 : i32, inputLweDim = 527 : i32, level = 3 : i32, polySize = 2048 : i32} : (!TFHE.glwe<{1,527,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
return %0 : !TFHE.glwe<{1,1024,64}{7}>
|
||||
}
|
||||
|
||||
// CHECK: func.func @glwe_from_table(%[[LUT:.*]]: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
func.func @glwe_from_table(%lookup_table: tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "TFHE.glwe_from_table"(%[[LUT]]) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
// CHECK-NEXT: return %[[V0]] : !TFHE.glwe<{1,1024,64}{7}>
|
||||
%0 = "TFHE.glwe_from_table"(%lookup_table) : (tensor<128xi64>) -> !TFHE.glwe<{1,1024,64}{7}>
|
||||
return %0 : !TFHE.glwe<{1,1024,64}{7}>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user