feat(compiler): add lowering and bufferization for SDFG dialect, generate code to Stream Emulator API.

This commit is contained in:
Antoniu Pop
2022-11-30 09:41:34 +00:00
committed by Andi Drebes
parent 752f0feb75
commit 0dbb86bb36
14 changed files with 626 additions and 3 deletions

View File

@@ -20,6 +20,7 @@
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
#include "concretelang/Conversion/LinalgExtras/Passes.h"
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"

View File

@@ -61,6 +61,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
}
def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls";
let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }];
let constructor = "mlir::concretelang::createConvertSDFGToStreamEmulatorPass()";
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";

View File

@@ -0,0 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
#define ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `SDFG` dialect to Stream Emulator calls.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertSDFGToStreamEmulatorPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H
#define CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace concretelang {
namespace SDFG {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace SDFG
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -66,6 +66,10 @@ mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,

View File

@@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [
"LLVMX86Info",
];
const CONCRETE_COMPILER_LIBS: [&str; 33] = [
const CONCRETE_COMPILER_LIBS: [&str; 34] = [
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",
@@ -274,7 +274,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 33] = [
"ConcreteDialect",
"RTDialectAnalysis",
"SDFGDialect",
"ExtractSDFGOps"
"ExtractSDFGOps",
"SDFGToStreamEmulator",
];
fn main() {

View File

@@ -4,6 +4,7 @@ add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToBConcrete)
add_subdirectory(BConcreteToCAPI)
add_subdirectory(SDFGToStreamEmulator)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LinalgExtras)
add_subdirectory(ExtractSDFGOps)

View File

@@ -32,6 +32,7 @@
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
namespace {
struct MLIRLowerableDialectsToLLVMPass
@@ -155,7 +156,9 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::ContextType>() ||
type.isa<mlir::concretelang::RT::FutureType>()) {
type.isa<mlir::concretelang::RT::FutureType>() ||
type.isa<mlir::concretelang::SDFG::DFGType>() ||
type.isa<mlir::concretelang::SDFG::StreamType>()) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(type.getContext(), 64));
}

View File

@@ -0,0 +1,14 @@
add_mlir_dialect_library(
SDFGToStreamEmulator
SDFGToStreamEmulator.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG
DEPENDS
SDFGDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
MLIRTransforms)
target_link_libraries(SDFGToStreamEmulator PUBLIC SDFGDialect MLIRIR)

View File

