From 60412f7f6130a5e1a7f93239e3f37c475c7be254 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Fri, 24 Mar 2023 10:44:56 +0000 Subject: [PATCH] feat(compiler): add SDFG op generation for batched operations. --- .../concretelang/Dialect/SDFG/IR/SDFGOps.td | 14 +++- .../Runtime/stream_emulator_api.h | 33 ++++++++ .../SDFGToStreamEmulator.cpp | 83 +++++++++++++++++-- .../compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp | 16 ++++ .../BufferizableOpInterfaceImpl.cpp | 19 +++-- .../SDFGConvertibleOpInterfaceImpl.cpp | 42 ++++++++++ 6 files changed, 196 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td index 7329e1fad..44904e492 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td @@ -89,10 +89,22 @@ def ProcessKindMulEintInt : I32EnumAttrCase<"mul_eint_int", 2>; def ProcessKindNegEint : I32EnumAttrCase<"neg_eint", 3>; def ProcessKindKeyswitch : I32EnumAttrCase<"keyswitch", 4>; def ProcessKindBootstrap : I32EnumAttrCase<"bootstrap", 5>; +def ProcessKindBatchAddEint : I32EnumAttrCase<"batched_add_eint", 6>; +def ProcessKindBatchAddEintInt : I32EnumAttrCase<"batched_add_eint_int", 7>; +def ProcessKindBatchAddEintIntCst : I32EnumAttrCase<"batched_add_eint_int_cst", 8>; +def ProcessKindBatchMulEintInt : I32EnumAttrCase<"batched_mul_eint_int", 9>; +def ProcessKindBatchMulEintIntCst : I32EnumAttrCase<"batched_mul_eint_int_cst", 10>; +def ProcessKindBatchNegEint : I32EnumAttrCase<"batched_neg_eint", 11>; +def ProcessKindBatchKeyswitch : I32EnumAttrCase<"batched_keyswitch", 12>; +def ProcessKindBatchBootstrap : I32EnumAttrCase<"batched_bootstrap", 13>; def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind", [ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt, - ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap]> { + ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap, + ProcessKindBatchAddEint, ProcessKindBatchAddEintInt, + ProcessKindBatchAddEintIntCst, ProcessKindBatchMulEintInt, + ProcessKindBatchMulEintIntCst, ProcessKindBatchNegEint, + ProcessKindBatchKeyswitch, ProcessKindBatchBootstrap]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::concretelang::SDFG"; } diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h index a88a18902..367141bc8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/stream_emulator_api.h @@ -42,6 +42,27 @@ void stream_emulator_make_memref_bootstrap_lwe_u64_process( uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t output_size, uint32_t bsk_index, void *context); +void stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sin2, void *sout); +void stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process( + void *dfg, void *sin1, void *sout); +void stream_emulator_make_memref_batched_keyswitch_lwe_u64_process( + void *dfg, void *sin1, void *sout, uint32_t level, uint32_t base_log, + uint32_t input_lwe_dim, uint32_t output_lwe_dim, uint32_t output_size, + void *context); +void stream_emulator_make_memref_batched_bootstrap_lwe_u64_process( + void *dfg, void *sin1, void *sin2, void *sout, uint32_t input_lwe_dim, + uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, + uint32_t output_size, void *context); + void *stream_emulator_make_uint64_stream(const char *name, stream_type stype); void stream_emulator_put_uint64(void *stream, uint64_t e); uint64_t stream_emulator_get_uint64(void *stream); @@ -53,6 +74,18 @@ void stream_emulator_put_memref(void *stream, uint64_t *allocated, void stream_emulator_get_memref(void *stream, uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t out_stride); + +void *stream_emulator_make_memref_batch_stream(const char *name, + stream_type stype); +void stream_emulator_put_memref_batch(void *stream, uint64_t *allocated, + uint64_t *aligned, uint64_t offset, + uint64_t size0, uint64_t size1, + uint64_t stride0, uint64_t stride1); +void stream_emulator_get_memref_batch(void *stream, uint64_t *out_allocated, + uint64_t *out_aligned, + uint64_t out_offset, uint64_t out_size0, + uint64_t out_size1, uint64_t out_stride0, + uint64_t out_stride1); } #endif diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp index fb9153820..e120663d0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp @@ -38,6 +38,31 @@ char stream_emulator_make_memref_keyswitch_lwe_u64_process[] = char stream_emulator_make_memref_bootstrap_lwe_u64_process[] = "stream_emulator_make_memref_bootstrap_lwe_u64_process"; +char stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process[] = + "stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process"; +char + stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process + [] = "stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_" + "u64_process"; +char + stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process + [] = "stream_emulator_make_memref_batched_add_plaintext_cst_lwe_" + "ciphertext_u64_process"; +char + stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process + [] = "stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_" + "u64_process"; +char + stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process + [] = "stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_" + "ciphertext_u64_process"; +char stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process[] = + "stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process"; +char stream_emulator_make_memref_batched_keyswitch_lwe_u64_process[] = + "stream_emulator_make_memref_batched_keyswitch_lwe_u64_process"; +char stream_emulator_make_memref_batched_bootstrap_lwe_u64_process[] = + "stream_emulator_make_memref_batched_bootstrap_lwe_u64_process"; + char stream_emulator_make_memref_stream[] = "stream_emulator_make_memref_stream"; char stream_emulator_put_memref[] = "stream_emulator_put_memref"; @@ -46,6 +71,10 @@ char stream_emulator_make_uint64_stream[] = char stream_emulator_put_uint64[] = "stream_emulator_put_uint64"; char stream_emulator_get_uint64[] = "stream_emulator_get_uint64"; +char stream_emulator_make_memref_batch_stream[] = + "stream_emulator_make_memref_batch_stream"; +char stream_emulator_put_memref_batch[] = "stream_emulator_put_memref_batch"; + mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) { std::vector shape(rank, mlir::ShapedType::kDynamic); return mlir::RankedTensorType::get(shape, rewriter.getI64Type()); @@ -164,7 +193,7 @@ struct LowerSDFGMakeProcess ::mlir::LogicalResult matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp, ::mlir::PatternRewriter &rewriter) const override { - const char *funcName; + const char *funcName = nullptr; mlir::SmallVector operands(mpOp->getOperands()); switch (mpOp.getType()) { case SDFG::ProcessKind::add_eint: @@ -181,8 +210,12 @@ struct LowerSDFGMakeProcess case SDFG::ProcessKind::neg_eint: funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process; break; + case SDFG::ProcessKind::batched_keyswitch: + funcName = stream_emulator_make_memref_batched_keyswitch_lwe_u64_process; + [[fallthrough]]; case SDFG::ProcessKind::keyswitch: - funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process; + if (funcName == nullptr) + funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process; // level operands.push_back(rewriter.create( mpOp.getLoc(), mpOp->getAttrOfType("level"))); @@ -206,8 +239,12 @@ struct LowerSDFGMakeProcess // context operands.push_back(getContextArgument(mpOp)); break; + case SDFG::ProcessKind::batched_bootstrap: + funcName = stream_emulator_make_memref_batched_bootstrap_lwe_u64_process; + [[fallthrough]]; case SDFG::ProcessKind::bootstrap: - funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process; + if (funcName == nullptr) + funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process; // input_lwe_dim operands.push_back(rewriter.create( mpOp.getLoc(), @@ -235,6 +272,30 @@ struct LowerSDFGMakeProcess // context operands.push_back(getContextArgument(mpOp)); break; + case SDFG::ProcessKind::batched_add_eint: + funcName = + stream_emulator_make_memref_batched_add_lwe_ciphertexts_u64_process; + break; + case SDFG::ProcessKind::batched_add_eint_int: + funcName = + stream_emulator_make_memref_batched_add_plaintext_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::batched_add_eint_int_cst: + funcName = + stream_emulator_make_memref_batched_add_plaintext_cst_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::batched_mul_eint_int: + funcName = + stream_emulator_make_memref_batched_mul_cleartext_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::batched_mul_eint_int_cst: + funcName = + stream_emulator_make_memref_batched_mul_cleartext_cst_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::batched_neg_eint: + funcName = + stream_emulator_make_memref_batched_negate_lwe_ciphertext_u64_process; + break; } if (insertGenericForwardDeclaration(mpOp, rewriter, funcName, mlir::ValueRange{operands}.getTypes(), @@ -274,7 +335,13 @@ struct LowerSDFGMakeStream assert(sType && "SDFG MakeStream operation should return a stream type"); if (sType.getElementType().isa()) { - funcName = stream_emulator_make_memref_stream; + if (sType.getElementType().dyn_cast().getRank() == 1) + funcName = stream_emulator_make_memref_stream; + else if (sType.getElementType().dyn_cast().getRank() == + 2) + funcName = stream_emulator_make_memref_batch_stream; + else + return ::mlir::failure(); } else { assert(sType.getElementType().isa() && "SDFG streams only support memrefs and integers."); @@ -314,7 +381,13 @@ struct LowerSDFGPut assert(sType && "SDFG Put operation must take a stream type as first parameter."); if (sType.getElementType().isa()) { - funcName = stream_emulator_put_memref; + if (sType.getElementType().dyn_cast().getRank() == 1) + funcName = stream_emulator_put_memref; + else if (sType.getElementType().dyn_cast().getRank() == + 2) + funcName = stream_emulator_put_memref_batch; + else + return ::mlir::failure(); } else { assert(sType.getElementType().isa() && "SDFG streams only support memrefs and integers."); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp index 801f61c0f..2ec8b3164 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp @@ -82,6 +82,22 @@ mlir::LogicalResult MakeProcess::verify() { return checkStreams(1, 1); case ProcessKind::bootstrap: return checkStreams(2, 1); + case ProcessKind::batched_add_eint: + return checkStreams(2, 1); + case ProcessKind::batched_add_eint_int: + return checkStreams(2, 1); + case ProcessKind::batched_add_eint_int_cst: + return checkStreams(2, 1); + case ProcessKind::batched_mul_eint_int: + return checkStreams(2, 1); + case ProcessKind::batched_mul_eint_int_cst: + return checkStreams(2, 1); + case ProcessKind::batched_neg_eint: + return checkStreams(1, 1); + case ProcessKind::batched_keyswitch: + return checkStreams(1, 1); + case ProcessKind::batched_bootstrap: + return checkStreams(2, 1); } return mlir::failure(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp index 94058e8f8..8361f1928 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp @@ -69,11 +69,12 @@ mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, } char stream_emulator_get_memref[] = "stream_emulator_get_memref"; +char stream_emulator_get_memref_batch[] = "stream_emulator_get_memref_batch"; -template +template struct BufferizableWithCallOpInterface : public BufferizableOpInterface::ExternalModel< - BufferizableWithCallOpInterface, Op> { + BufferizableWithCallOpInterface, Op> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; @@ -103,6 +104,13 @@ struct BufferizableWithCallOpInterface // able to avoid the copy depending on the stream semantics. auto resTensorType = op->getResultTypes()[0].template cast(); + char const *fname; + if (resTensorType.getRank() == 1) + fname = funcName; + else if (resTensorType.getRank() == 2) + fname = funcName_batch; + else + return mlir::failure(); auto outMemrefType = MemRefType::get(resTensorType.getShape(), resTensorType.getElementType()); auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); @@ -117,9 +125,9 @@ struct BufferizableWithCallOpInterface mlir::FunctionType funcType = mlir::FunctionType::get( rewriter.getContext(), mlir::ValueRange{operands}.getTypes(), mlir::TypeRange()); - if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) + if (insertForwardDeclaration(op, rewriter, fname, funcType).failed()) return ::mlir::failure(); - rewriter.create(loc, funcName, mlir::TypeRange{}, + rewriter.create(loc, fname, mlir::TypeRange{}, operands); replaceOpWithBufferizedValues(rewriter, op, *outMemref); @@ -133,7 +141,8 @@ void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) { SDFG::Get::attachInterface< - BufferizableWithCallOpInterface>( + BufferizableWithCallOpInterface>( *ctx); }); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp index 6a7d5e418..dbbb25dbe 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -20,6 +20,15 @@ char mul_eint_int[] = "mul_eint_int"; char neg_eint[] = "neg_eint"; char keyswitch[] = "keyswitch"; char bootstrap[] = "bootstrap"; + +char batched_add_eint[] = "batched_add_eint"; +char batched_add_eint_int[] = "batched_add_eint_int"; +char batched_add_eint_int_cst[] = "batched_add_eint_int_cst"; +char batched_mul_eint_int[] = "batched_mul_eint_int"; +char batched_mul_eint_int_cst[] = "batched_mul_eint_int_cst"; +char batched_neg_eint[] = "batched_neg_eint"; +char batched_keyswitch[] = "batched_keyswitch"; +char batched_bootstrap[] = "batched_bootstrap"; } // namespace template @@ -87,6 +96,39 @@ void registerSDFGConvertibleOpInterfaceExternalModels( ReplaceWithProcessSDFGConversionInterface< mlir::concretelang::Concrete::BootstrapLweTensorOp, bootstrap, true>>(*ctx); + + mlir::concretelang::Concrete::BatchedAddLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::Concrete::BatchedAddLweTensorOp, + batched_add_eint>>(*ctx); + mlir::concretelang::Concrete::BatchedAddPlaintextLweTensorOp:: + attachInterface>(*ctx); + mlir::concretelang::Concrete::BatchedAddPlaintextCstLweTensorOp:: + attachInterface>(*ctx); + mlir::concretelang::Concrete::BatchedMulCleartextLweTensorOp:: + attachInterface>(*ctx); + mlir::concretelang::Concrete::BatchedMulCleartextCstLweTensorOp:: + attachInterface>(*ctx); + mlir::concretelang::Concrete::BatchedNegateLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::Concrete::BatchedNegateLweTensorOp, + batched_neg_eint>>(*ctx); + mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::Concrete::BatchedKeySwitchLweTensorOp, + batched_keyswitch, true>>(*ctx); + mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::Concrete::BatchedBootstrapLweTensorOp, + batched_bootstrap, true>>(*ctx); }); } } // namespace SDFG