From 3da32560b715e22beb19e99b061b05dfd47a7a17 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 30 Nov 2022 10:35:03 +0100 Subject: [PATCH] feat(compiler): Add pass converting operations into SDFG processes This adds a new pass `ExtractSDGOps`, which scans a function for operations that implement `SDFGConvertibleOpInterface`, replaces them with SDFG processes and constructs an SDFG graph around the processes. Initialization and teardown of the SDFG graph are embedded into the function and take place at the beginning of the function and before the function's terminator, respectively. The pass can be invoked using concretecompiler by specifying the new compilation option `--emit-sdfg-ops` or programmatically on a `CompilerEngine` using the new compilation option `extractSDFGOps`. --- .../Conversion/ExtractSDFGOps/Pass.h | 18 ++ .../include/concretelang/Conversion/Passes.h | 2 + .../include/concretelang/Conversion/Passes.td | 7 + .../SDFGConvertibleOpInterfaceImpl.h | 20 ++ .../concretelang/Support/CompilerEngine.h | 5 +- .../include/concretelang/Support/Pipeline.h | 4 + compiler/lib/Bindings/Rust/build.rs | 6 +- compiler/lib/Conversion/CMakeLists.txt | 1 + .../Conversion/ExtractSDFGOps/CMakeLists.txt | 17 ++ .../ExtractSDFGOps/ExtractSDFGOps.cpp | 212 ++++++++++++++++++ compiler/lib/Dialect/SDFG/CMakeLists.txt | 1 + .../Dialect/SDFG/Transforms/CMakeLists.txt | 18 ++ .../SDFGConvertibleOpInterfaceImpl.cpp | 89 ++++++++ compiler/lib/Support/CMakeLists.txt | 2 + compiler/lib/Support/CompilerEngine.cpp | 13 ++ compiler/lib/Support/Pipeline.cpp | 10 + compiler/src/main.cpp | 8 + 17 files changed, 429 insertions(+), 4 deletions(-) create mode 100644 compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h create mode 100644 compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h create mode 100644 compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt create mode 100644 compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp create mode 100644 compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt create mode 100644 compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp diff --git a/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h b/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h new file mode 100644 index 000000000..ec61eed94 --- /dev/null +++ b/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_ +#define CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace concretelang { +std::unique_ptr> createExtractSDFGOpsPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index 4af9f0090..ab644af9b 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -15,6 +15,7 @@ #include "concretelang/Conversion/BConcreteToCAPI/Pass.h" #include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" +#include "concretelang/Conversion/ExtractSDFGOps/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHE/Pass.h" #include "concretelang/Conversion/LinalgExtras/Passes.h" @@ -24,6 +25,7 @@ #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #define GEN_PASS_CLASSES diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index b43bf4d1f..498fc4dfd 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -47,6 +47,13 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"]; } +def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> { + let summary = "Extracts SDFG ops and creates a static data flow graph"; + let description = [{ Extracts SDFG ops and creates a static data flow graph }]; + let constructor = "mlir::concretelang::createExtractSDFGOps()"; + let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"]; +} + def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { let summary = "Lowers operations from the BConcrete dialect to CAPI calls"; let description = [{ Lowers operations from the BConcrete dialect to CAPI calls }]; diff --git a/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h b/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h new file mode 100644 index 000000000..d4ea3cb24 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h @@ -0,0 +1,20 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H +#define CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace concretelang { +namespace SDFG { +void registerSDFGConvertibleOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace SDFG +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 2d785f020..5e7c77e2f 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -54,6 +54,7 @@ struct CompilationOptions { bool autoParallelize; bool loopParallelize; bool batchConcreteOps; + bool emitSDFGOps; bool dataflowParallelize; bool optimizeConcrete; /// use GPU during execution by generating GPU operations if possible @@ -67,8 +68,8 @@ struct CompilationOptions { CompilationOptions() : v0FHEConstraints(llvm::None), verifyDiagnostics(false), autoParallelize(false), loopParallelize(false), batchConcreteOps(false), - dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false), - clientParametersFuncName(llvm::None), + emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true), + emitGPUOps(false), clientParametersFuncName(llvm::None), optimizerConfig(optimizer::DEFAULT_CONFIG){}; CompilationOptions(std::string funcname) : CompilationOptions() { diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 39d07e0ad..f42533d6d 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -58,6 +58,10 @@ mlir::LogicalResult optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index 830dccb09..1269f52e3 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [ "LLVMX86Info", ]; -const CONCRETE_COMPILER_LIBS: [&str; 31] = [ +const CONCRETE_COMPILER_LIBS: [&str; 33] = [ "RTDialect", "RTDialectTransforms", "ConcretelangSupport", @@ -256,6 +256,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [ "ConcretelangClientLib", "ConcretelangBConcreteTransforms", "ConcretelangSDFGInterfaces", + "ConcretelangSDFGTransforms", "CONCRETELANGCAPISupport", "FHELinalgDialect", "ConcretelangInterfaces", @@ -272,7 +273,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [ "FHEDialectAnalysis", "ConcreteDialect", "RTDialectAnalysis", - "SDFGDialect" + "SDFGDialect", + "ExtractSDFGOps" ]; fn main() { diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index c39bf6b88..13db1bd1f 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -6,5 +6,6 @@ add_subdirectory(ConcreteToBConcrete) add_subdirectory(BConcreteToCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) +add_subdirectory(ExtractSDFGOps) add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt b/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt new file mode 100644 index 000000000..fce834f6a --- /dev/null +++ b/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library( + ExtractSDFGOps + ExtractSDFGOps.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE + DEPENDS + SDFGDialect + ConcretelangSDFGInterfaces + mlir-headers + LINK_LIBS + SDFGDialect + ConcretelangSDFGInterfaces + PUBLIC + MLIRIR + MLIRTransforms) + +target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp b/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp new file mode 100644 index 000000000..817adb622 --- /dev/null +++ b/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp @@ -0,0 +1,212 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" +#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +namespace SDFG = mlir::concretelang::SDFG; + +namespace { +enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE }; + +SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder, + SDFG::StreamKind kind, mlir::Type type, + mlir::Value dfg, unsigned &streamNumber) { + SDFG::StreamType streamType = builder.getType(type); + mlir::StringAttr name = builder.getStringAttr(llvm::Twine("stream") + + llvm::Twine(streamNumber++)); + + return builder.create(streamType, dfg, name, kind); +} + +StreamMappingKind determineStreamMappingKind(mlir::Value v) { + // Determine stream type for operands: + // + // - If an operand is produced by a non-convertible op, there + // needs to be just a host-to-device stream + // + // - If an operand is produced by a convertible op and there + // are no other consumers, a device-to-device stream may be + // used + // + // - If an operand is produced by a convertible op and there is + // at least one other non-convertible consumer, there needs + // to be a device-to-host stream, and a host-to-device stream + // + if (llvm::dyn_cast_or_null( + v.getDefiningOp())) { + // All convertible consumers? + if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) { + return !!llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::ON_DEVICE; + } + // All non-convertible consumers? + else if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) { + return !llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::TO_HOST; + } + // Mix of convertible and non-convertible users + else { + return StreamMappingKind::SPLICE; + } + } else { + if (llvm::any_of(v.getUses(), [](mlir::OpOperand &o) { + return !!llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::TO_DEVICE; + } else { + return StreamMappingKind::NONE; + } + } +} + +void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder, + mlir::Value v, + mlir::OpBuilder::InsertPoint &pos) { + if (v.isa()) + builder.restoreInsertionPoint(pos); + else + builder.setInsertionPointAfterValue(v); +} + +struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase { + + ExtractSDFGOpsPass() {} + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + mlir::IRRewriter rewriter(func.getContext()); + + mlir::DenseMap processOutMapping; + mlir::DenseMap processInMapping; + mlir::DenseMap replacementMapping; + + llvm::SmallVector convertibleOps; + + unsigned streamNumber = 0; + + func.walk([&](SDFG::SDFGConvertibleOpInterface op) { + convertibleOps.push_back(op); + }); + + if (convertibleOps.size() == 0) + return; + + // Insert Prelude + rewriter.setInsertionPointToStart(&func.getBlocks().front()); + mlir::Value dfg = rewriter.create( + func.getLoc(), rewriter.getType()); + SDFG::Start start = rewriter.create(func.getLoc(), dfg); + + rewriter.setInsertionPoint(func.getBlocks().front().getTerminator()); + rewriter.create(func.getLoc(), dfg); + + mlir::ImplicitLocOpBuilder ilb(func.getLoc(), rewriter); + + auto mapValueToStreams = [&](mlir::Value v) { + if (processInMapping.find(v) != processInMapping.end() || + processOutMapping.find(v) != processOutMapping.end()) + return; + + StreamMappingKind smk = determineStreamMappingKind(v); + + SDFG::MakeStream prodOutStream; + SDFG::MakeStream consInStream; + + if (smk == StreamMappingKind::SPLICE || + smk == StreamMappingKind::TO_HOST) { + ilb.setInsertionPoint(start); + prodOutStream = makeStream(ilb, SDFG::StreamKind::device_to_host, + v.getType(), dfg, streamNumber); + processOutMapping.insert({v, prodOutStream}); + + ilb.setInsertionPointAfter(start); + mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint(); + setInsertionPointAfterValueOrRestore(ilb, v, pos); + mlir::Value newOutVal = + ilb.create(v.getType(), prodOutStream.getResult()); + replacementMapping.insert({v, newOutVal}); + } else if (smk == StreamMappingKind::ON_DEVICE) { + ilb.setInsertionPoint(start); + prodOutStream = makeStream(ilb, SDFG::StreamKind::on_device, + v.getType(), dfg, streamNumber); + processOutMapping.insert({v, prodOutStream}); + } + + if (smk == StreamMappingKind::ON_DEVICE) { + processInMapping.insert({v, prodOutStream}); + } else if (smk == StreamMappingKind::SPLICE || + smk == StreamMappingKind::TO_DEVICE || + smk == StreamMappingKind::ON_DEVICE) { + ilb.setInsertionPoint(start); + consInStream = makeStream(ilb, SDFG::StreamKind::host_to_device, + v.getType(), dfg, streamNumber); + processInMapping.insert({v, consInStream}); + + if (smk == StreamMappingKind::TO_DEVICE) { + ilb.setInsertionPointAfter(start); + mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint(); + setInsertionPointAfterValueOrRestore(ilb, v, pos); + ilb.create(consInStream.getResult(), v); + } + } + }; + + for (SDFG::SDFGConvertibleOpInterface convertibleOp : convertibleOps) { + llvm::SmallVector ins; + llvm::SmallVector outs; + ilb.setLoc(convertibleOp.getLoc()); + + for (mlir::Value res : convertibleOp->getResults()) { + mapValueToStreams(res); + outs.push_back(processOutMapping.find(res)->second.getResult()); + } + + for (mlir::Value operand : convertibleOp->getOperands()) { + mapValueToStreams(operand); + ins.push_back(processInMapping.find(operand)->second.getResult()); + } + + ilb.setInsertionPoint(start); + SDFG::MakeProcess process = convertibleOp.convert(ilb, dfg, ins, outs); + + assert(process && "Conversion to SDFG operation failed"); + } + + for (auto it : replacementMapping) { + it.first.replaceAllUsesWith(it.second); + } + + (void)mlir::simplifyRegions(rewriter, func->getRegions()); + } +}; +} // namespace + +namespace mlir { +namespace concretelang { + +std::unique_ptr> createExtractSDFGOpsPass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/SDFG/CMakeLists.txt b/compiler/lib/Dialect/SDFG/CMakeLists.txt index 7927b605a..fa8842fb0 100644 --- a/compiler/lib/Dialect/SDFG/CMakeLists.txt +++ b/compiler/lib/Dialect/SDFG/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Interfaces) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt new file mode 100644 index 000000000..05bfa724c --- /dev/null +++ b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library( + ConcretelangSDFGTransforms + SDFGConvertibleOpInterfaceImpl.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG + DEPENDS + mlir-headers + SDFGDialect + ConcretelangSDFGInterfaces + LINK_LIBS + PUBLIC + SDFGDialect + ConcretelangSDFGInterfaces + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransforms) diff --git a/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp new file mode 100644 index 000000000..0c98ec8fd --- /dev/null +++ b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -0,0 +1,89 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace concretelang { +namespace SDFG { +namespace { +char add_eint[] = "add_eint"; +char add_eint_int[] = "add_eint_int"; +char mul_eint_int[] = "mul_eint_int"; +char neg_eint[] = "neg_eint"; +char keyswitch[] = "keyswitch"; +char bootstrap[] = "bootstrap"; +} // namespace + +template +struct ReplaceWithProcessSDFGConversionInterface + : public SDFGConvertibleOpInterface::ExternalModel< + ReplaceWithProcessSDFGConversionInterface, + Op> { + MakeProcess convert(Operation *op, mlir::ImplicitLocOpBuilder &builder, + ::mlir::Value dfg, ::mlir::ValueRange inStreams, + ::mlir::ValueRange outStreams) const { + llvm::SmallVector streams = llvm::to_vector(inStreams); + streams.append(outStreams.begin(), outStreams.end()); + MakeProcess process = builder.create( + *symbolizeProcessKind(processName), dfg, streams); + + if (copyAttributes) { + llvm::SmallVector combinedAttrs = + llvm::to_vector(op->getAttrs()); + + for (mlir::NamedAttribute attr : process->getAttrs()) { + combinedAttrs.push_back(attr); + } + + process->setAttrs(combinedAttrs); + } + + return process; + } +}; + +void registerSDFGConvertibleOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, + BConcrete::BConcreteDialect *dialect) { + mlir::concretelang::BConcrete::AddLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::AddLweTensorOp, add_eint>>(*ctx); + + mlir::concretelang::BConcrete::AddPlaintextLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::AddPlaintextLweTensorOp, + add_eint_int>>(*ctx); + + mlir::concretelang::BConcrete::MulCleartextLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::MulCleartextLweTensorOp, + mul_eint_int>>(*ctx); + + mlir::concretelang::BConcrete::NegateLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::NegateLweTensorOp, neg_eint>>(*ctx); + + mlir::concretelang::BConcrete::KeySwitchLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::KeySwitchLweTensorOp, keyswitch, + true>>(*ctx); + + mlir::concretelang::BConcrete::BootstrapLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::BootstrapLweTensorOp, bootstrap, + true>>(*ctx); + }); +} +} // namespace SDFG +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 9aae93646..20f85eecc 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -22,11 +22,13 @@ add_mlir_library( FHELinalgDialectTransforms FHETensorOpsToLinalg FHEToTFHE + ExtractSDFGOps MLIRLowerableDialectsToLLVM FHEDialectAnalysis RTDialectAnalysis ConcretelangTransforms ConcretelangBConcreteTransforms + ConcretelangSDFGTransforms ConcretelangSDFGInterfaces LinalgExtras ConcreteDialectTransforms diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5eca491b3..f5b2a6776 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -81,6 +82,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect, mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); BConcrete::registerBufferizableOpInterfaceExternalModels(registry); + SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); @@ -392,6 +394,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return std::move(res); } + // Extract SDFG data flow graph from BConcrete representation + + if (options.emitSDFGOps) { + if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module, + enablePass) + .failed()) { + return errorDiag( + "Extraction of SDFG operations from BConcrete representation failed"); + } + } + // BConcrete -> Canonical dialects if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module, enablePass) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 4d2837e31..ea12e60f2 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -274,6 +274,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("extract SDFG ops from BConcrete", pm, context); + addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(), + enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index f92af6ca9..bcb726c4a 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -173,6 +173,13 @@ llvm::cl::opt batchConcreteOps( "operations out of loop nests as batched operations"), llvm::cl::init(false)); +llvm::cl::opt emitSDFGOps( + "emit-sdfg-ops", + llvm::cl::desc( + "Extract operations supported by the SDFG dialect for static data flow" + " graphs and emit them."), + llvm::cl::init(false)); + llvm::cl::opt dataflowParallelize( "parallelize-dataflow", llvm::cl::desc( @@ -305,6 +312,7 @@ cmdlineCompilationOptions() { options.loopParallelize = cmdline::loopParallelize; options.dataflowParallelize = cmdline::dataflowParallelize; options.batchConcreteOps = cmdline::batchConcreteOps; + options.emitSDFGOps = cmdline::emitSDFGOps; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps;