mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 12:44:57 -05:00
feat(compiler): add an asynchronous interface for bootstrap and keyswitch using std::promise/future.
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h.inc"
|
||||
|
||||
@@ -8,6 +8,8 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
|
||||
include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
|
||||
include "concretelang/Dialect/RT/IR/RTDialect.td"
|
||||
include "concretelang/Dialect/RT/IR/RTTypes.td"
|
||||
|
||||
class BConcrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<BConcrete_Dialect, mnemonic, traits>;
|
||||
@@ -125,4 +127,33 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"keyswitch_lwe_buffer_async_offload"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> {
|
||||
let arguments = (ins
|
||||
// LweBootstrapKeyType:$bootstrap_key,
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
def BConcrete_AwaitFutureOp :
|
||||
BConcrete_Op<"await_future"> {
|
||||
let arguments = (ins RT_Future : $future);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAddRuntimeContext();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncOffload();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -21,4 +21,9 @@ def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp">
|
||||
let constructor = "mlir::concretelang::createEliminateCRTOpsPass()";
|
||||
}
|
||||
|
||||
def AsyncOffload : Pass<"async-offload", "mlir::ModuleOp"> {
|
||||
let summary = "Replace keyswitch and bootstrap operations by async versions and add synchronisation.";
|
||||
let constructor = "mlir::concretelang::createAsyncOffload()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
|
||||
@@ -48,6 +48,11 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_bootstrap_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -56,6 +61,20 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, void *future,
|
||||
uint64_t *in_allocated, uint64_t *in_aligned,
|
||||
uint64_t in_offset, uint64_t in_size,
|
||||
uint64_t in_stride);
|
||||
|
||||
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product);
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ struct CompilationOptions {
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool dataflowParallelize;
|
||||
bool asyncOffload;
|
||||
bool optimizeConcrete;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
@@ -62,7 +63,7 @@ struct CompilationOptions {
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false),
|
||||
dataflowParallelize(false), optimizeConcrete(true),
|
||||
dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
|
||||
@@ -53,6 +53,10 @@ mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
80
compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp
Normal file
80
compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct AsyncOffloadPass : public AsyncOffloadBase<AsyncOffloadPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void AsyncOffloadPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
std::vector<mlir::Operation *> ops;
|
||||
|
||||
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweBufferOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs());
|
||||
|
||||
assert(op.getResult().hasOneUse() &&
|
||||
"Single use assumed (for deallocation purposes - restriction can be "
|
||||
"lifted).");
|
||||
for (auto &use : op.getResult().getUses()) {
|
||||
builder.setInsertionPoint(use.getOwner());
|
||||
mlir::Value res =
|
||||
builder.create<mlir::concretelang::BConcrete::AwaitFutureOp>(
|
||||
use.getOwner()->getLoc(),
|
||||
mlir::TypeRange{op.getResult().getType()}, future);
|
||||
use.set(res);
|
||||
}
|
||||
ops.push_back(op);
|
||||
});
|
||||
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweBufferOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperands(),
|
||||
op->getAttrs());
|
||||
|
||||
assert(op.getResult().hasOneUse() &&
|
||||
"Single use assumed (for deallocation purposes - restriction can be "
|
||||
"lifted).");
|
||||
for (auto &use : op.getResult().getUses()) {
|
||||
builder.setInsertionPoint(use.getOwner());
|
||||
mlir::Value res =
|
||||
builder.create<mlir::concretelang::BConcrete::AwaitFutureOp>(
|
||||
use.getOwner()->getLoc(),
|
||||
mlir::TypeRange{op.getResult().getType()}, future);
|
||||
use.set(res);
|
||||
}
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (auto op : ops)
|
||||
op->erase();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncOffload() {
|
||||
return std::make_unique<AsyncOffloadPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -85,6 +85,9 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
|
||||
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
|
||||
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
|
||||
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
char memref_await_future[] = "memref_await_future";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
@@ -95,9 +98,10 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto futureType =
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
mlir::FunctionType funcType;
|
||||
|
||||
if (funcName == memref_add_lwe_ciphertexts_u64) {
|
||||
@@ -121,6 +125,18 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, contextType}, {});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_bootstrap_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, contextType}, {futureType});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, futureType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
@@ -333,6 +349,154 @@ struct BufferizableGlweFromTableOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName, bool withContext = false>
|
||||
struct BufferizableWithAsyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithAsyncCallOpInterface<Op, funcName, withContext>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<Op>(op);
|
||||
|
||||
// For now we always alloc for the result, we didn't have the in place
|
||||
// operators yet.
|
||||
auto resTensorType =
|
||||
castOp.result()
|
||||
.getType()
|
||||
.template cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()
|
||||
.template cast<mlir::TensorType>();
|
||||
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
if (mlir::failed(outMemref)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value, 3> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
auto memrefOperand =
|
||||
bufferization::getBuffer(rewriter, operand.get(), options);
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
|
||||
}
|
||||
}
|
||||
// Append the context argument
|
||||
if (withContext) {
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto result = rewriter.create<mlir::func::CallOp>(
|
||||
loc, funcName,
|
||||
mlir::TypeRange{
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())},
|
||||
operands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, result.getResult(0));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithSyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithSyncCallOpInterface<Op, funcName>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<Op>(op);
|
||||
|
||||
auto resTensorType =
|
||||
castOp.result().getType().template cast<mlir::TensorType>();
|
||||
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
if (mlir::failed(outMemref)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value, 3> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// Then add the future operand
|
||||
operands.push_back(op->getOpOperand(0).get());
|
||||
// Finally add a dependence on the memref covered by the future to
|
||||
// prevent early deallocation
|
||||
auto def = op->getOpOperand(0).get().getDefiningOp();
|
||||
operands.push_back(def->getOpOperand(0).get());
|
||||
operands.push_back(def->getOpOperand(1).get());
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
|
||||
operands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::concretelang::BConcrete::
|
||||
@@ -363,6 +527,17 @@ void mlir::concretelang::BConcrete::
|
||||
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
|
||||
memref_wop_pbs_crt_buffer, true>>(*ctx);
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp,
|
||||
memref_keyswitch_async_lwe_u64, true>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp,
|
||||
memref_bootstrap_async_lwe_u64, true>>(*ctx);
|
||||
BConcrete::AwaitFutureOp::attachInterface<
|
||||
BufferizableWithSyncCallOpInterface<BConcrete::AwaitFutureOp,
|
||||
memref_await_future>>(*ctx);
|
||||
BConcrete::FillGlweFromTable::attachInterface<
|
||||
BufferizableGlweFromTableOpInterface>(*ctx);
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
AddRuntimeContext.cpp
|
||||
EliminateCRTOps.cpp
|
||||
AsyncOffload.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
|
||||
93
compiler/lib/Runtime/AsyncOffload.cpp
Normal file
93
compiler/lib/Runtime/AsyncOffload.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <future>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <thread>
|
||||
|
||||
void async_keyswitch(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context,
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_engine(context), get_keyswitch_key_u64(context),
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset));
|
||||
promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{
|
||||
out_allocated, out_aligned, out_offset, out_size, out_stride});
|
||||
}
|
||||
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context) {
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise;
|
||||
auto ret = new std::future<concretelang::clientlib::MemRefDescriptor<1>>(
|
||||
promise.get_future());
|
||||
std::thread offload_thread(async_keyswitch, out_allocated, out_aligned,
|
||||
out_offset, out_size, out_stride, ct0_allocated,
|
||||
ct0_aligned, ct0_offset, ct0_size, ct0_stride,
|
||||
context, std::move(promise));
|
||||
offload_thread.detach();
|
||||
return (void *)ret;
|
||||
}
|
||||
|
||||
void async_bootstrap(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context,
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fftw_engine(context), get_engine(context),
|
||||
get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset));
|
||||
promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{
|
||||
out_allocated, out_aligned, out_offset, out_size, out_stride});
|
||||
}
|
||||
|
||||
void *memref_bootstrap_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
std::promise<concretelang::clientlib::MemRefDescriptor<1>> promise;
|
||||
auto ret = new std::future<concretelang::clientlib::MemRefDescriptor<1>>(
|
||||
promise.get_future());
|
||||
std::thread offload_thread(
|
||||
async_bootstrap, out_allocated, out_aligned, out_offset, out_size,
|
||||
out_stride, ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride,
|
||||
glwe_ct_allocated, glwe_ct_aligned, glwe_ct_offset, glwe_ct_size,
|
||||
glwe_ct_stride, context, std::move(promise));
|
||||
offload_thread.detach();
|
||||
return (void *)ret;
|
||||
}
|
||||
|
||||
void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, void *fut, uint64_t *in_allocated,
|
||||
uint64_t *in_aligned, uint64_t in_offset,
|
||||
uint64_t in_size, uint64_t in_stride) {
|
||||
auto future =
|
||||
static_cast<std::future<concretelang::clientlib::MemRefDescriptor<1>> *>(
|
||||
fut);
|
||||
auto desc = future->get();
|
||||
memref_copy_one_rank(desc.allocated, desc.aligned, desc.offset, desc.sizes[0],
|
||||
desc.strides[0], out_allocated, out_aligned, out_offset,
|
||||
out_size, out_stride);
|
||||
delete future;
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
add_library(ConcretelangRuntime SHARED
|
||||
context.cpp
|
||||
wrappers.cpp
|
||||
AsyncOffload.cpp
|
||||
DFRuntime.cpp
|
||||
seeder.cpp
|
||||
)
|
||||
|
||||
@@ -372,6 +372,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag(
|
||||
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
||||
}
|
||||
|
||||
// Make keyswitch and bootstrap asynchronous
|
||||
if (options.asyncOffload) {
|
||||
if (mlir::concretelang::pipeline::asyncOffload(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Converting Keyswitch and Bootstrap to asynchronous "
|
||||
"operations failed.");
|
||||
}
|
||||
}
|
||||
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -264,6 +264,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult asyncOffload(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("AsyncOffload", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAsyncOffload(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
|
||||
@@ -162,6 +162,11 @@ llvm::cl::opt<bool> dataflowParallelize(
|
||||
"Generate (and execute if JIT) the program as a dataflow graph"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> asyncOffload(
|
||||
"async-offload",
|
||||
llvm::cl::desc("Use asynchronous interface for keyswitch and bootstrap."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
|
||||
Reference in New Issue
Block a user