feat(compiler): Add pass converting operations into SDFG processes

This adds a new pass `ExtractSDGOps`, which scans a function for
operations that implement `SDFGConvertibleOpInterface`, replaces them
with SDFG processes and constructs an SDFG graph around the processes.

Initialization and teardown of the SDFG graph are embedded into the
function and take place at the beginning of the function and before
the function's terminator, respectively.

The pass can be invoked using concretecompiler by specifying the new
compilation option `--emit-sdfg-ops` or programmatically on a
`CompilerEngine` using the new compilation option `extractSDFGOps`.
This commit is contained in:
Andi Drebes
2022-11-30 10:35:03 +01:00
committed by Andi Drebes
parent 9ea6c0e8a3
commit 3da32560b7
17 changed files with 429 additions and 4 deletions

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
#define CONCRETELANG_CONVERSION_EXTRACTSDFGOPS_PASS_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -15,6 +15,7 @@
#include "concretelang/Conversion/BConcreteToCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
#include "concretelang/Conversion/ExtractSDFGOps/Pass.h"
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
#include "concretelang/Conversion/LinalgExtras/Passes.h"
@@ -24,6 +25,7 @@
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#define GEN_PASS_CLASSES

View File

@@ -47,6 +47,13 @@ def ConcreteToBConcrete : Pass<"concrete-to-bconcrete", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect", "mlir::concretelang::Concrete::ConcreteDialect", "mlir::concretelang::BConcrete::BConcreteDialect"];
}
def ExtractSDFGOps : Pass<"extract-sdfg-ops", "::mlir::func::FuncOp"> {
let summary = "Extracts SDFG ops and creates a static data flow graph";
let description = [{ Extracts SDFG ops and creates a static data flow graph }];
let constructor = "mlir::concretelang::createExtractSDFGOps()";
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
}
def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
let summary = "Lowers operations from the BConcrete dialect to CAPI calls";
let description = [{ Lowers operations from the BConcrete dialect to CAPI calls }];

View File

@@ -0,0 +1,20 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H
#define CONCRETELANG_DIALECT_SDFG_SDFGCONVERTIBLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace concretelang {
namespace SDFG {
void registerSDFGConvertibleOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace SDFG
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -54,6 +54,7 @@ struct CompilationOptions {
bool autoParallelize;
bool loopParallelize;
bool batchConcreteOps;
bool emitSDFGOps;
bool dataflowParallelize;
bool optimizeConcrete;
/// use GPU during execution by generating GPU operations if possible
@@ -67,8 +68,8 @@ struct CompilationOptions {
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
clientParametersFuncName(llvm::None),
emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true),
emitGPUOps(false), clientParametersFuncName(llvm::None),
optimizerConfig(optimizer::DEFAULT_CONFIG){};
CompilationOptions(std::string funcname) : CompilationOptions() {

View File

@@ -58,6 +58,10 @@ mlir::LogicalResult
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

View File

@@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [
"LLVMX86Info",
];
const CONCRETE_COMPILER_LIBS: [&str; 31] = [
const CONCRETE_COMPILER_LIBS: [&str; 33] = [
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",
@@ -256,6 +256,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [
"ConcretelangClientLib",
"ConcretelangBConcreteTransforms",
"ConcretelangSDFGInterfaces",
"ConcretelangSDFGTransforms",
"CONCRETELANGCAPISupport",
"FHELinalgDialect",
"ConcretelangInterfaces",
@@ -272,7 +273,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 31] = [
"FHEDialectAnalysis",
"ConcreteDialect",
"RTDialectAnalysis",
"SDFGDialect"
"SDFGDialect",
"ExtractSDFGOps"
];
fn main() {

View File

@@ -6,5 +6,6 @@ add_subdirectory(ConcreteToBConcrete)
add_subdirectory(BConcreteToCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LinalgExtras)
add_subdirectory(ExtractSDFGOps)
add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR)

View File

@@ -0,0 +1,17 @@
add_mlir_dialect_library(
ExtractSDFGOps
ExtractSDFGOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
SDFGDialect
ConcretelangSDFGInterfaces
mlir-headers
LINK_LIBS
SDFGDialect
ConcretelangSDFGInterfaces
PUBLIC
MLIRIR
MLIRTransforms)
target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR)

View File

@@ -0,0 +1,212 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
namespace SDFG = mlir::concretelang::SDFG;
namespace {
enum class StreamMappingKind { ON_DEVICE, TO_DEVICE, SPLICE, TO_HOST, NONE };
SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder,
SDFG::StreamKind kind, mlir::Type type,
mlir::Value dfg, unsigned &streamNumber) {
SDFG::StreamType streamType = builder.getType<SDFG::StreamType>(type);
mlir::StringAttr name = builder.getStringAttr(llvm::Twine("stream") +
llvm::Twine(streamNumber++));
return builder.create<SDFG::MakeStream>(streamType, dfg, name, kind);
}
StreamMappingKind determineStreamMappingKind(mlir::Value v) {
// Determine stream type for operands:
//
// - If an operand is produced by a non-convertible op, there
// needs to be just a host-to-device stream
//
// - If an operand is produced by a convertible op and there
// are no other consumers, a device-to-device stream may be
// used
//
// - If an operand is produced by a convertible op and there is
// at least one other non-convertible consumer, there needs
// to be a device-to-host stream, and a host-to-device stream
//
if (llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
v.getDefiningOp())) {
// All convertible consumers?
if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
o.getOwner());
})) {
return StreamMappingKind::ON_DEVICE;
}
// All non-convertible consumers?
else if (llvm::all_of(v.getUses(), [](mlir::OpOperand &o) {
return !llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
o.getOwner());
})) {
return StreamMappingKind::TO_HOST;
}
// Mix of convertible and non-convertible users
else {
return StreamMappingKind::SPLICE;
}
} else {
if (llvm::any_of(v.getUses(), [](mlir::OpOperand &o) {
return !!llvm::dyn_cast_or_null<SDFG::SDFGConvertibleOpInterface>(
o.getOwner());
})) {
return StreamMappingKind::TO_DEVICE;
} else {
return StreamMappingKind::NONE;
}
}
}
void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder,
mlir::Value v,
mlir::OpBuilder::InsertPoint &pos) {
if (v.isa<mlir::BlockArgument>())
builder.restoreInsertionPoint(pos);
else
builder.setInsertionPointAfterValue(v);
}
struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
ExtractSDFGOpsPass() {}
void runOnOperation() override {
mlir::func::FuncOp func = getOperation();
mlir::IRRewriter rewriter(func.getContext());
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processInMapping;
mlir::DenseMap<mlir::Value, mlir::Value> replacementMapping;
llvm::SmallVector<SDFG::SDFGConvertibleOpInterface> convertibleOps;
unsigned streamNumber = 0;
func.walk([&](SDFG::SDFGConvertibleOpInterface op) {
convertibleOps.push_back(op);
});
if (convertibleOps.size() == 0)
return;
// Insert Prelude
rewriter.setInsertionPointToStart(&func.getBlocks().front());
mlir::Value dfg = rewriter.create<SDFG::Init>(
func.getLoc(), rewriter.getType<SDFG::DFGType>());
SDFG::Start start = rewriter.create<SDFG::Start>(func.getLoc(), dfg);
rewriter.setInsertionPoint(func.getBlocks().front().getTerminator());
rewriter.create<SDFG::Shutdown>(func.getLoc(), dfg);
mlir::ImplicitLocOpBuilder ilb(func.getLoc(), rewriter);
auto mapValueToStreams = [&](mlir::Value v) {
if (processInMapping.find(v) != processInMapping.end() ||
processOutMapping.find(v) != processOutMapping.end())
return;
StreamMappingKind smk = determineStreamMappingKind(v);
SDFG::MakeStream prodOutStream;
SDFG::MakeStream consInStream;
if (smk == StreamMappingKind::SPLICE ||
smk == StreamMappingKind::TO_HOST) {
ilb.setInsertionPoint(start);
prodOutStream = makeStream(ilb, SDFG::StreamKind::device_to_host,
v.getType(), dfg, streamNumber);
processOutMapping.insert({v, prodOutStream});
ilb.setInsertionPointAfter(start);
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
setInsertionPointAfterValueOrRestore(ilb, v, pos);
mlir::Value newOutVal =
ilb.create<SDFG::Get>(v.getType(), prodOutStream.getResult());
replacementMapping.insert({v, newOutVal});
} else if (smk == StreamMappingKind::ON_DEVICE) {
ilb.setInsertionPoint(start);
prodOutStream = makeStream(ilb, SDFG::StreamKind::on_device,
v.getType(), dfg, streamNumber);
processOutMapping.insert({v, prodOutStream});
}
if (smk == StreamMappingKind::ON_DEVICE) {
processInMapping.insert({v, prodOutStream});
} else if (smk == StreamMappingKind::SPLICE ||
smk == StreamMappingKind::TO_DEVICE ||
smk == StreamMappingKind::ON_DEVICE) {
ilb.setInsertionPoint(start);
consInStream = makeStream(ilb, SDFG::StreamKind::host_to_device,
v.getType(), dfg, streamNumber);
processInMapping.insert({v, consInStream});
if (smk == StreamMappingKind::TO_DEVICE) {
ilb.setInsertionPointAfter(start);
mlir::OpBuilder::InsertPoint pos = ilb.saveInsertionPoint();
setInsertionPointAfterValueOrRestore(ilb, v, pos);
ilb.create<SDFG::Put>(consInStream.getResult(), v);
}
}
};
for (SDFG::SDFGConvertibleOpInterface convertibleOp : convertibleOps) {
llvm::SmallVector<mlir::Value> ins;
llvm::SmallVector<mlir::Value> outs;
ilb.setLoc(convertibleOp.getLoc());
for (mlir::Value res : convertibleOp->getResults()) {
mapValueToStreams(res);
outs.push_back(processOutMapping.find(res)->second.getResult());
}
for (mlir::Value operand : convertibleOp->getOperands()) {
mapValueToStreams(operand);
ins.push_back(processInMapping.find(operand)->second.getResult());
}
ilb.setInsertionPoint(start);
SDFG::MakeProcess process = convertibleOp.convert(ilb, dfg, ins, outs);
assert(process && "Conversion to SDFG operation failed");
}
for (auto it : replacementMapping) {
it.first.replaceAllUsesWith(it.second);
}
(void)mlir::simplifyRegions(rewriter, func->getRegions());
}
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass() {
return std::make_unique<ExtractSDFGOpsPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,2 +1,3 @@
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,18 @@
add_mlir_dialect_library(
ConcretelangSDFGTransforms
SDFGConvertibleOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG
DEPENDS
mlir-headers
SDFGDialect
ConcretelangSDFGInterfaces
LINK_LIBS
PUBLIC
SDFGDialect
ConcretelangSDFGInterfaces
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRTransforms)

View File

@@ -0,0 +1,89 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace concretelang {
namespace SDFG {
namespace {
char add_eint[] = "add_eint";
char add_eint_int[] = "add_eint_int";
char mul_eint_int[] = "mul_eint_int";
char neg_eint[] = "neg_eint";
char keyswitch[] = "keyswitch";
char bootstrap[] = "bootstrap";
} // namespace
template <typename Op, char const *processName, bool copyAttributes = false>
struct ReplaceWithProcessSDFGConversionInterface
: public SDFGConvertibleOpInterface::ExternalModel<
ReplaceWithProcessSDFGConversionInterface<Op, processName,
copyAttributes>,
Op> {
MakeProcess convert(Operation *op, mlir::ImplicitLocOpBuilder &builder,
::mlir::Value dfg, ::mlir::ValueRange inStreams,
::mlir::ValueRange outStreams) const {
llvm::SmallVector<mlir::Value> streams = llvm::to_vector(inStreams);
streams.append(outStreams.begin(), outStreams.end());
MakeProcess process = builder.create<MakeProcess>(
*symbolizeProcessKind(processName), dfg, streams);
if (copyAttributes) {
llvm::SmallVector<mlir::NamedAttribute> combinedAttrs =
llvm::to_vector(op->getAttrs());
for (mlir::NamedAttribute attr : process->getAttrs()) {
combinedAttrs.push_back(attr);
}
process->setAttrs(combinedAttrs);
}
return process;
}
};
void registerSDFGConvertibleOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx,
BConcrete::BConcreteDialect *dialect) {
mlir::concretelang::BConcrete::AddLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::AddLweTensorOp, add_eint>>(*ctx);
mlir::concretelang::BConcrete::AddPlaintextLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::AddPlaintextLweTensorOp,
add_eint_int>>(*ctx);
mlir::concretelang::BConcrete::MulCleartextLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::MulCleartextLweTensorOp,
mul_eint_int>>(*ctx);
mlir::concretelang::BConcrete::NegateLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::NegateLweTensorOp, neg_eint>>(*ctx);
mlir::concretelang::BConcrete::KeySwitchLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::KeySwitchLweTensorOp, keyswitch,
true>>(*ctx);
mlir::concretelang::BConcrete::BootstrapLweTensorOp::attachInterface<
ReplaceWithProcessSDFGConversionInterface<
mlir::concretelang::BConcrete::BootstrapLweTensorOp, bootstrap,
true>>(*ctx);
});
}
} // namespace SDFG
} // namespace concretelang
} // namespace mlir

