mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): add lowering and bufferization for SDFG dialect, generate code to Stream Emulator API.
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
|
||||
@@ -61,6 +61,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls";
|
||||
let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }];
|
||||
let constructor = "mlir::concretelang::createConvertSDFGToStreamEmulatorPass()";
|
||||
let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `SDFG` dialect to Stream Emulator calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertSDFGToStreamEmulatorPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace concretelang {
|
||||
namespace SDFG {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace SDFG
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -66,6 +66,10 @@ mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
|
||||
@@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [
|
||||
"LLVMX86Info",
|
||||
];
|
||||
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 33] = [
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 34] = [
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
@@ -274,7 +274,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 33] = [
|
||||
"ConcreteDialect",
|
||||
"RTDialectAnalysis",
|
||||
"SDFGDialect",
|
||||
"ExtractSDFGOps"
|
||||
"ExtractSDFGOps",
|
||||
"SDFGToStreamEmulator",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
|
||||
@@ -4,6 +4,7 @@ add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(BConcreteToCAPI)
|
||||
add_subdirectory(SDFGToStreamEmulator)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
add_subdirectory(ExtractSDFGOps)
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/Analysis/Autopar.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
|
||||
namespace {
|
||||
struct MLIRLowerableDialectsToLLVMPass
|
||||
@@ -155,7 +156,9 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ContextType>() ||
|
||||
type.isa<mlir::concretelang::RT::FutureType>()) {
|
||||
type.isa<mlir::concretelang::RT::FutureType>() ||
|
||||
type.isa<mlir::concretelang::SDFG::DFGType>() ||
|
||||
type.isa<mlir::concretelang::SDFG::StreamType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
}
|
||||
|
||||
14
compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt
Normal file
14
compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_mlir_dialect_library(
|
||||
SDFGToStreamEmulator
|
||||
SDFGToStreamEmulator.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG
|
||||
DEPENDS
|
||||
SDFGDialect
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms)
|
||||
|
||||
target_link_libraries(SDFGToStreamEmulator PUBLIC SDFGDialect MLIRIR)
|
||||
@@ -0,0 +1,390 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
#include "concretelang/Runtime/stream_emulator_api.h"
|
||||
|
||||
namespace SDFG = mlir::concretelang::SDFG;
|
||||
|
||||
namespace {
|
||||
struct SDFGToStreamEmulatorPass
|
||||
: public SDFGToStreamEmulatorBase<SDFGToStreamEmulatorPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
char stream_emulator_init[] = "stream_emulator_init";
|
||||
char stream_emulator_run[] = "stream_emulator_run";
|
||||
char stream_emulator_delete[] = "stream_emulator_delete";
|
||||
char stream_emulator_make_memref_add_lwe_ciphertexts_u64_process[] =
|
||||
"stream_emulator_make_memref_add_lwe_ciphertexts_u64_process";
|
||||
char stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process[] =
|
||||
"stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process";
|
||||
char stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process[] =
|
||||
"stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process";
|
||||
char stream_emulator_make_memref_negate_lwe_ciphertext_u64_process[] =
|
||||
"stream_emulator_make_memref_negate_lwe_ciphertext_u64_process";
|
||||
char stream_emulator_make_memref_keyswitch_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_keyswitch_lwe_u64_process";
|
||||
char stream_emulator_make_memref_bootstrap_lwe_u64_process[] =
|
||||
"stream_emulator_make_memref_bootstrap_lwe_u64_process";
|
||||
|
||||
char stream_emulator_make_memref_stream[] =
|
||||
"stream_emulator_make_memref_stream";
|
||||
char stream_emulator_put_memref[] = "stream_emulator_put_memref";
|
||||
char stream_emulator_make_uint64_stream[] =
|
||||
"stream_emulator_make_uint64_stream";
|
||||
char stream_emulator_put_uint64[] = "stream_emulator_put_uint64";
|
||||
char stream_emulator_get_uint64[] = "stream_emulator_get_uint64";
|
||||
|
||||
mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
return mlir::RankedTensorType::get(shape, rewriter.getI64Type());
|
||||
}
|
||||
|
||||
mlir::Type makeDynamicTensorTypes(mlir::OpBuilder &rewriter, mlir::Type oldTy) {
|
||||
if (auto ttype = oldTy.dyn_cast_or_null<mlir::TensorType>())
|
||||
return getDynamicTensor(rewriter, ttype.getRank());
|
||||
if (auto stTy = oldTy.dyn_cast_or_null<SDFG::StreamType>())
|
||||
return SDFG::StreamType::get(
|
||||
rewriter.getContext(),
|
||||
makeDynamicTensorTypes(rewriter, stTy.getElementType()));
|
||||
return oldTy;
|
||||
}
|
||||
|
||||
mlir::LogicalResult insertGenericForwardDeclaration(mlir::Operation *op,
|
||||
mlir::OpBuilder &rewriter,
|
||||
llvm::StringRef funcName,
|
||||
mlir::TypeRange opTys,
|
||||
mlir::TypeRange resTys) {
|
||||
mlir::SmallVector<mlir::Type> operands;
|
||||
for (mlir::Type opTy : opTys)
|
||||
operands.push_back(makeDynamicTensorTypes(rewriter, opTy));
|
||||
mlir::SmallVector<mlir::Type> results;
|
||||
for (mlir::Type resTy : resTys)
|
||||
results.push_back(makeDynamicTensorTypes(rewriter, resTy));
|
||||
|
||||
mlir::FunctionType funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), operands, results);
|
||||
return insertForwardDeclaration(op, rewriter, funcName, funcType);
|
||||
}
|
||||
|
||||
void castDynamicTensorOps(mlir::Operation *op, mlir::OpBuilder &rewriter,
|
||||
mlir::ValueRange operands,
|
||||
mlir::SmallVector<mlir::Value> &newOps) {
|
||||
for (auto val : operands) {
|
||||
auto oldTy = val.getType();
|
||||
if (auto ttype = oldTy.dyn_cast_or_null<mlir::TensorType>())
|
||||
newOps.push_back(rewriter.create<mlir::tensor::CastOp>(
|
||||
op->getLoc(), getDynamicTensor(rewriter, ttype.getRank()), val));
|
||||
else
|
||||
newOps.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
struct LowerSDFGInit
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Init> {
|
||||
LowerSDFGInit(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Init>(context,
|
||||
benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::Init initOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::FunctionType funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {}, {SDFG::DFGType::get(rewriter.getContext())});
|
||||
if (insertForwardDeclaration(initOp, rewriter, stream_emulator_init,
|
||||
funcType)
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
initOp, stream_emulator_init,
|
||||
mlir::TypeRange{SDFG::DFGType::get(rewriter.getContext())});
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGStart
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Start> {
|
||||
LowerSDFGStart(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Start>(context,
|
||||
benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::Start startOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::FunctionType funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {});
|
||||
if (insertForwardDeclaration(startOp, rewriter, stream_emulator_run,
|
||||
funcType)
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
startOp, stream_emulator_run, mlir::TypeRange{},
|
||||
startOp.getOperation()->getOperands());
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGShutdown
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Shutdown> {
|
||||
LowerSDFGShutdown(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Shutdown>(context,
|
||||
benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::Shutdown desOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::FunctionType funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {});
|
||||
if (insertForwardDeclaration(desOp, rewriter, stream_emulator_delete,
|
||||
funcType)
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
desOp, stream_emulator_delete, mlir::TypeRange{},
|
||||
desOp.getOperation()->getOperands());
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGMakeProcess
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeProcess> {
|
||||
LowerSDFGMakeProcess(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeProcess>(
|
||||
context, benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
mlir::SmallVector<mlir::Value> operands(mpOp->getOperands());
|
||||
switch (mpOp.type()) {
|
||||
case SDFG::ProcessKind::add_eint:
|
||||
funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::add_eint_int:
|
||||
funcName =
|
||||
stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::mul_eint_int:
|
||||
funcName =
|
||||
stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::neg_eint:
|
||||
funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process;
|
||||
break;
|
||||
case SDFG::ProcessKind::keyswitch:
|
||||
funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process;
|
||||
// level
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
|
||||
// base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("baseLog")));
|
||||
// lwe_dim_in
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_in")));
|
||||
// lwe_dim_out
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("lwe_dim_out")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
case SDFG::ProcessKind::bootstrap:
|
||||
funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process;
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("inputLweDim")));
|
||||
// poly_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("polySize")));
|
||||
// level
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("level")));
|
||||
// base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(), mpOp->getAttrOfType<mlir::IntegerAttr>("baseLog")));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("glweDimension")));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
mpOp.getLoc(),
|
||||
mpOp->getAttrOfType<mlir::IntegerAttr>("outPrecision")));
|
||||
// context
|
||||
operands.push_back(getContextArgument(mpOp));
|
||||
break;
|
||||
}
|
||||
if (insertGenericForwardDeclaration(mpOp, rewriter, funcName,
|
||||
mlir::ValueRange{operands}.getTypes(),
|
||||
mpOp->getResultTypes())
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
mpOp, funcName, mpOp->getResultTypes(), operands);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGMakeStream
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeStream> {
|
||||
LowerSDFGMakeStream(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::MakeStream>(
|
||||
context, benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::MakeStream msOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
|
||||
stream_type t;
|
||||
switch (msOp.type()) {
|
||||
case SDFG::StreamKind::host_to_device:
|
||||
t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP;
|
||||
break;
|
||||
case SDFG::StreamKind::on_device:
|
||||
t = TS_STREAM_TYPE_TOPO_TO_TOPO_LSAP;
|
||||
break;
|
||||
case SDFG::StreamKind::device_to_host:
|
||||
t = TS_STREAM_TYPE_TOPO_TO_X86_LSAP;
|
||||
break;
|
||||
}
|
||||
auto sType = msOp->getResultTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
|
||||
assert(sType && "SDFG MakeStream operation should return a stream type");
|
||||
|
||||
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
|
||||
funcName = stream_emulator_make_memref_stream;
|
||||
} else {
|
||||
assert(sType.getElementType().isa<mlir::IntegerType>() &&
|
||||
"SDFG streams only support memrefs and integers.");
|
||||
funcName = stream_emulator_make_uint64_stream;
|
||||
}
|
||||
if (insertGenericForwardDeclaration(
|
||||
msOp, rewriter, funcName,
|
||||
{rewriter.getI64Type(), rewriter.getI64Type()},
|
||||
msOp->getResultTypes())
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
mlir::Value nullStringPtr = rewriter.create<mlir::arith::ConstantOp>(
|
||||
msOp.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
mlir::Value streamTypeCst = rewriter.create<mlir::arith::ConstantOp>(
|
||||
msOp.getLoc(), rewriter.getI64IntegerAttr((int)t));
|
||||
auto callop = rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
msOp, funcName,
|
||||
makeDynamicTensorTypes(rewriter, msOp->getResultTypes()[0]),
|
||||
mlir::ValueRange{nullStringPtr, streamTypeCst});
|
||||
for (auto &use : llvm::make_early_inc_range(msOp->getUses()))
|
||||
use.set(callop.getResult(0));
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGPut
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Put> {
|
||||
LowerSDFGPut(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Put>(context,
|
||||
benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::Put putOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
auto sType =
|
||||
putOp->getOperandTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
|
||||
assert(sType &&
|
||||
"SDFG Put operation must take a stream type as first parameter.");
|
||||
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
|
||||
funcName = stream_emulator_put_memref;
|
||||
} else {
|
||||
assert(sType.getElementType().isa<mlir::IntegerType>() &&
|
||||
"SDFG streams only support memrefs and integers.");
|
||||
funcName = stream_emulator_put_uint64;
|
||||
}
|
||||
if (insertGenericForwardDeclaration(putOp, rewriter, funcName,
|
||||
putOp->getOperandTypes(),
|
||||
putOp->getResultTypes())
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
mlir::SmallVector<mlir::Value> newOps;
|
||||
castDynamicTensorOps(putOp, rewriter, putOp->getOperands(), newOps);
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
putOp, funcName, putOp->getResultTypes(), newOps);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerSDFGGet
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::SDFG::Get> {
|
||||
LowerSDFGGet(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::SDFG::Get>(context,
|
||||
benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::SDFG::Get getOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
const char *funcName;
|
||||
auto sType =
|
||||
getOp->getOperandTypes()[0].dyn_cast_or_null<SDFG::StreamType>();
|
||||
assert(sType &&
|
||||
"SDFG Get operation must take a stream type as first parameter.");
|
||||
if (sType.getElementType().isa<mlir::RankedTensorType>()) {
|
||||
// TODO: SDFG.Get for memref streams is lowered during bufferization
|
||||
// as returning a memref requires allocation for now
|
||||
return ::mlir::success();
|
||||
} else {
|
||||
assert(sType.getElementType().isa<mlir::IntegerType>() &&
|
||||
"SDFG streams only support memrefs and integers.");
|
||||
funcName = stream_emulator_get_uint64;
|
||||
}
|
||||
if (insertGenericForwardDeclaration(getOp, rewriter, funcName,
|
||||
getOp->getOperandTypes(),
|
||||
getOp->getResultTypes())
|
||||
.failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
|
||||
getOp, funcName, getOp->getResultTypes(), getOp->getOperands());
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void SDFGToStreamEmulatorPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
patterns.insert<LowerSDFGInit, LowerSDFGStart, LowerSDFGShutdown,
|
||||
LowerSDFGMakeProcess, LowerSDFGMakeStream, LowerSDFGPut,
|
||||
LowerSDFGGet>(&getContext());
|
||||
|
||||
target.addIllegalOp<SDFG::Init, SDFG::Start, SDFG::Shutdown,
|
||||
SDFG::MakeProcess, SDFG::MakeStream, SDFG::Put>();
|
||||
// All BConcrete ops are legal after the conversion
|
||||
target.addLegalDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addLegalOp<mlir::func::ReturnOp, mlir::func::FuncOp,
|
||||
mlir::func::CallOp, SDFG::Get, mlir::tensor::CastOp>();
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertSDFGToStreamEmulatorPass() {
|
||||
return std::make_unique<SDFGToStreamEmulatorPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -0,0 +1,139 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
#include "concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include <mlir/IR/AffineExpr.h>
|
||||
#include <mlir/IR/AffineMap.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
using namespace mlir::tensor;
|
||||
|
||||
namespace SDFG = mlir::concretelang::SDFG;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace SDFG {
|
||||
namespace {} // namespace
|
||||
} // namespace SDFG
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
loc,
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
char stream_emulator_get_memref[] = "stream_emulator_get_memref";
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithCallOpInterface<Op, funcName>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// TODO: For now we allocate for the result of GET but we might be
|
||||
// able to avoid the copy depending on the stream semantics.
|
||||
auto resTensorType =
|
||||
op->getResultTypes()[0].template cast<mlir::TensorType>();
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
if (mlir::failed(outMemref)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// The last operand is the result
|
||||
mlir::SmallVector<mlir::Value> operands(op->getOperands());
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, *outMemref));
|
||||
|
||||
mlir::FunctionType funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), mlir::ValueRange{operands}.getTypes(),
|
||||
mlir::TypeRange());
|
||||
if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed())
|
||||
return ::mlir::failure();
|
||||
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
|
||||
operands);
|
||||
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) {
|
||||
SDFG::Get::attachInterface<
|
||||
BufferizableWithCallOpInterface<SDFG::Get, stream_emulator_get_memref>>(
|
||||
*ctx);
|
||||
});
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
add_mlir_dialect_library(
|
||||
ConcretelangSDFGTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
SDFGConvertibleOpInterfaceImpl.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
@@ -12,6 +13,10 @@ add_mlir_dialect_library(
|
||||
PUBLIC
|
||||
SDFGDialect
|
||||
ConcretelangSDFGInterfaces
|
||||
ConcretelangConversion
|
||||
MLIRArithmeticDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIRPass
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/SDFG/IR/SDFGDialect.h>
|
||||
#include <concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
@@ -83,6 +84,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
|
||||
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
|
||||
SDFG::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
|
||||
registry);
|
||||
@@ -417,6 +419,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
||||
}
|
||||
|
||||
// SDFG -> Canonical dialects
|
||||
if (mlir::concretelang::pipeline::lowerSDFGToStd(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from SDFG to canonical MLIR dialects failed");
|
||||
}
|
||||
|
||||
if (target == Target::STD)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -296,6 +296,17 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("SDFGToStd", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertSDFGToStreamEmulatorPass(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
|
||||
Reference in New Issue
Block a user