@@ -0,0 +1,390 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
#include "concretelang/Runtime/stream_emulator_api.h"
namespace SDFG = mlir::concretelang::SDFG;
namespace {
struct SDFGToStreamEmulatorPass
: public SDFGToStreamEmulatorBase<SDFGToStreamEmulatorPass> {
void runOnOperation() final;
};
char stream_emulator_init[] = "stream_emulator_init";
char stream_emulator_run[] = "stream_emulator_run";
char stream_emulator_delete[] = "stream_emulator_delete";
char stream_emulator_make_memref_add_lwe_ciphertexts_u64_process[] =
"stream_emulator_make_memref_add_lwe_ciphertexts_u64_process";
char stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process[] =
"stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process";
char stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process[] =
"stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process";
char stream_emulator_make_memref_negate_lwe_ciphertext_u64_process[] =
"stream_emulator_make_memref_negate_lwe_ciphertext_u64_process";
char stream_emulator_make_memref_keyswitch_lwe_u64_process[] =
"stream_emulator_make_memref_keyswitch_lwe_u64_process";
char stream_emulator_make_memref_bootstrap_lwe_u64_process[] =
"stream_emulator_make_memref_bootstrap_lwe_u64_process";
char stream_emulator_make_memref_stream[] =
"stream_emulator_make_memref_stream";
char stream_emulator_put_memref[] = "stream_emulator_put_memref";
char stream_emulator_make_uint64_stream[] =
"stream_emulator_make_uint64_stream";
char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
std::vector<int64_t> shape(rank, -1);
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
}
mlir::Type makeDynamicTensorTypes(mlir::OpBuilder &rewriter, mlir::Type oldTy) {
if (auto ttype = oldTy.dyn_cast_or_null<mlir::TensorType>())
return getDynamicTensor(rewriter, ttype.getRank());
if (auto stTy = oldTy.dyn_cast_or_null<SDFG::StreamType>())
return SDFG::StreamType::get(
rewriter.getContext(),
makeDynamicTensorTypes(rewriter, stTy.getElementType()));
return oldTy;
}
mlir::LogicalResult insertGenericForwardDeclaration(mlir::Operation *op,
mlir::OpBuilder &rewriter,
llvm::StringRef funcName,
mlir::TypeRange opTys,
mlir::TypeRange resTys) {
mlir::SmallVector<mlir::Type> operands;
for (mlir::Type opTy : opTys)
operands.push_back(makeDynamicTensorTypes(rewriter, opTy));
mlir::SmallVector<mlir::Type> results;
for (mlir::Type resTy : resTys)
results.push_back(makeDynamicTensorTypes(rewriter, resTy));
mlir::FunctionType funcType =
mlir::FunctionType::get(rewriter.getContext(), operands, results);
return insertForwardDeclaration(op, rewriter, funcName, funcType);
}
void castDynamicTensorOps(mlir::Operation *op, mlir::OpBuilder &rewriter,
mlir::ValueRange operands,
mlir::SmallVector<mlir::Value> &newOps) {
for (auto val : operands) {
auto oldTy = val.getType();
if (auto ttype = oldTy.dyn_cast_or_null<mlir::TensorType>())
newOps.push_back(rewriter.create<mlir::tensor::CastOp>(
op->getLoc(), getDynamicTensor(rewriter, ttype.getRank()), val));
else
newOps.push_back(val);
}
}
struct LowerSDFGInit
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Init> {
LowerSDFGInit(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Init>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::Init initOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::FunctionType funcType = mlir::FunctionType::get(
rewriter.getContext(), {}, {SDFG::DFGType::get(rewriter.getContext())});
if (insertForwardDeclaration(initOp, rewriter, stream_emulator_init,
funcType)
.failed())
return ::mlir::failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
initOp, stream_emulator_init,
mlir::TypeRange{SDFG::DFGType::get(rewriter.getContext())});
return ::mlir::success();
};
};
struct LowerSDFGStart
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Start> {
LowerSDFGStart(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Start>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::Start startOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::FunctionType funcType = mlir::FunctionType::get(
rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {});
if (insertForwardDeclaration(startOp, rewriter, stream_emulator_run,
funcType)
.failed())
return ::mlir::failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
startOp, stream_emulator_run, mlir::TypeRange{},
startOp.getOperation()->getOperands());
return ::mlir::success();
};
};
struct LowerSDFGShutdown
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Shutdown> {
LowerSDFGShutdown(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Shutdown>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::Shutdown desOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::FunctionType funcType = mlir::FunctionType::get(
rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {});
if (insertForwardDeclaration(desOp, rewriter, stream_emulator_delete,
funcType)
.failed())
return ::mlir::failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
desOp, stream_emulator_delete, mlir::TypeRange{},
desOp.getOperation()->getOperands());
return ::mlir::success();
};
};
struct LowerSDFGMakeProcess
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeProcess> {
LowerSDFGMakeProcess(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeProcess>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp,
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
switch (mpOp.type()) {
case SDFG::ProcessKind::add_eint:
funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process;
break;
case SDFG::ProcessKind::add_eint_int:
funcName =
stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::mul_eint_int:
funcName =
stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::neg_eint:
funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process;
break;
case SDFG::ProcessKind::keyswitch:
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
// level
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
// base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("baseLog")));
// lwe_dim_in
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_in")));
// lwe_dim_out
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_out")));
// context
operands.push_back(getContextArgument(mpOp));
break;
case SDFG::ProcessKind::bootstrap:
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("inputLweDim")));
// poly_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("polySize")));
// level
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
// base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("baseLog")));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("glweDimension")));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
mpOp.getLoc(),
mpOp->getAttrOfType<mlir::IntegerAttr>("outPrecision")));
// context
operands.push_back(getContextArgument(mpOp));
break;
}
if (insertGenericForwardDeclaration(mpOp, rewriter, funcName,
mlir::ValueRange{operands}.getTypes(),
mpOp->getResultTypes())
.failed())
return ::mlir::failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
mpOp, funcName, mpOp->getResultTypes(), operands);
return ::mlir::success();
};
};
struct LowerSDFGMakeStream
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeStream> {
LowerSDFGMakeStream(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeStream>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::MakeStream msOp,
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
stream_type t;
switch (msOp.type()) {
case SDFG::StreamKind::host_to_device:
t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP;
break;
case SDFG::StreamKind::on_device:
t = TS_STREAM_TYPE_TOPO_TO_TOPO_LSAP;
break;
case SDFG::StreamKind::device_to_host:
t = TS_STREAM_TYPE_TOPO_TO_X86_LSAP;
break;
}
auto sType = msOp->getResultTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
assert(sType && "SDFG MakeStream operation should return a stream type");
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
funcName = stream_emulator_make_memref_stream;
} else {
assert(sType.getElementType().isa<mlir::IntegerType>() &&
"SDFG streams only support memrefs and integers.");
funcName = stream_emulator_make_uint64_stream;
}
if (insertGenericForwardDeclaration(
msOp, rewriter, funcName,
{rewriter.getI64Type(), rewriter.getI64Type()},
msOp->getResultTypes())
.failed())
return ::mlir::failure();
mlir::Value nullStringPtr = rewriter.create<mlir::arith::ConstantOp>(
msOp.getLoc(), rewriter.getI64IntegerAttr(0));
mlir::Value streamTypeCst = rewriter.create<mlir::arith::ConstantOp>(
msOp.getLoc(), rewriter.getI64IntegerAttr((int)t));
auto callop = rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
msOp, funcName,
makeDynamicTensorTypes(rewriter, msOp->getResultTypes()[0]),
mlir::ValueRange{nullStringPtr, streamTypeCst});
for (auto &use : llvm::make_early_inc_range(msOp->getUses()))
use.set(callop.getResult(0));
return ::mlir::success();
};
};
struct LowerSDFGPut
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Put> {
LowerSDFGPut(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Put>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::Put putOp,
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
auto sType =
putOp->getOperandTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
assert(sType &&
"SDFG Put operation must take a stream type as first parameter.");
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
funcName = stream_emulator_put_memref;
} else {
assert(sType.getElementType().isa<mlir::IntegerType>() &&
"SDFG streams only support memrefs and integers.");
funcName = stream_emulator_put_uint64;
}
if (insertGenericForwardDeclaration(putOp, rewriter, funcName,
putOp->getOperandTypes(),
putOp->getResultTypes())
.failed())
return ::mlir::failure();
mlir::SmallVector<mlir::Value> newOps;
castDynamicTensorOps(putOp, rewriter, putOp->getOperands(), newOps);
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
putOp, funcName, putOp->getResultTypes(), newOps);
return ::mlir::success();
};
};
struct LowerSDFGGet
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Get> {
LowerSDFGGet(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Get>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::SDFG::Get getOp,
::mlir::PatternRewriter &rewriter) const override {
const char *funcName;
auto sType =
getOp->getOperandTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
assert(sType &&
"SDFG Get operation must take a stream type as first parameter.");
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
// TODO: SDFG.Get for memref streams is lowered during bufferization
// as returning a memref requires allocation for now
return ::mlir::success();
} else {
assert(sType.getElementType().isa<mlir::IntegerType>() &&
"SDFG streams only support memrefs and integers.");
funcName = stream_emulator_get_uint64;
}
if (insertGenericForwardDeclaration(getOp, rewriter, funcName,
getOp->getOperandTypes(),
getOp->getResultTypes())
.failed())
return ::mlir::failure();
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
getOp, funcName, getOp->getResultTypes(), getOp->getOperands());
return ::mlir::success();
};
};
} // namespace
void SDFGToStreamEmulatorPass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LowerSDFGInit, LowerSDFGStart, LowerSDFGShutdown,
LowerSDFGMakeProcess, LowerSDFGMakeStream, LowerSDFGPut,
LowerSDFGGet>(&getContext());
target.addIllegalOp<SDFG::Init, SDFG::Start, SDFG::Shutdown,
SDFG::MakeProcess, SDFG::MakeStream, SDFG::Put>();
// All BConcrete ops are legal after the conversion
target.addLegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalOp<mlir::func::ReturnOp, mlir::func::FuncOp,
mlir::func::CallOp, SDFG::Get, mlir::tensor::CastOp>();
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertSDFGToStreamEmulatorPass() {
return std::make_unique<SDFGToStreamEmulatorPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,139 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
#include "concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h"
#include "concretelang/Support/CompilerEngine.h"
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/AffineMap.h>
#include <mlir/IR/BuiltinTypes.h>
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::tensor;
namespace SDFG = mlir::concretelang::SDFG;
namespace mlir {
namespace concretelang {
namespace SDFG {
namespace {} // namespace
} // namespace SDFG
} // namespace concretelang
} // namespace mlir
namespace {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
loc,
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
char stream_emulator_get_memref[] = "stream_emulator_get_memref";
template <typename Op, char const *funcName>
struct BufferizableWithCallOpInterface
: public BufferizableOpInterface::ExternalModel<
BufferizableWithCallOpInterface<Op, funcName>, Op> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto loc = op->getLoc();
// TODO: For now we allocate for the result of GET but we might be
// able to avoid the copy depending on the stream semantics.
auto resTensorType =
op->getResultTypes()[0].template cast<mlir::TensorType>();
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
resTensorType.getElementType());
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
if (mlir::failed(outMemref)) {
return mlir::failure();
}
// The last operand is the result
mlir::SmallVector<mlir::Value> operands(op->getOperands());
operands.push_back(getCastedMemRef(rewriter, loc, *outMemref));
mlir::FunctionType funcType = mlir::FunctionType::get(
rewriter.getContext(), mlir::ValueRange{operands}.getTypes(),
mlir::TypeRange());
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed())
return ::mlir::failure();
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
operands);
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
return success();
}
};
} // namespace
void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) {
SDFG::Get::attachInterface<
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref>>(
*ctx);
});
}

View File

@@ -1,5 +1,6 @@
add_mlir_dialect_library(
ConcretelangSDFGTransforms
BufferizableOpInterfaceImpl.cpp
SDFGConvertibleOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
@@ -12,6 +13,10 @@ add_mlir_dialect_library(
PUBLIC
SDFGDialect
ConcretelangSDFGInterfaces
ConcretelangConversion
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR
MLIRMemRefDialect
MLIRPass

View File

@@ -35,6 +35,7 @@
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Dialect/SDFG/IR/SDFGDialect.h>
#include <concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <concretelang/Runtime/DFRuntime.hpp>
@@ -83,6 +84,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
SDFG::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
@@ -417,6 +419,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
}
// SDFG -> Canonical dialects
if (mlir::concretelang::pipeline::lowerSDFGToStd(mlirContext, module,
enablePass)
.failed()) {
return errorDiag("Lowering from SDFG to canonical MLIR dialects failed");
}
if (target == Target::STD)
return std::move(res);

View File

@@ -296,6 +296,17 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("SDFGToStd", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertSDFGToStreamEmulatorPass(),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,