// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. // See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace zamalang { namespace { mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(rewriter.getContext(), 64)); } LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op, llvm::StringRef funcName, LLVM::LLVMFunctionType funcType, ConversionPatternRewriter &rewriter) { // Check if the function is already in the symbol table auto module = op->getParentOfType(); auto funcOp = module.lookupSymbol(funcName); if (!funcOp) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); funcOp = rewriter.create(op->getLoc(), funcName, funcType); funcOp.setPrivate(); } else { if (!funcOp.isPrivate()) { op->emitError() << "the function \"" << funcName << "\" conflicts with the Dataflow Runtime API, please rename."; return nullptr; } } return funcOp; } // This function is only needed for debug purposes to inspect values // in the generated code - it is therefore not generally in use. LLVM_ATTRIBUTE_UNUSED void insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op, Value val) { OpBuilder::InsertionGuard guard(rewriter); auto printFnType = LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true); auto printFnOp = getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter); rewriter.create(op->getLoc(), printFnOp, val); } struct MakeReadyFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::MakeReadyFutureOp mrfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::MakeReadyFutureOp::Adaptor transformed(operands); OpBuilder::InsertionGuard guard(rewriter); // Normally this function takes a pointer as parameter auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter), {}, /*isVariadic=*/true); auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future", mrfFuncType, rewriter); // In order to support non pointer types, we need to allocate // explicitly space that we can reference as a base for the // future. auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn( mrfOp->getParentOfType(), getIndexType()); auto sizeBytes = getSizeInBytes( mrfOp.getLoc(), transformed.getOperands().getTypes().front(), rewriter); auto results = mlir::LLVM::createLLVMCall( rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType()); Value allocatedPtr = rewriter.create( mrfOp.getLoc(), mlir::LLVM::LLVMPointerType::get( transformed.getOperands().getTypes().front()), results[0]); rewriter.create( mrfOp.getLoc(), transformed.getOperands().front(), allocatedPtr); rewriter.replaceOpWithNewOp(mrfOp, mrfFuncOp, allocatedPtr); return mlir::success(); } }; struct AwaitFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::AwaitFutureOp afOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::AwaitFutureOp::Adaptor transformed(operands); OpBuilder::InsertionGuard guard(rewriter); auto afFuncType = LLVM::LLVMFunctionType::get( mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), {getVoidPtrI64Type(rewriter)}); auto afFuncOp = getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter); auto afCallOp = rewriter.create(afOp.getLoc(), afFuncOp, transformed.getOperands()); Value futVal = rewriter.create( afOp.getLoc(), mlir::LLVM::LLVMPointerType::get( (*getTypeConverter()).convertType(afOp.getResult().getType())), afCallOp.getResult(0)); rewriter.replaceOpWithNewOp(afOp, futVal); return success(); } }; struct CreateAsyncTaskOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::CreateAsyncTaskOp catOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::CreateAsyncTaskOp::Adaptor transformed(operands); auto catFuncType = LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task", catFuncType, rewriter); rewriter.replaceOpWithNewOp(catOp, catFuncOp, transformed.getOperands()); return success(); } }; struct DeallocateFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DeallocateFutureOp dfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::DeallocateFutureOp::Adaptor transformed(operands); auto dfFuncType = LLVM::LLVMFunctionType::get( getVoidType(), {getVoidPtrI64Type(rewriter)}); auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future", dfFuncType, rewriter); rewriter.replaceOpWithNewOp(dfOp, dfFuncOp, transformed.getOperands()); return success(); } }; struct DeallocateFutureDataOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DeallocateFutureDataOp dfdOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::DeallocateFutureDataOp::Adaptor transformed(operands); auto dfdFuncType = LLVM::LLVMFunctionType::get( getVoidType(), {getVoidPtrI64Type(rewriter)}); auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data", dfdFuncType, rewriter); rewriter.replaceOpWithNewOp(dfdOp, dfdFuncOp, transformed.getOperands()); return success(); } }; struct BuildReturnPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); // BuildReturnPtrPlaceholder is a placeholder for generating a memory // location where a pointer to allocated memory can be written so // that we can return outputs from task work function. Value one = rewriter.create( befOp.getLoc(), (*getTypeConverter()).convertType(rewriter.getIndexType()), rewriter.getIntegerAttr( (*getTypeConverter()).convertType(rewriter.getIndexType()), 1)); rewriter.replaceOpWithNewOp( befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), one, /*alignment=*/ rewriter.getIntegerAttr( (*getTypeConverter()).convertType(rewriter.getIndexType()), 0)); return success(); } }; struct DerefReturnPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::DerefReturnPtrPlaceholderOp::Adaptor transformed(operands); // DerefReturnPtrPlaceholder is a placeholder for generating a // dereference operation for the pointer used to get results from // task. rewriter.replaceOpWithNewOp( drppOp, transformed.getOperands().front()); return success(); } }; struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern< RT::DerefWorkFunctionArgumentPtrPlaceholderOp> { using ConvertOpToLLVMPattern< RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor transformed( operands); OpBuilder::InsertionGuard guard(rewriter); // DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for // generating a dereference operation for the pointer used to pass // arguments to the task. rewriter.replaceOpWithNewOp( dwfappOp, transformed.getOperands().front()); return success(); } }; struct WorkFunctionReturnOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::WorkFunctionReturnOp wfrOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RT::WorkFunctionReturnOp::Adaptor transformed(operands); rewriter.replaceOpWithNewOp( wfrOp, transformed.getOperands().front(), transformed.getOperands().back()); return success(); } }; } // end anonymous namespace } // namespace zamalang } // namespace mlir void mlir::zamalang::populateRTToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< MakeReadyFutureOpInterfaceLowering, AwaitFutureOpInterfaceLowering, BuildReturnPtrPlaceholderOpInterfaceLowering, DerefReturnPtrPlaceholderOpInterfaceLowering, DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering, CreateAsyncTaskOpInterfaceLowering, DeallocateFutureOpInterfaceLowering, DeallocateFutureDataOpInterfaceLowering, WorkFunctionReturnOpInterfaceLowering>(converter); // clang-format on }