diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h index e9f6d88a0..06cc85d02 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.h @@ -13,6 +13,7 @@ #include #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Dialect/RT/IR/RTTypes.h" #define GET_OP_CLASSES #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h.inc" diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index dfb6a236e..e5ff8711e 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -8,6 +8,8 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.td" include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td" +include "concretelang/Dialect/RT/IR/RTDialect.td" +include "concretelang/Dialect/RT/IR/RTTypes.td" class BConcrete_Op traits = []> : Op; @@ -125,4 +127,33 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { let results = (outs 2DTensorOf<[I64]>:$result); } +def BConcrete_KeySwitchLweBufferAsyncOffloadOp : + BConcrete_Op<"keyswitch_lwe_buffer_async_offload"> { + let arguments = (ins + // LweKeySwitchKeyType:$keyswitch_key, + 1DTensorOf<[I64]>:$ciphertext, + I32Attr:$level, + I32Attr:$baseLog + ); + let results = (outs RT_Future : $result); +} + +def BConcrete_BootstrapLweBufferAsyncOffloadOp : + BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> { + let arguments = (ins + // LweBootstrapKeyType:$bootstrap_key, + 1DTensorOf<[I64]>:$input_ciphertext, + 1DTensorOf<[I64]>:$accumulator, + I32Attr:$level, + I32Attr:$baseLog + ); + let results = (outs RT_Future : $result); +} + +def BConcrete_AwaitFutureOp : + BConcrete_Op<"await_future"> { + let arguments = (ins RT_Future : $future); + let results = (outs 1DTensorOf<[I64]>:$result); +} + #endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h index 9367fbb46..8ac318474 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h +++ b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h @@ -14,8 +14,8 @@ namespace mlir { namespace concretelang { std::unique_ptr> createAddRuntimeContext(); - std::unique_ptr> createEliminateCRTOps(); +std::unique_ptr> createAsyncOffload(); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td index 54251a113..e22bef109 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td +++ b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td @@ -21,4 +21,9 @@ def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> let constructor = "mlir::concretelang::createEliminateCRTOpsPass()"; } +def AsyncOffload : Pass<"async-offload", "mlir::ModuleOp"> { + let summary = "Replace keyswitch and bootstrap operations by async versions and add synchronisation."; + let constructor = "mlir::concretelang::createAsyncOffload()"; +} + #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 4b40380f6..c27670563 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -48,6 +48,11 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context); +void *memref_keyswitch_async_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context); void memref_bootstrap_lwe_u64( uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, @@ -56,6 +61,20 @@ void memref_bootstrap_lwe_u64( uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, mlir::concretelang::RuntimeContext *context); +void *memref_bootstrap_async_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + mlir::concretelang::RuntimeContext *context); + +void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, void *future, + uint64_t *in_allocated, uint64_t *in_aligned, + uint64_t in_offset, uint64_t in_size, + uint64_t in_stride); uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product); diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index c49da57c5..541f062d7 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -52,6 +52,7 @@ struct CompilationOptions { bool autoParallelize; bool loopParallelize; bool dataflowParallelize; + bool asyncOffload; bool optimizeConcrete; llvm::Optional> fhelinalgTileSizes; @@ -62,7 +63,7 @@ struct CompilationOptions { CompilationOptions() : v0FHEConstraints(llvm::None), verifyDiagnostics(false), autoParallelize(false), loopParallelize(false), - dataflowParallelize(false), optimizeConcrete(true), + dataflowParallelize(false), asyncOffload(false), optimizeConcrete(true), clientParametersFuncName(llvm::None), optimizerConfig(optimizer::DEFAULT_CONFIG){}; diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index cbab04756..a5ef12ffb 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -53,6 +53,10 @@ mlir::LogicalResult optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult asyncOffload(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp b/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp new file mode 100644 index 000000000..4a0310639 --- /dev/null +++ b/compiler/lib/Dialect/BConcrete/Transforms/AsyncOffload.cpp @@ -0,0 +1,80 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/BConcrete/Transforms/Passes.h" + +namespace { + +struct AsyncOffloadPass : public AsyncOffloadBase { + void runOnOperation() final; +}; + +void AsyncOffloadPass::runOnOperation() { + auto module = getOperation(); + std::vector ops; + + module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweBufferOp op) { + mlir::OpBuilder builder(op); + mlir::Type futType = + mlir::concretelang::RT::FutureType::get(op.getResult().getType()); + mlir::Value future = builder.create< + mlir::concretelang::BConcrete::KeySwitchLweBufferAsyncOffloadOp>( + op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs()); + + assert(op.getResult().hasOneUse() && + "Single use assumed (for deallocation purposes - restriction can be " + "lifted)."); + for (auto &use : op.getResult().getUses()) { + builder.setInsertionPoint(use.getOwner()); + mlir::Value res = + builder.create( + use.getOwner()->getLoc(), + mlir::TypeRange{op.getResult().getType()}, future); + use.set(res); + } + ops.push_back(op); + }); + module.walk([&](mlir::concretelang::BConcrete::BootstrapLweBufferOp op) { + mlir::OpBuilder builder(op); + mlir::Type futType = + mlir::concretelang::RT::FutureType::get(op.getResult().getType()); + mlir::Value future = builder.create< + mlir::concretelang::BConcrete::BootstrapLweBufferAsyncOffloadOp>( + op.getLoc(), mlir::TypeRange{futType}, op.getOperands(), + op->getAttrs()); + + assert(op.getResult().hasOneUse() && + "Single use assumed (for deallocation purposes - restriction can be " + "lifted)."); + for (auto &use : op.getResult().getUses()) { + builder.setInsertionPoint(use.getOwner()); + mlir::Value res = + builder.create( + use.getOwner()->getLoc(), + mlir::TypeRange{op.getResult().getType()}, future); + use.set(res); + } + ops.push_back(op); + }); + + for (auto op : ops) + op->erase(); +} +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createAsyncOffload() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index b95fb3fac..eb8df9579 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -85,6 +85,9 @@ char memref_mul_cleartext_lwe_ciphertext_u64[] = char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64"; char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64"; char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; +char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64"; +char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64"; +char memref_await_future[] = "memref_await_future"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; @@ -95,9 +98,10 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); + auto futureType = + mlir::concretelang::RT::FutureType::get(rewriter.getIndexType()); auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); - mlir::FunctionType funcType; if (funcName == memref_add_lwe_ciphertexts_u64) { @@ -121,6 +125,18 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType, contextType}, {}); + } else if (funcName == memref_keyswitch_async_lwe_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), {memref1DType, memref1DType, contextType}, + {futureType}); + } else if (funcName == memref_bootstrap_async_lwe_u64) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, memref1DType, memref1DType, contextType}, {futureType}); + } else if (funcName == memref_await_future) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, futureType, memref1DType, memref1DType}, {}); } else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) { funcType = mlir::FunctionType::get(rewriter.getContext(), { @@ -333,6 +349,154 @@ struct BufferizableGlweFromTableOpInterface } }; +template +struct BufferizableWithAsyncCallOpInterface + : public BufferizableOpInterface::ExternalModel< + BufferizableWithAsyncCallOpInterface, Op> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + + auto loc = op->getLoc(); + auto castOp = cast(op); + + // For now we always alloc for the result, we didn't have the in place + // operators yet. + auto resTensorType = + castOp.result() + .getType() + .template cast() + .getElementType() + .template cast(); + + auto outMemrefType = MemRefType::get(resTensorType.getShape(), + resTensorType.getElementType()); + auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); + if (mlir::failed(outMemref)) { + return mlir::failure(); + } + + // The first operand is the result + mlir::SmallVector operands{ + getCastedMemRef(rewriter, loc, *outMemref), + }; + // For all tensor operand get the corresponding casted buffer + for (auto &operand : op->getOpOperands()) { + if (!operand.get().getType().isa()) { + operands.push_back(operand.get()); + } else { + auto memrefOperand = + bufferization::getBuffer(rewriter, operand.get(), options); + operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); + } + } + // Append the context argument + if (withContext) { + operands.push_back(getContextArgument(op)); + } + + // Insert forward declaration of the function + if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { + return mlir::failure(); + } + + auto result = rewriter.create( + loc, funcName, + mlir::TypeRange{ + mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())}, + operands); + + replaceOpWithBufferizedValues(rewriter, op, result.getResult(0)); + + return success(); + } +}; + +template +struct BufferizableWithSyncCallOpInterface + : public BufferizableOpInterface::ExternalModel< + BufferizableWithSyncCallOpInterface, Op> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + + auto loc = op->getLoc(); + auto castOp = cast(op); + + auto resTensorType = + castOp.result().getType().template cast(); + + auto outMemrefType = MemRefType::get(resTensorType.getShape(), + resTensorType.getElementType()); + auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); + if (mlir::failed(outMemref)) { + return mlir::failure(); + } + + // The first operand is the result + mlir::SmallVector operands{ + getCastedMemRef(rewriter, loc, *outMemref), + }; + // Then add the future operand + operands.push_back(op->getOpOperand(0).get()); + // Finally add a dependence on the memref covered by the future to + // prevent early deallocation + auto def = op->getOpOperand(0).get().getDefiningOp(); + operands.push_back(def->getOpOperand(0).get()); + operands.push_back(def->getOpOperand(1).get()); + + // Insert forward declaration of the function + if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) { + return mlir::failure(); + } + + rewriter.create(loc, funcName, mlir::TypeRange{}, + operands); + + replaceOpWithBufferizedValues(rewriter, op, *outMemref); + + return success(); + } +}; + } // namespace void mlir::concretelang::BConcrete:: @@ -363,6 +527,17 @@ void mlir::concretelang::BConcrete:: BConcrete::WopPBSCRTLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); + BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface< + BufferizableWithAsyncCallOpInterface< + BConcrete::KeySwitchLweBufferAsyncOffloadOp, + memref_keyswitch_async_lwe_u64, true>>(*ctx); + BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface< + BufferizableWithAsyncCallOpInterface< + BConcrete::BootstrapLweBufferAsyncOffloadOp, + memref_bootstrap_async_lwe_u64, true>>(*ctx); + BConcrete::AwaitFutureOp::attachInterface< + BufferizableWithSyncCallOpInterface>(*ctx); BConcrete::FillGlweFromTable::attachInterface< BufferizableGlweFromTableOpInterface>(*ctx); }); diff --git a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt b/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt index cef63baf0..80d87aea5 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms BufferizableOpInterfaceImpl.cpp AddRuntimeContext.cpp EliminateCRTOps.cpp + AsyncOffload.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete diff --git a/compiler/lib/Runtime/AsyncOffload.cpp b/compiler/lib/Runtime/AsyncOffload.cpp new file mode 100644 index 000000000..205888d7b --- /dev/null +++ b/compiler/lib/Runtime/AsyncOffload.cpp @@ -0,0 +1,93 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/ClientLib/Types.h" +#include "concretelang/Runtime/wrappers.h" +#include +#include +#include +#include +#include + +void async_keyswitch( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context, + std::promise> promise) { + CAPI_ASSERT_ERROR( + default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers( + get_engine(context), get_keyswitch_key_u64(context), + out_aligned + out_offset, ct0_aligned + ct0_offset)); + promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{ + out_allocated, out_aligned, out_offset, out_size, out_stride}); +} + +void *memref_keyswitch_async_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, mlir::concretelang::RuntimeContext *context) { + std::promise> promise; + auto ret = new std::future>( + promise.get_future()); + std::thread offload_thread(async_keyswitch, out_allocated, out_aligned, + out_offset, out_size, out_stride, ct0_allocated, + ct0_aligned, ct0_offset, ct0_size, ct0_stride, + context, std::move(promise)); + offload_thread.detach(); + return (void *)ret; +} + +void async_bootstrap( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + mlir::concretelang::RuntimeContext *context, + std::promise> promise) { + CAPI_ASSERT_ERROR( + fftw_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers( + get_fftw_engine(context), get_engine(context), + get_fftw_fourier_bootstrap_key_u64(context), out_aligned + out_offset, + ct0_aligned + ct0_offset, glwe_ct_aligned + glwe_ct_offset)); + promise.set_value(concretelang::clientlib::MemRefDescriptor<1>{ + out_allocated, out_aligned, out_offset, out_size, out_stride}); +} + +void *memref_bootstrap_async_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, + uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, + mlir::concretelang::RuntimeContext *context) { + std::promise> promise; + auto ret = new std::future>( + promise.get_future()); + std::thread offload_thread( + async_bootstrap, out_allocated, out_aligned, out_offset, out_size, + out_stride, ct0_allocated, ct0_aligned, ct0_offset, ct0_size, ct0_stride, + glwe_ct_allocated, glwe_ct_aligned, glwe_ct_offset, glwe_ct_size, + glwe_ct_stride, context, std::move(promise)); + offload_thread.detach(); + return (void *)ret; +} + +void memref_await_future(uint64_t *out_allocated, uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size, + uint64_t out_stride, void *fut, uint64_t *in_allocated, + uint64_t *in_aligned, uint64_t in_offset, + uint64_t in_size, uint64_t in_stride) { + auto future = + static_cast> *>( + fut); + auto desc = future->get(); + memref_copy_one_rank(desc.allocated, desc.aligned, desc.offset, desc.sizes[0], + desc.strides[0], out_allocated, out_aligned, out_offset, + out_size, out_stride); + delete future; +} diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index 0a2f470dd..b680ed9d1 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp + AsyncOffload.cpp DFRuntime.cpp seeder.cpp ) diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 7dcdb36a4..642676ec8 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -372,6 +372,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag( "Lowering from Bufferized Concrete to canonical MLIR dialects failed"); } + + // Make keyswitch and bootstrap asynchronous + if (options.asyncOffload) { + if (mlir::concretelang::pipeline::asyncOffload(mlirContext, module, + enablePass) + .failed()) { + return errorDiag("Converting Keyswitch and Bootstrap to asynchronous " + "operations failed."); + } + } + if (target == Target::STD) return std::move(res); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 321857d25..076addbda 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -264,6 +264,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult asyncOffload(mlir::MLIRContext &context, + mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("AsyncOffload", pm, context); + addPotentiallyNestedPass(pm, mlir::concretelang::createAsyncOffload(), + enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 3ac1743c5..2bbdc9875 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -162,6 +162,11 @@ llvm::cl::opt dataflowParallelize( "Generate (and execute if JIT) the program as a dataflow graph"), llvm::cl::init(false)); +llvm::cl::opt asyncOffload( + "async-offload", + llvm::cl::desc("Use asynchronous interface for keyswitch and bootstrap."), + llvm::cl::init(false)); + llvm::cl::opt funcName("funcname", llvm::cl::desc("Name of the function to compile, default 'main'"),