// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(rewriter.getContext(), 64)); } LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op, llvm::StringRef funcName, LLVM::LLVMFunctionType funcType, ConversionPatternRewriter &rewriter) { // Check if the function is already in the symbol table auto module = op->getParentOfType(); auto funcOp = module.lookupSymbol(funcName); if (!funcOp) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); funcOp = rewriter.create(op->getLoc(), funcName, funcType); funcOp.setPrivate(); } else { if (!funcOp.isPrivate()) { op->emitError() << "the function \"" << funcName << "\" conflicts with the Dataflow Runtime API, please rename."; return nullptr; } } return funcOp; } /// This function is only needed for debug purposes to inspect values /// in the generated code - it is therefore not generally in use. LLVM_ATTRIBUTE_UNUSED void insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op, Value val) { OpBuilder::InsertionGuard guard(rewriter); auto printFnType = LLVM::LLVMFunctionType::get( LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true); auto printFnOp = getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter); rewriter.create(op->getLoc(), printFnOp, val); } struct MakeReadyFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::MakeReadyFutureOp mrfOp, RT::MakeReadyFutureOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); // Normally this function takes a pointer as parameter auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter), {}, /*isVariadic=*/true); auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future", mrfFuncType, rewriter); // In order to support non pointer types, we need to allocate // explicitly space that we can reference as a base for the // future. auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn( mrfOp->getParentOfType(), getIndexType(), getTypeConverter()->useOpaquePointers()); auto sizeBytes = getSizeInBytes( mrfOp.getLoc(), adaptor.getOperands().getTypes().front(), rewriter); auto results = rewriter.create(mrfOp.getLoc(), allocFuncOp, sizeBytes); Value allocatedPtr = rewriter.create( mrfOp.getLoc(), mlir::LLVM::LLVMPointerType::get( adaptor.getOperands().getTypes().front()), results.getResult()); rewriter.create(mrfOp.getLoc(), adaptor.getOperands().front(), allocatedPtr); SmallVector mrfOperands = {adaptor.getOperands()}; mrfOperands[0] = allocatedPtr; rewriter.replaceOpWithNewOp(mrfOp, mrfFuncOp, mrfOperands); return mlir::success(); } }; struct AwaitFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::AwaitFutureOp afOp, RT::AwaitFutureOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); auto afFuncType = LLVM::LLVMFunctionType::get( mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), {getVoidPtrI64Type(rewriter)}); auto afFuncOp = getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter); auto afCallOp = rewriter.create(afOp.getLoc(), afFuncOp, adaptor.getOperands()); Value futVal = rewriter.create( afOp.getLoc(), mlir::LLVM::LLVMPointerType::get( (*getTypeConverter()).convertType(afOp.getResult().getType())), afCallOp.getResult()); rewriter.replaceOpWithNewOp(afOp, futVal); return success(); } }; struct CreateAsyncTaskOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::CreateAsyncTaskOp catOp, RT::CreateAsyncTaskOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto catFuncType = LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task", catFuncType, rewriter); rewriter.replaceOpWithNewOp(catOp, catFuncOp, adaptor.getOperands()); return success(); } }; struct RegisterTaskWorkFunctionOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::RegisterTaskWorkFunctionOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp, RT::RegisterTaskWorkFunctionOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rtwfFuncType = LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); auto rtwfFuncOp = getOrInsertFuncOpDecl( rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter); rewriter.replaceOpWithNewOp(rtwfOp, rtwfFuncOp, adaptor.getOperands()); return success(); } }; struct DeallocateFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DeallocateFutureOp dfOp, RT::DeallocateFutureOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dfFuncType = LLVM::LLVMFunctionType::get( getVoidType(), {getVoidPtrI64Type(rewriter)}); auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future", dfFuncType, rewriter); rewriter.replaceOpWithNewOp(dfOp, dfFuncOp, adaptor.getOperands()); return success(); } }; struct DeallocateFutureDataOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DeallocateFutureDataOp dfdOp, RT::DeallocateFutureDataOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dfdFuncType = LLVM::LLVMFunctionType::get( getVoidType(), {getVoidPtrI64Type(rewriter)}); auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data", dfdFuncType, rewriter); rewriter.replaceOpWithNewOp(dfdOp, dfdFuncOp, adaptor.getOperands()); return success(); } }; struct BuildReturnPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp, RT::BuildReturnPtrPlaceholderOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); // BuildReturnPtrPlaceholder is a placeholder for generating a memory // location where a pointer to allocated memory can be written so // that we can return outputs from task work function. Value one = rewriter.create( befOp.getLoc(), (*getTypeConverter()).convertType(rewriter.getIndexType()), rewriter.getIntegerAttr( (*getTypeConverter()).convertType(rewriter.getIndexType()), 1)); rewriter.replaceOpWithNewOp( befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), one, 0); return success(); } }; struct DerefReturnPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp, RT::DerefReturnPtrPlaceholderOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // DerefReturnPtrPlaceholder is a placeholder for generating a // dereference operation for the pointer used to get results from // task. rewriter.replaceOpWithNewOp(drppOp, adaptor.getOperands().front()); return success(); } }; struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering : public ConvertOpToLLVMPattern< RT::DerefWorkFunctionArgumentPtrPlaceholderOp> { using ConvertOpToLLVMPattern< RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite( RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp, RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); // DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for // generating a dereference operation for the pointer used to pass // arguments to the task. rewriter.replaceOpWithNewOp(dwfappOp, adaptor.getOperands().front()); return success(); } }; struct WorkFunctionReturnOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern; mlir::LogicalResult matchAndRewrite(RT::WorkFunctionReturnOp wfrOp, RT::WorkFunctionReturnOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( wfrOp, adaptor.getOperands().front(), adaptor.getOperands().back()); return success(); } }; } // end anonymous namespace } // namespace concretelang } // namespace mlir void mlir::concretelang::populateRTToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< MakeReadyFutureOpInterfaceLowering, AwaitFutureOpInterfaceLowering, BuildReturnPtrPlaceholderOpInterfaceLowering, DerefReturnPtrPlaceholderOpInterfaceLowering, DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering, CreateAsyncTaskOpInterfaceLowering, RegisterTaskWorkFunctionOpInterfaceLowering, DeallocateFutureOpInterfaceLowering, DeallocateFutureDataOpInterfaceLowering, WorkFunctionReturnOpInterfaceLowering>(converter); // clang-format on }