mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): Add pass converting operations into SDFG processes
This adds a new pass `ExtractSDGOps`, which scans a function for operations that implement `SDFGConvertibleOpInterface`, replaces them with SDFG processes and constructs an SDFG graph around the processes. Initialization and teardown of the SDFG graph are embedded into the function and take place at the beginning of the function and before the function's terminator, respectively. The pass can be invoked using concretecompiler by specifying the new compilation option `--emit-sdfg-ops` or programmatically on a `CompilerEngine` using the new compilation option `extractSDFGOps`.
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "concretelang/Conversion/BConcreteToCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ExtractSDFGOps/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
@@ -24,6 +25,7 @@
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
@@ -47,6 +47,13 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> {
|
||||
let summary = "Extracts SDFG ops and creates a static data flow graph";
|
||||
let description = [{ Extracts SDFG ops and creates a static data flow graph }];
|
||||
let constructor = "mlir::concretelang::createExtractSDFGOps()";
|
||||
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
|
||||
}
|
||||
|
||||
def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the BConcrete dialect to CAPI calls";
|
||||
let description = [{ Lowers operations from the BConcrete dialect to CAPI calls }];
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H
|
||||
#define CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace concretelang {
|
||||
namespace SDFG {
|
||||
void registerSDFGConvertibleOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry);
|
||||
} // namespace SDFG
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -54,6 +54,7 @@ struct CompilationOptions {
|
||||
bool autoParallelize;
|
||||
bool loopParallelize;
|
||||
bool batchConcreteOps;
|
||||
bool emitSDFGOps;
|
||||
bool dataflowParallelize;
|
||||
bool optimizeConcrete;
|
||||
/// use GPU during execution by generating GPU operations if possible
|
||||
@@ -67,8 +68,8 @@ struct CompilationOptions {
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true),
|
||||
emitGPUOps(false), clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
|
||||
@@ -58,6 +58,10 @@ mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [
|
||||
"LLVMX86Info",
|
||||
];
|
||||
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 31] = [
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 33] = [
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
@@ -256,6 +256,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [
|
||||
"ConcretelangClientLib",
|
||||
"ConcretelangBConcreteTransforms",
|
||||
"ConcretelangSDFGInterfaces",
|
||||
"ConcretelangSDFGTransforms",
|
||||
"CONCRETELANGCAPISupport",
|
||||
"FHELinalgDialect",
|
||||
"ConcretelangInterfaces",
|
||||
@@ -272,7 +273,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [
|
||||
"FHEDialectAnalysis",
|
||||
"ConcreteDialect",
|
||||
"RTDialectAnalysis",
|
||||
"SDFGDialect"
|
||||
"SDFGDialect",
|
||||
"ExtractSDFGOps"
|
||||
];
|
||||
|
||||
fn main() {
|
||||
|
||||
@@ -6,5 +6,6 @@ add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(BConcreteToCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
add_subdirectory(ExtractSDFGOps)
|
||||
|
||||
add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR)
|
||||
|
||||
17
compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt
Normal file
17
compiler/lib/Conversion/ExtractSDFGOps/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
add_mlir_dialect_library(
|
||||
ExtractSDFGOps
|
||||
ExtractSDFGOps.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms)
|
||||
|
||||
target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR)
|
||||
212
compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp
Normal file
212
compiler/lib/Conversion/ExtractSDFGOps/ExtractSDFGOps.cpp
Normal file
@@ -0,0 +1,212 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
namespace SDFG = mlir::concretelang::SDFG;
|
||||
|
||||
namespace {
|
||||
enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE };
|
||||
|
||||
SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder,
|
||||
SDFG::StreamKind kind, mlir::Type type,
|
||||
mlir::Value dfg, unsigned &streamNumber) {
|
||||
SDFG::StreamType streamType = builder.getType<SDFG::StreamType>(type);
|
||||
mlir::StringAttr name = builder.getStringAttr(llvm::Twine("stream") +
|
||||
llvm::Twine(streamNumber++));
|
||||
|
||||
return builder.create<SDFG::MakeStream>(streamType, dfg, name, kind);
|
||||
}
|
||||
|
||||
StreamMappingKind determineStreamMappingKind(mlir::Value v) {
|
||||
// Determine stream type for operands:
|
||||
//
|
||||
// - If an operand is produced by a non-convertible op, there
|
||||
// needs to be just a host-to-device stream
|
||||
//
|
||||
// - If an operand is produced by a convertible op and there
|
||||
// are no other consumers, a device-to-device stream may be
|
||||
// used
|
||||
//
|
||||
// - If an operand is produced by a convertible op and there is
|
||||
// at least one other non-convertible consumer, there needs
|
||||
// to be a device-to-host stream, and a host-to-device stream
|
||||
//
|
||||
if (llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
||||
v.getDefiningOp())) {
|
||||
// All convertible consumers?
|
||||
if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
|
||||
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
||||
o.getOwner());
|
||||
})) {
|
||||
return StreamMappingKind::ON_DEVICE;
|
||||
}
|
||||
// All non-convertible consumers?
|
||||
else if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
|
||||
return !llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
||||
o.getOwner());
|
||||
})) {
|
||||
return StreamMappingKind::TO_HOST;
|
||||
}
|
||||
// Mix of convertible and non-convertible users
|
||||
else {
|
||||
return StreamMappingKind::SPLICE;
|
||||
}
|
||||
} else {
|
||||
if (llvm::any_of(v.getUses(), [](mlir::OpOperand &o) {
|
||||
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
|
||||
o.getOwner());
|
||||
})) {
|
||||
return StreamMappingKind::TO_DEVICE;
|
||||
} else {
|
||||
return StreamMappingKind::NONE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder,
|
||||
mlir::Value v,
|
||||
mlir::OpBuilder::InsertPoint &pos) {
|
||||
if (v.isa<mlir::BlockArgument>())
|
||||
builder.restoreInsertionPoint(pos);
|
||||
else
|
||||
builder.setInsertionPointAfterValue(v);
|
||||
}
|
||||
|
||||
struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
|
||||
|
||||
ExtractSDFGOpsPass() {}
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp func = getOperation();
|
||||
mlir::IRRewriter rewriter(func.getContext());
|
||||
|
||||
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
|
||||
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processInMapping;
|
||||
mlir::DenseMap<mlir::Value, mlir::Value> replacementMapping;
|
||||
|
||||
llvm::SmallVector<SDFG::SDFGConvertibleOpInterface> convertibleOps;
|
||||
|
||||
unsigned streamNumber = 0;
|
||||
|
||||
func.walk([&](SDFG::SDFGConvertibleOpInterface op) {
|
||||
convertibleOps.push_back(op);
|
||||
});
|
||||
|
||||
if (convertibleOps.size() == 0)
|
||||
return;
|
||||
|
||||
// Insert Prelude
|
||||
rewriter.setInsertionPointToStart(&func.getBlocks().front());
|
||||
mlir::Value dfg = rewriter.create<SDFG::Init>(
|
||||
func.getLoc(), rewriter.getType<SDFG::DFGType>());
|
||||
SDFG::Start start = rewriter.create<SDFG::Start>(func.getLoc(), dfg);
|
||||
|
||||
rewriter.setInsertionPoint(func.getBlocks().front().getTerminator());
|
||||
rewriter.create<SDFG::Shutdown>(func.getLoc(), dfg);
|
||||
|
||||
mlir::ImplicitLocOpBuilder ilb(func.getLoc(), rewriter);
|
||||
|
||||
auto mapValueToStreams = [&](mlir::Value v) {
|
||||
if (processInMapping.find(v) != processInMapping.end() ||
|
||||
processOutMapping.find(v) != processOutMapping.end())
|
||||
return;
|
||||
|
||||
StreamMappingKind smk = determineStreamMappingKind(v);
|
||||
|
||||
SDFG::MakeStream prodOutStream;
|
||||
SDFG::MakeStream consInStream;
|
||||
|
||||
if (smk == StreamMappingKind::SPLICE ||
|
||||
smk == StreamMappingKind::TO_HOST) {
|
||||
ilb.setInsertionPoint(start);
|
||||
prodOutStream = makeStream(ilb, SDFG::StreamKind::device_to_host,
|
||||
v.getType(), dfg, streamNumber);
|
||||
processOutMapping.insert({v, prodOutStream});
|
||||
|
||||
ilb.setInsertionPointAfter(start);
|
||||
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
|
||||
setInsertionPointAfterValueOrRestore(ilb, v, pos);
|
||||
mlir::Value newOutVal =
|
||||
ilb.create<SDFG::Get>(v.getType(), prodOutStream.getResult());
|
||||
replacementMapping.insert({v, newOutVal});
|
||||
} else if (smk == StreamMappingKind::ON_DEVICE) {
|
||||
ilb.setInsertionPoint(start);
|
||||
prodOutStream = makeStream(ilb, SDFG::StreamKind::on_device,
|
||||
v.getType(), dfg, streamNumber);
|
||||
processOutMapping.insert({v, prodOutStream});
|
||||
}
|
||||
|
||||
if (smk == StreamMappingKind::ON_DEVICE) {
|
||||
processInMapping.insert({v, prodOutStream});
|
||||
} else if (smk == StreamMappingKind::SPLICE ||
|
||||
smk == StreamMappingKind::TO_DEVICE ||
|
||||
smk == StreamMappingKind::ON_DEVICE) {
|
||||
ilb.setInsertionPoint(start);
|
||||
consInStream = makeStream(ilb, SDFG::StreamKind::host_to_device,
|
||||
v.getType(), dfg, streamNumber);
|
||||
processInMapping.insert({v, consInStream});
|
||||
|
||||
if (smk == StreamMappingKind::TO_DEVICE) {
|
||||
ilb.setInsertionPointAfter(start);
|
||||
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
|
||||
setInsertionPointAfterValueOrRestore(ilb, v, pos);
|
||||
ilb.create<SDFG::Put>(consInStream.getResult(), v);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (SDFG::SDFGConvertibleOpInterface convertibleOp : convertibleOps) {
|
||||
llvm::SmallVector<mlir::Value> ins;
|
||||
llvm::SmallVector<mlir::Value> outs;
|
||||
ilb.setLoc(convertibleOp.getLoc());
|
||||
|
||||
for (mlir::Value res : convertibleOp->getResults()) {
|
||||
mapValueToStreams(res);
|
||||
outs.push_back(processOutMapping.find(res)->second.getResult());
|
||||
}
|
||||
|
||||
for (mlir::Value operand : convertibleOp->getOperands()) {
|
||||
mapValueToStreams(operand);
|
||||
ins.push_back(processInMapping.find(operand)->second.getResult());
|
||||
}
|
||||
|
||||
ilb.setInsertionPoint(start);
|
||||
SDFG::MakeProcess process = convertibleOp.convert(ilb, dfg, ins, outs);
|
||||
|
||||
assert(process && "Conversion to SDFG operation failed");
|
||||
}
|
||||
|
||||
for (auto it : replacementMapping) {
|
||||
it.first.replaceAllUsesWith(it.second);
|
||||
}
|
||||
|
||||
(void)mlir::simplifyRegions(rewriter, func->getRegions());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass() {
|
||||
return std::make_unique<ExtractSDFGOpsPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
18
compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt
Normal file
18
compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
add_mlir_dialect_library(
|
||||
ConcretelangSDFGTransforms
|
||||
SDFGConvertibleOpInterfaceImpl.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIRPass
|
||||
MLIRTransforms)
|
||||
@@ -0,0 +1,89 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace SDFG {
|
||||
namespace {
|
||||
char add_eint[] = "add_eint";
|
||||
char add_eint_int[] = "add_eint_int";
|
||||
char mul_eint_int[] = "mul_eint_int";
|
||||
char neg_eint[] = "neg_eint";
|
||||
char keyswitch[] = "keyswitch";
|
||||
char bootstrap[] = "bootstrap";
|
||||
} // namespace
|
||||
|
||||
template <typename Op, char const *processName, bool copyAttributes = false>
|
||||
struct ReplaceWithProcessSDFGConversionInterface
|
||||
: public SDFGConvertibleOpInterface::ExternalModel<
|
||||
ReplaceWithProcessSDFGConversionInterface<Op, processName,
|
||||
copyAttributes>,
|
||||
Op> {
|
||||
MakeProcess convert(Operation *op, mlir::ImplicitLocOpBuilder &builder,
|
||||
::mlir::Value dfg, ::mlir::ValueRange inStreams,
|
||||
::mlir::ValueRange outStreams) const {
|
||||
llvm::SmallVector<mlir::Value> streams = llvm::to_vector(inStreams);
|
||||
streams.append(outStreams.begin(), outStreams.end());
|
||||
MakeProcess process = builder.create<MakeProcess>(
|
||||
*symbolizeProcessKind(processName), dfg, streams);
|
||||
|
||||
if (copyAttributes) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> combinedAttrs =
|
||||
llvm::to_vector(op->getAttrs());
|
||||
|
||||
for (mlir::NamedAttribute attr : process->getAttrs()) {
|
||||
combinedAttrs.push_back(attr);
|
||||
}
|
||||
|
||||
process->setAttrs(combinedAttrs);
|
||||
}
|
||||
|
||||
return process;
|
||||
}
|
||||
};
|
||||
|
||||
void registerSDFGConvertibleOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx,
|
||||
BConcrete::BConcreteDialect *dialect) {
|
||||
mlir::concretelang::BConcrete::AddLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::AddLweTensorOp, add_eint>>(*ctx);
|
||||
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweTensorOp,
|
||||
add_eint_int>>(*ctx);
|
||||
|
||||
mlir::concretelang::BConcrete::MulCleartextLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::MulCleartextLweTensorOp,
|
||||
mul_eint_int>>(*ctx);
|
||||
|
||||
mlir::concretelang::BConcrete::NegateLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::NegateLweTensorOp, neg_eint>>(*ctx);
|
||||
|
||||
mlir::concretelang::BConcrete::KeySwitchLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweTensorOp, keyswitch,
|
||||
true>>(*ctx);
|
||||
|
||||
mlir::concretelang::BConcrete::BootstrapLweTensorOp::attachInterface<
|
||||
ReplaceWithProcessSDFGConversionInterface<
|
||||
mlir::concretelang::BConcrete::BootstrapLweTensorOp, bootstrap,
|
||||
true>>(*ctx);
|
||||
});
|
||||
}
|
||||
} // namespace SDFG
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -22,11 +22,13 @@ add_mlir_library(
|
||||
FHELinalgDialectTransforms
|
||||
FHETensorOpsToLinalg
|
||||
FHEToTFHE
|
||||
ExtractSDFGOps
|
||||
MLIRLowerableDialectsToLLVM
|
||||
FHEDialectAnalysis
|
||||
RTDialectAnalysis
|
||||
ConcretelangTransforms
|
||||
ConcretelangBConcreteTransforms
|
||||
ConcretelangSDFGTransforms
|
||||
ConcretelangSDFGInterfaces
|
||||
LinalgExtras
|
||||
ConcreteDialectTransforms
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/SDFG/IR/SDFGDialect.h>
|
||||
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
@@ -81,6 +82,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect,
|
||||
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
|
||||
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
|
||||
registry);
|
||||
@@ -392,6 +394,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Extract SDFG data flow graph from BConcrete representation
|
||||
|
||||
if (options.emitSDFGOps) {
|
||||
if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Extraction of SDFG operations from BConcrete representation failed");
|
||||
}
|
||||
}
|
||||
|
||||
// BConcrete -> Canonical dialects
|
||||
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
|
||||
enablePass)
|
||||
|
||||
@@ -274,6 +274,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("extract SDFG ops from BConcrete", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
|
||||
@@ -173,6 +173,13 @@ llvm::cl::opt<bool> batchConcreteOps(
|
||||
"operations out of loop nests as batched operations"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> emitSDFGOps(
|
||||
"emit-sdfg-ops",
|
||||
llvm::cl::desc(
|
||||
"Extract operations supported by the SDFG dialect for static data flow"
|
||||
" graphs and emit them."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> dataflowParallelize(
|
||||
"parallelize-dataflow",
|
||||
llvm::cl::desc(
|
||||
@@ -305,6 +312,7 @@ cmdlineCompilationOptions() {
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.batchConcreteOps = cmdline::batchConcreteOps;
|
||||
options.emitSDFGOps = cmdline::emitSDFGOps;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user