View File

@@ -22,11 +22,13 @@ add_mlir_library(
FHELinalgDialectTransforms
FHETensorOpsToLinalg
FHEToTFHE
ExtractSDFGOps
MLIRLowerableDialectsToLLVM
FHEDialectAnalysis
RTDialectAnalysis
ConcretelangTransforms
ConcretelangBConcreteTransforms
ConcretelangSDFGTransforms
ConcretelangSDFGInterfaces
LinalgExtras
ConcreteDialectTransforms

View File

@@ -35,6 +35,7 @@
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Dialect/SDFG/IR/SDFGDialect.h>
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/CompilerEngine.h>
@@ -81,6 +82,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect,
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
@@ -392,6 +394,17 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return std::move(res);
}
// Extract SDFG data flow graph from BConcrete representation
if (options.emitSDFGOps) {
if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module,
enablePass)
.failed()) {
return errorDiag(
"Extraction of SDFG operations from BConcrete representation failed");
}
}
// BConcrete -> Canonical dialects
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
enablePass)

View File

@@ -274,6 +274,16 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("extract SDFG ops from BConcrete", pm, context);
addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {

View File

@@ -173,6 +173,13 @@ llvm::cl::opt<bool> batchConcreteOps(
"operations out of loop nests as batched operations"),
llvm::cl::init(false));
llvm::cl::opt<bool> emitSDFGOps(
"emit-sdfg-ops",
llvm::cl::desc(
"Extract operations supported by the SDFG dialect for static data flow"
" graphs and emit them."),
llvm::cl::init(false));
llvm::cl::opt<bool> dataflowParallelize(
"parallelize-dataflow",
llvm::cl::desc(
@@ -305,6 +312,7 @@ cmdlineCompilationOptions() {
options.loopParallelize = cmdline::loopParallelize;
options.dataflowParallelize = cmdline::dataflowParallelize;
options.batchConcreteOps = cmdline::batchConcreteOps;
options.emitSDFGOps = cmdline::emitSDFGOps;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;