Files
concrete/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp
youben11 f4166a4973 docs: use consistent style for comment blocks
prefix comment blocks with ///
2022-07-07 16:11:19 +01:00

303 lines
13 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteTypes.h>
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/IR/RTTypes.h>
#include <concretelang/Support/math.h>
#include <mlir/IR/BuiltinOps.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <llvm/IR/Instructions.h>
#include <llvm/Support/Compiler.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
#include <mlir/Transforms/RegionUtils.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace concretelang {
namespace {
mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(rewriter.getContext(), 64));
}
LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op,
llvm::StringRef funcName,
LLVM::LLVMFunctionType funcType,
ConversionPatternRewriter &rewriter) {
// Check if the function is already in the symbol table
auto module = op->getParentOfType<ModuleOp>();
auto funcOp = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName);
if (!funcOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
funcOp =
rewriter.create<LLVM::LLVMFuncOp>(op->getLoc(), funcName, funcType);
funcOp.setPrivate();
} else {
if (!funcOp.isPrivate()) {
op->emitError()
<< "the function \"" << funcName
<< "\" conflicts with the Dataflow Runtime API, please rename.";
return nullptr;
}
}
return funcOp;
}
/// This function is only needed for debug purposes to inspect values
/// in the generated code - it is therefore not generally in use.
LLVM_ATTRIBUTE_UNUSED void
insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op,
Value val) {
OpBuilder::InsertionGuard guard(rewriter);
auto printFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true);
auto printFnOp =
getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter);
rewriter.create<LLVM::CallOp>(op->getLoc(), printFnOp, val);
}
struct MakeReadyFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::MakeReadyFutureOp> {
using ConvertOpToLLVMPattern<RT::MakeReadyFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::MakeReadyFutureOp mrfOp,
RT::MakeReadyFutureOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
// Normally this function takes a pointer as parameter
auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter),
{}, /*isVariadic=*/true);
auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future",
mrfFuncType, rewriter);
// In order to support non pointer types, we need to allocate
// explicitly space that we can reference as a base for the
// future.
auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn(
mrfOp->getParentOfType<ModuleOp>(), getIndexType());
auto sizeBytes = getSizeInBytes(
mrfOp.getLoc(), adaptor.getOperands().getTypes().front(), rewriter);
auto results = mlir::LLVM::createLLVMCall(
rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType());
Value allocatedPtr = rewriter.create<mlir::LLVM::BitcastOp>(
mrfOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
adaptor.getOperands().getTypes().front()),
results[0]);
rewriter.create<LLVM::StoreOp>(mrfOp.getLoc(),
adaptor.getOperands().front(), allocatedPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, allocatedPtr);
return mlir::success();
}
};
struct AwaitFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::AwaitFutureOp> {
using ConvertOpToLLVMPattern<RT::AwaitFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::AwaitFutureOp afOp, RT::AwaitFutureOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
auto afFuncType = LLVM::LLVMFunctionType::get(
mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
{getVoidPtrI64Type(rewriter)});
auto afFuncOp =
getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter);
auto afCallOp = rewriter.create<LLVM::CallOp>(afOp.getLoc(), afFuncOp,
adaptor.getOperands());
Value futVal = rewriter.create<mlir::LLVM::BitcastOp>(
afOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
(*getTypeConverter()).convertType(afOp.getResult().getType())),
afCallOp.getResult(0));
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(afOp, futVal);
return success();
}
};
struct CreateAsyncTaskOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp> {
using ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::CreateAsyncTaskOp catOp,
RT::CreateAsyncTaskOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto catFuncType =
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task",
catFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(catOp, catFuncOp,
adaptor.getOperands());
return success();
}
};
struct DeallocateFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DeallocateFutureOp> {
using ConvertOpToLLVMPattern<RT::DeallocateFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DeallocateFutureOp dfOp,
RT::DeallocateFutureOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dfFuncType = LLVM::LLVMFunctionType::get(
getVoidType(), {getVoidPtrI64Type(rewriter)});
auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future",
dfFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfOp, dfFuncOp,
adaptor.getOperands());
return success();
}
};
struct DeallocateFutureDataOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DeallocateFutureDataOp> {
using ConvertOpToLLVMPattern<
RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DeallocateFutureDataOp dfdOp,
RT::DeallocateFutureDataOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dfdFuncType = LLVM::LLVMFunctionType::get(
getVoidType(), {getVoidPtrI64Type(rewriter)});
auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data",
dfdFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfdOp, dfdFuncOp,
adaptor.getOperands());
return success();
}
};
struct BuildReturnPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::BuildReturnPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp,
RT::BuildReturnPtrPlaceholderOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
// BuildReturnPtrPlaceholder is a placeholder for generating a memory
// location where a pointer to allocated memory can be written so
// that we can return outputs from task work function.
Value one = rewriter.create<arith::ConstantOp>(
befOp.getLoc(),
(*getTypeConverter()).convertType(rewriter.getIndexType()),
rewriter.getIntegerAttr(
(*getTypeConverter()).convertType(rewriter.getIndexType()), 1));
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
one, 0);
return success();
}
};
struct DerefReturnPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DerefReturnPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp,
RT::DerefReturnPtrPlaceholderOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// DerefReturnPtrPlaceholder is a placeholder for generating a
// dereference operation for the pointer used to get results from
// task.
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(drppOp,
adaptor.getOperands().front());
return success();
}
};
struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<
RT::DerefWorkFunctionArgumentPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult matchAndRewrite(
RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp,
RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
// DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for
// generating a dereference operation for the pointer used to pass
// arguments to the task.
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(dwfappOp,
adaptor.getOperands().front());
return success();
}
};
struct WorkFunctionReturnOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::WorkFunctionReturnOp> {
using ConvertOpToLLVMPattern<
RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::WorkFunctionReturnOp wfrOp,
RT::WorkFunctionReturnOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
wfrOp, adaptor.getOperands().front(), adaptor.getOperands().back());
return success();
}
};
} // end anonymous namespace
} // namespace concretelang
} // namespace mlir
void mlir::concretelang::populateRTToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
MakeReadyFutureOpInterfaceLowering,
AwaitFutureOpInterfaceLowering,
BuildReturnPtrPlaceholderOpInterfaceLowering,
DerefReturnPtrPlaceholderOpInterfaceLowering,
DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering,
CreateAsyncTaskOpInterfaceLowering,
DeallocateFutureOpInterfaceLowering,
DeallocateFutureDataOpInterfaceLowering,
WorkFunctionReturnOpInterfaceLowering>(converter);
// clang-format on
}