#include "mlir//IR/BuiltinTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Transforms/DialectConversion.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" /// LowLFHEOpToConcreteCAPICallPattern match the `Op` Operation and /// replace with a call to `funcName`, the funcName should be an external /// function that was linked later. It insert the forward declaration of the /// private `funcName` if it not already in the symbol table. /// The C signature of the function should be `void funcName(int *err, out, /// arg0, arg1)`, the pattern rewrite: /// ``` /// out = op(arg0, arg1) /// ``` /// to /// ``` /// err = memref.alloc() : memref /// out = _allocate_(err); /// call_op(err, out, arg0, arg1); /// ``` template struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context, mlir::StringRef funcName, mlir::StringRef allocName, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit), funcName(funcName), allocName(allocName) {} mlir::LogicalResult static insertForwardDeclaration( Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName, mlir::FunctionType funcType) { // Looking for the `funcName` Operation auto module = mlir::SymbolTable::getNearestSymbolTable(op); auto opFunc = mlir::dyn_cast_or_null( mlir::SymbolTable::lookupSymbolIn(module, funcName)); if (!opFunc) { // Insert the forward declaration of the funcName mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); opFunc = rewriter.create(rewriter.getUnknownLoc(), funcName, funcType); opFunc.setPrivate(); } else { // Check if the `funcName` is well a private function if (!opFunc.isPrivate()) { op.emitError() << "the function \"" << funcName << "\" conflicts with the concrete C API, please rename"; return mlir::failure(); } } assert(mlir::SymbolTable::lookupSymbolIn(module, funcName) ->template hasTrait()); return mlir::success(); } mlir::LogicalResult matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { auto errType = mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); // Insert forward declaration of the operator function { mlir::SmallVector operands{errType, op->getResultTypes().front()}; for (auto ty : op->getOperandTypes()) { operands.push_back(ty); } auto funcType = mlir::FunctionType::get(rewriter.getContext(), operands, {}); if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) { return mlir::failure(); } } // Insert forward declaration of the alloc function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), {errType, rewriter.getIndexType()}, {op->getResultTypes().front()}); if (insertForwardDeclaration(op, rewriter, allocName, funcType) .failed()) { return mlir::failure(); } } // Replace the operation with a call to the `funcName` { // Create the err value auto err = rewriter.create(op.getLoc(), errType); // Add the call to the allocation // TODO - 2018 auto lweSize = rewriter.create( op.getLoc(), rewriter.getIndexAttr(2048)); mlir::SmallVector allocOperands{err, lweSize}; auto alloc = rewriter.replaceOpWithNewOp( op, allocName, op.getType(), allocOperands); // Add err and allocated value to operands mlir::SmallVector newOperands{err, alloc.getResult(0)}; for (auto operand : op->getOperands()) { newOperands.push_back(operand); } rewriter.create(op.getLoc(), funcName, mlir::TypeRange{}, newOperands); } return mlir::success(); }; private: std::string funcName; std::string allocName; }; /// Populate the RewritePatternSet with all patterns that rewrite LowLFHE /// operators to the corresponding function call to the `Concrete C API`. void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { patterns.add>( patterns.getContext(), "add_lwe_ciphertexts_u64", "allocate_lwe_ciphertext_u64"); } namespace { struct LowLFHEToConcreteCAPIPass : public LowLFHEToConcreteCAPIBase { void runOnOperation() final; }; } // namespace void LowLFHEToConcreteCAPIPass::runOnOperation() { // Setup the conversion target. mlir::ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); // Setup rewrite patterns mlir::RewritePatternSet patterns(&getContext()); populateLowLFHEToConcreteCAPICall(patterns); // Apply the conversion mlir::ModuleOp op = getOperation(); if (mlir::applyFullConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } namespace mlir { namespace zamalang { std::unique_ptr> createConvertLowLFHEToConcreteCAPIPass() { return std::make_unique(); } } // namespace zamalang } // namespace mlir