diff --git a/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h b/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h new file mode 100644 index 000000000..ec61eed94 --- /dev/null +++ b/compiler/include/concretelang/Conversion/ExtractSDFGOps/Pass.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_ +#define CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace concretelang { +std::unique_ptr> createExtractSDFGOpsPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index 4af9f0090..ab644af9b 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -15,6 +15,7 @@ #include "concretelang/Conversion/BConcreteToCAPI/Pass.h" #include "concretelang/Conversion/ConcreteToBConcrete/Pass.h" +#include "concretelang/Conversion/ExtractSDFGOps/Pass.h" #include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h" #include "concretelang/Conversion/FHEToTFHE/Pass.h" #include "concretelang/Conversion/LinalgExtras/Passes.h" @@ -24,6 +25,7 @@ #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #define GEN_PASS_CLASSES diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index b43bf4d1f..498fc4dfd 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -47,6 +47,13 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"]; } +def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> { + let summary = "Extracts SDFG ops and creates a static data flow graph"; + let description = [{ Extracts SDFG ops and creates a static data flow graph }]; + let constructor = "mlir::concretelang::createExtractSDFGOps()"; + let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"]; +} + def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { let summary = "Lowers operations from the BConcrete dialect to CAPI calls"; let description = [{ Lowers operations from the BConcrete dialect to CAPI calls }]; diff --git a/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h b/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h new file mode 100644 index 000000000..d4ea3cb24 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h @@ -0,0 +1,20 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H +#define CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace concretelang { +namespace SDFG { +void registerSDFGConvertibleOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace SDFG +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 2d785f020..5e7c77e2f 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -54,6 +54,7 @@ struct CompilationOptions { bool autoParallelize; bool loopParallelize; bool batchConcreteOps; + bool emitSDFGOps; bool dataflowParallelize; bool optimizeConcrete; /// use GPU during execution by generating GPU operations if possible @@ -67,8 +68,8 @@ struct CompilationOptions { CompilationOptions() : v0FHEConstraints(llvm::None), verifyDiagnostics(false), autoParallelize(false), loopParallelize(false), batchConcreteOps(false), - dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false), - clientParametersFuncName(llvm::None), + emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true), + emitGPUOps(false), clientParametersFuncName(llvm::None), optimizerConfig(optimizer::DEFAULT_CONFIG){}; CompilationOptions(std::string funcname) : CompilationOptions() { diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index 39d07e0ad..f42533d6d 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -58,6 +58,10 @@ mlir::LogicalResult optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index 830dccb09..1269f52e3 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [ "LLVMX86Info", ]; -const CONCRETE_COMPILER_LIBS: [&str; 31] = [ +const CONCRETE_COMPILER_LIBS: [&str; 33] = [ "RTDialect", "RTDialectTransforms", "ConcretelangSupport", @@ -256,6 +256,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [ "ConcretelangClientLib", "ConcretelangBConcreteTransforms", "ConcretelangSDFGInterfaces", + "ConcretelangSDFGTransforms", "CONCRETELANGCAPISupport", "FHELinalgDialect", "ConcretelangInterfaces", @@ -272,7 +273,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [ "FHEDialectAnalysis", "ConcreteDialect", "RTDialectAnalysis", - "SDFGDialect" + "SDFGDialect", + "ExtractSDFGOps" ]; fn main() { diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index c39bf6b88..13db1bd1f 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -6,5 +6,6 @@ add_subdirectory(ConcreteToBConcrete) add_subdirectory(BConcreteToCAPI) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) +add_subdirectory(ExtractSDFGOps) add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt b/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt new file mode 100644 index 000000000..fce834f6a --- /dev/null +++ b/compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library( + ExtractSDFGOps + ExtractSDFGOps.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE + DEPENDS + SDFGDialect + ConcretelangSDFGInterfaces + mlir-headers + LINK_LIBS + SDFGDialect + ConcretelangSDFGInterfaces + PUBLIC + MLIRIR + MLIRTransforms) + +target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp b/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp new file mode 100644 index 000000000..817adb622 --- /dev/null +++ b/compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp @@ -0,0 +1,212 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" +#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +namespace SDFG = mlir::concretelang::SDFG; + +namespace { +enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE }; + +SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder, + SDFG::StreamKind kind, mlir::Type type, + mlir::Value dfg, unsigned &streamNumber) { + SDFG::StreamType streamType = builder.getType(type); + mlir::StringAttr name = builder.getStringAttr(llvm::Twine("stream") + + llvm::Twine(streamNumber++)); + + return builder.create(streamType, dfg, name, kind); +} + +StreamMappingKind determineStreamMappingKind(mlir::Value v) { + // Determine stream type for operands: + // + // - If an operand is produced by a non-convertible op, there + // needs to be just a host-to-device stream + // + // - If an operand is produced by a convertible op and there + // are no other consumers, a device-to-device stream may be + // used + // + // - If an operand is produced by a convertible op and there is + // at least one other non-convertible consumer, there needs + // to be a device-to-host stream, and a host-to-device stream + // + if (llvm::dyn_cast_or_null( + v.getDefiningOp())) { + // All convertible consumers? + if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) { + return !!llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::ON_DEVICE; + } + // All non-convertible consumers? + else if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) { + return !llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::TO_HOST; + } + // Mix of convertible and non-convertible users + else { + return StreamMappingKind::SPLICE; + } + } else { + if (llvm::any_of(v.getUses(), [](mlir::OpOperand &o) { + return !!llvm::dyn_cast_or_null( + o.getOwner()); + })) { + return StreamMappingKind::TO_DEVICE; + } else { + return StreamMappingKind::NONE; + } + } +} + +void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder, + mlir::Value v, + mlir::OpBuilder::InsertPoint &pos) { + if (v.isa()) + builder.restoreInsertionPoint(pos); + else + builder.setInsertionPointAfterValue(v); +} + +struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase { + + ExtractSDFGOpsPass() {} + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + mlir::IRRewriter rewriter(func.getContext()); + + mlir::DenseMap processOutMapping; + mlir::DenseMap processInMapping; + mlir::DenseMap replacementMapping; + + llvm::SmallVector convertibleOps; + + unsigned streamNumber = 0; + + func.walk([&](SDFG::SDFGConvertibleOpInterface op) { + convertibleOps.push_back(op); + }); + + if (convertibleOps.size() == 0) + return; + + // Insert Prelude + rewriter.setInsertionPointToStart(&func.getBlocks().front()); + mlir::Value dfg = rewriter.create( + func.getLoc(), rewriter.getType()); + SDFG::Start start = rewriter.create(func.getLoc(), dfg); + + rewriter.setInsertionPoint(func.getBlocks().front().getTerminator()); + rewriter.create(func.getLoc(), dfg); + + mlir::ImplicitLocOpBuilder ilb(func.getLoc(), rewriter); + + auto mapValueToStreams = [&](mlir::Value v) { + if (processInMapping.find(v) != processInMapping.end() || + processOutMapping.find(v) != processOutMapping.end()) + return; + + StreamMappingKind smk = determineStreamMappingKind(v); + + SDFG::MakeStream prodOutStream; + SDFG::MakeStream consInStream; + + if (smk == StreamMappingKind::SPLICE || + smk == StreamMappingKind::TO_HOST) { + ilb.setInsertionPoint(start); + prodOutStream = makeStream(ilb, SDFG::StreamKind::device_to_host, + v.getType(), dfg, streamNumber); + processOutMapping.insert({v, prodOutStream}); + + ilb.setInsertionPointAfter(start); + mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint(); + setInsertionPointAfterValueOrRestore(ilb, v, pos); + mlir::Value newOutVal = + ilb.create(v.getType(), prodOutStream.getResult()); + replacementMapping.insert({v, newOutVal}); + } else if (smk == StreamMappingKind::ON_DEVICE) { + ilb.setInsertionPoint(start); + prodOutStream = makeStream(ilb, SDFG::StreamKind::on_device, + v.getType(), dfg, streamNumber); + processOutMapping.insert({v, prodOutStream}); + } + + if (smk == StreamMappingKind::ON_DEVICE) { + processInMapping.insert({v, prodOutStream}); + } else if (smk == StreamMappingKind::SPLICE || + smk == StreamMappingKind::TO_DEVICE || + smk == StreamMappingKind::ON_DEVICE) { + ilb.setInsertionPoint(start); + consInStream = makeStream(ilb, SDFG::StreamKind::host_to_device, + v.getType(), dfg, streamNumber); + processInMapping.insert({v, consInStream}); + + if (smk == StreamMappingKind::TO_DEVICE) { + ilb.setInsertionPointAfter(start); + mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint(); + setInsertionPointAfterValueOrRestore(ilb, v, pos); + ilb.create(consInStream.getResult(), v); + } + } + }; + + for (SDFG::SDFGConvertibleOpInterface convertibleOp : convertibleOps) { + llvm::SmallVector ins; + llvm::SmallVector outs; + ilb.setLoc(convertibleOp.getLoc()); + + for (mlir::Value res : convertibleOp->getResults()) { + mapValueToStreams(res); + outs.push_back(processOutMapping.find(res)->second.getResult()); + } + + for (mlir::Value operand : convertibleOp->getOperands()) { + mapValueToStreams(operand); + ins.push_back(processInMapping.find(operand)->second.getResult()); + } + + ilb.setInsertionPoint(start); + SDFG::MakeProcess process = convertibleOp.convert(ilb, dfg, ins, outs); + + assert(process && "Conversion to SDFG operation failed"); + } + + for (auto it : replacementMapping) { + it.first.replaceAllUsesWith(it.second); + } + + (void)mlir::simplifyRegions(rewriter, func->getRegions()); + } +}; +} // namespace + +namespace mlir { +namespace concretelang { + +std::unique_ptr> createExtractSDFGOpsPass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/SDFG/CMakeLists.txt b/compiler/lib/Dialect/SDFG/CMakeLists.txt index 7927b605a..fa8842fb0 100644 --- a/compiler/lib/Dialect/SDFG/CMakeLists.txt +++ b/compiler/lib/Dialect/SDFG/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Interfaces) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt new file mode 100644 index 000000000..05bfa724c --- /dev/null +++ b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library( + ConcretelangSDFGTransforms + SDFGConvertibleOpInterfaceImpl.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG + DEPENDS + mlir-headers + SDFGDialect + ConcretelangSDFGInterfaces + LINK_LIBS + PUBLIC + SDFGDialect + ConcretelangSDFGInterfaces + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransforms) diff --git a/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp new file mode 100644 index 000000000..0c98ec8fd --- /dev/null +++ b/compiler/lib/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.cpp @@ -0,0 +1,89 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace concretelang { +namespace SDFG { +namespace { +char add_eint[] = "add_eint"; +char add_eint_int[] = "add_eint_int"; +char mul_eint_int[] = "mul_eint_int"; +char neg_eint[] = "neg_eint"; +char keyswitch[] = "keyswitch"; +char bootstrap[] = "bootstrap"; +} // namespace + +template +struct ReplaceWithProcessSDFGConversionInterface + : public SDFGConvertibleOpInterface::ExternalModel< + ReplaceWithProcessSDFGConversionInterface, + Op> { + MakeProcess convert(Operation *op, mlir::ImplicitLocOpBuilder &builder, + ::mlir::Value dfg, ::mlir::ValueRange inStreams, + ::mlir::ValueRange outStreams) const { + llvm::SmallVector streams = llvm::to_vector(inStreams); + streams.append(outStreams.begin(), outStreams.end()); + MakeProcess process = builder.create( + *symbolizeProcessKind(processName), dfg, streams); + + if (copyAttributes) { + llvm::SmallVector combinedAttrs = + llvm::to_vector(op->getAttrs()); + + for (mlir::NamedAttribute attr : process->getAttrs()) { + combinedAttrs.push_back(attr); + } + + process->setAttrs(combinedAttrs); + } + + return process; + } +}; + +void registerSDFGConvertibleOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, + BConcrete::BConcreteDialect *dialect) { + mlir::concretelang::BConcrete::AddLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::AddLweTensorOp, add_eint>>(*ctx); + + mlir::concretelang::BConcrete::AddPlaintextLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::AddPlaintextLweTensorOp, + add_eint_int>>(*ctx); + + mlir::concretelang::BConcrete::MulCleartextLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::MulCleartextLweTensorOp, + mul_eint_int>>(*ctx); + + mlir::concretelang::BConcrete::NegateLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::NegateLweTensorOp, neg_eint>>(*ctx); + + mlir::concretelang::BConcrete::KeySwitchLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::KeySwitchLweTensorOp, keyswitch, + true>>(*ctx); + + mlir::concretelang::BConcrete::BootstrapLweTensorOp::attachInterface< + ReplaceWithProcessSDFGConversionInterface< + mlir::concretelang::BConcrete::BootstrapLweTensorOp, bootstrap, + true>>(*ctx); + }); +} +} // namespace SDFG +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 9aae93646..20f85eecc 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -22,11 +22,13 @@ add_mlir_library( FHELinalgDialectTransforms FHETensorOpsToLinalg FHEToTFHE + ExtractSDFGOps MLIRLowerableDialectsToLLVM FHEDialectAnalysis RTDialectAnalysis ConcretelangTransforms ConcretelangBConcreteTransforms + ConcretelangSDFGTransforms ConcretelangSDFGInterfaces LinalgExtras ConcreteDialectTransforms diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5eca491b3..f5b2a6776 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -81,6 +82,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect, mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); BConcrete::registerBufferizableOpInterfaceExternalModels(registry); + SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); @@ -392,6 +394,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return std::move(res); } + // Extract SDFG data flow graph from BConcrete representation + + if (options.emitSDFGOps) { + if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module, + enablePass) + .failed()) { + return errorDiag( + "Extraction of SDFG operations from BConcrete representation failed"); + } + } + // BConcrete -> Canonical dialects if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module, enablePass) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 4d2837e31..ea12e60f2 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -274,6 +274,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("extract SDFG ops from BConcrete", pm, context); + addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(), + enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index f92af6ca9..bcb726c4a 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -173,6 +173,13 @@ llvm::cl::opt batchConcreteOps( "operations out of loop nests as batched operations"), llvm::cl::init(false)); +llvm::cl::opt emitSDFGOps( + "emit-sdfg-ops", + llvm::cl::desc( + "Extract operations supported by the SDFG dialect for static data flow" + " graphs and emit them."), + llvm::cl::init(false)); + llvm::cl::opt dataflowParallelize( "parallelize-dataflow", llvm::cl::desc( @@ -305,6 +312,7 @@ cmdlineCompilationOptions() { options.loopParallelize = cmdline::loopParallelize; options.dataflowParallelize = cmdline::dataflowParallelize; options.batchConcreteOps = cmdline::batchConcreteOps; + options.emitSDFGOps = cmdline::emitSDFGOps; options.optimizeConcrete = cmdline::optimizeConcrete; options.emitGPUOps = cmdline::emitGPUOps;