mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor(compiler): Remove async offloading of BS/KS
This commit is contained in:
@@ -127,7 +127,6 @@ def BConcrete_BatchedBootstrapLweTensorOp : BConcrete_Op<"batched_bootstrap_lwe_
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
@@ -151,31 +150,6 @@ def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweTensorAsyncOffloadOp :
|
||||
BConcrete_Op<"keyswitch_lwe_tensor_async_offload"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweTensorAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_tensor_async_offload"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
// BConcrete memref operators /////////////////////////////////////////////////
|
||||
|
||||
def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
@@ -263,7 +237,6 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweCRTBuffer:$result,
|
||||
@@ -287,14 +260,4 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO
|
||||
|
||||
|
||||
|
||||
def BConcrete_AwaitFutureOp :
|
||||
BConcrete_Op<"await_future"> {
|
||||
let arguments = (ins RT_Future : $future);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -15,7 +15,6 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAddRuntimeContext();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncOffload();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -16,14 +16,10 @@ def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> {
|
||||
let constructor = "mlir::concretelang::createAddRuntimeContext()";
|
||||
}
|
||||
|
||||
def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> {
|
||||
def EliminateCRTOps
|
||||
: Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> {
|
||||
let summary = "Eliminate the crt bconcrete operators.";
|
||||
let constructor = "mlir::concretelang::createEliminateCRTOpsPass()";
|
||||
}
|
||||
|
||||
def AsyncOffload : Pass<"async-offload", "mlir::ModuleOp"> {
|
||||
let summary = "Replace keyswitch and bootstrap operations by async versions and add synchronisation.";
|
||||
let constructor = "mlir::concretelang::createAsyncOffload()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
|
||||
@@ -55,7 +55,6 @@ struct CompilationOptions {
|
||||
bool loopParallelize;
|
||||
bool batchConcreteOps;
|
||||
bool dataflowParallelize;
|
||||
bool asyncOffload;
|
||||
bool optimizeConcrete;
|
||||
/// use GPU during execution by generating GPU operations if possible
|
||||
bool emitGPUOps;
|
||||
@@ -68,8 +67,8 @@ struct CompilationOptions {
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true),
|
||||
emitGPUOps(false), clientParametersFuncName(llvm::None),
|
||||
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
|
||||
@@ -58,10 +58,6 @@ mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct AsyncOffloadPass : public AsyncOffloadBase<AsyncOffloadPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void AsyncOffloadPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
std::vector<mlir::Operation *> ops;
|
||||
|
||||
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweTensorOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweTensorAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs());
|
||||
|
||||
assert(op.getResult().hasOneUse() &&
|
||||
"Single use assumed (for deallocation purposes - restriction can be "
|
||||
"lifted).");
|
||||
for (auto &use : op.getResult().getUses()) {
|
||||
builder.setInsertionPoint(use.getOwner());
|
||||
mlir::Value res =
|
||||
builder.create<mlir::concretelang::BConcrete::AwaitFutureOp>(
|
||||
use.getOwner()->getLoc(),
|
||||
mlir::TypeRange{op.getResult().getType()}, future);
|
||||
use.set(res);
|
||||
}
|
||||
ops.push_back(op);
|
||||
});
|
||||
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweTensorOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::BootstrapLweTensorAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperands(),
|
||||
op->getAttrs());
|
||||
|
||||
assert(op.getResult().hasOneUse() &&
|
||||
"Single use assumed (for deallocation purposes - restriction can be "
|
||||
"lifted).");
|
||||
for (auto &use : op.getResult().getUses()) {
|
||||
builder.setInsertionPoint(use.getOwner());
|
||||
mlir::Value res =
|
||||
builder.create<mlir::concretelang::BConcrete::AwaitFutureOp>(
|
||||
use.getOwner()->getLoc(),
|
||||
mlir::TypeRange{op.getResult().getType()}, future);
|
||||
use.set(res);
|
||||
}
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (auto op : ops)
|
||||
op->erase();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncOffload() {
|
||||
return std::make_unique<AsyncOffloadPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -3,7 +3,6 @@ add_mlir_dialect_library(
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
AddRuntimeContext.cpp
|
||||
EliminateCRTOps.cpp
|
||||
AsyncOffload.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
DEPENDS
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <future>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <thread>
|
||||
|
||||
void async_keyswitch(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context,
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_engine(context), get_keyswitch_key_u64(context),
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset));
|
||||
promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{
|
||||
out_allocated, out_aligned, out_offset, {out_size}, {out_stride}});
|
||||
}
|
||||
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context) {
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise;
|
||||
auto ret = new std::future<concretelang::clientlib::MemRefDescriptor<1>>(
|
||||
promise.get_future());
|
||||
std::thread offload_thread(async_keyswitch, out_allocated, out_aligned,
|
||||
out_offset, out_size, out_stride, ct0_allocated,
|
||||
ct0_aligned, ct0_offset, ct0_size, ct0_stride,
|
||||
context, std::move(promise));
|
||||
offload_thread.detach();
|
||||
return (void *)ret;
|
||||
}
|
||||
|
||||
void async_bootstrap(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context,
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise) {
|
||||
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_engine(context), glwe_ct, glwe_ct_size,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fft_engine(context), get_engine(context),
|
||||
get_fft_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct));
|
||||
promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{
|
||||
out_allocated, out_aligned, out_offset, {out_size}, {out_stride}});
|
||||
free(glwe_ct);
|
||||
}
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise;
|
||||
auto ret = new std::future<concretelang::clientlib::MemRefDescriptor<1>>(
|
||||
promise.get_future());
|
||||
std::thread offload_thread(
|
||||
async_bootstrap, out_allocated, out_aligned, out_offset, out_size,
|
||||
out_stride, ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride,
|
||||
tlu_allocated, tlu_aligned, tlu_offset, tlu_size, tlu_stride,
|
||||
input_lwe_dim, poly_size, level, base_log, glwe_dim, precision, context,
|
||||
std::move(promise));
|
||||
offload_thread.detach();
|
||||
return (void *)ret;
|
||||
}
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, void *fut, uint64_t *in_allocated,
|
||||
uint64_t *in_aligned, uint64_t in_offset,
|
||||
uint64_t in_size, uint64_t in_stride) {
|
||||
auto future =
|
||||
static_cast<std::future<concretelang::clientlib::MemRefDescriptor<1>> *>(
|
||||
fut);
|
||||
auto desc = future->get();
|
||||
memref_copy_one_rank(desc.allocated, desc.aligned, desc.offset, desc.sizes[0],
|
||||
desc.strides[0], out_allocated, out_aligned, out_offset,
|
||||
out_size, out_stride);
|
||||
delete future;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp AsyncOffload.cpp DFRuntime.cpp seeder.cpp)
|
||||
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp seeder.cpp)
|
||||
|
||||
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
|
||||
target_link_libraries(ConcretelangRuntime PRIVATE HPX::hpx HPX::iostreams_component)
|
||||
|
||||
@@ -399,16 +399,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
||||
}
|
||||
|
||||
// Make keyswitch and bootstrap asynchronous
|
||||
if (options.asyncOffload) {
|
||||
if (mlir::concretelang::pipeline::asyncOffload(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Converting Keyswitch and Bootstrap to asynchronous "
|
||||
"operations failed.");
|
||||
}
|
||||
}
|
||||
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -274,16 +274,6 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("AsyncOffload", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAsyncOffload(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
|
||||
@@ -179,11 +179,6 @@ llvm::cl::opt<bool> dataflowParallelize(
|
||||
"Generate (and execute if JIT) the program as a dataflow graph"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> asyncOffload(
|
||||
"async-offload",
|
||||
llvm::cl::desc("Use asynchronous interface for keyswitch and bootstrap."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
@@ -300,7 +295,6 @@ cmdlineCompilationOptions() {
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.batchConcreteOps = cmdline::batchConcreteOps;
|
||||
options.asyncOffload = cmdline::asyncOffload;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
|
||||
|
||||
@@ -153,8 +153,6 @@ std::string printEndToEndDesc(const testing::TestParamInfo<TestParam> desc) {
|
||||
std::ostringstream opt;
|
||||
if (options.loopParallelize)
|
||||
opt << "_loop";
|
||||
if (options.asyncOffload)
|
||||
opt << "_async";
|
||||
if (options.dataflowParallelize)
|
||||
opt << "_dataflow";
|
||||
if (options.emitGPUOps)
|
||||
@@ -239,12 +237,6 @@ mlir::concretelang::CompilationOptions loopOptions() {
|
||||
return o;
|
||||
}
|
||||
|
||||
mlir::concretelang::CompilationOptions asyncOptions() {
|
||||
mlir::concretelang::CompilationOptions o("main");
|
||||
o.asyncOffload = true;
|
||||
return o;
|
||||
}
|
||||
|
||||
mlir::concretelang::CompilationOptions dataflowOptions() {
|
||||
mlir::concretelang::CompilationOptions o("main");
|
||||
o.dataflowParallelize = true;
|
||||
|
||||
Reference in New Issue
Block a user