From 160dfe838e82eadcebfaab7a0aac012bfdd89ea1 Mon Sep 17 00:00:00 2001 From: Jason Furmanek Date: Tue, 12 Dec 2023 19:30:01 +0000 Subject: [PATCH] ROCM IFU: Fix print and assert --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 39 +++++++++++-------- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 15 ++++--- python/test/unit/language/print_helper.py | 3 +- python/test/unit/language/test_subprocess.py | 10 ++--- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ee89039af..4a3260ac9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -17,7 +17,13 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp, assert(axis >= 0); assert(axis < 3); assert(moduleOp); - +#ifdef USE_ROCM + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); + return rewriter.create(loc, i32_ty, blockId); +#else // It is not easy to get the compute capability here, so we use numCTAs to // decide the semantic of GetProgramIdOp. If numCTAs = 1, then // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to @@ -27,6 +33,7 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp, std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid."; sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z' return getSRegValue(rewriter, loc, sreg); +#endif } struct ReturnOpConversion : public ConvertOpToLLVMPattern { @@ -137,12 +144,14 @@ struct PrintOpConversion if (op.getNumOperands() == 0) { std::string formatStr; llvm::raw_string_ostream os(formatStr); +#ifdef USE_ROCM + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << op.getPrefix().str(); + llPrintfHIP(loc, op->getParentOfType(), formatStr, + {pid[0], pid[1], pid[2]}, rewriter); +#else os << "pid (" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s"; -#ifdef USE_ROCM - llPrintfHIP(loc, op->getParentOfType(), formatStr, - {pid[0], pid[1], pid[2], prefixStr}, rewriter); -#else llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter); #endif } else { @@ -250,8 +259,12 @@ struct PrintOpConversion } os << ")"; +#if USE_ROCM + os << op.getPrefix().str(); +#else os << "%s"; printfOperands.push_back(prefixStr); +#endif if (numOperands > 1) { os << "(operand " << operand << ") "; @@ -266,7 +279,7 @@ struct PrintOpConversion // printfOperands. But we don't want to create BLOCK_SIZE duplicate // strings, so we cache the Value. if (i == 0) { -#ifdef USE_ROCM +#if USE_ROCM formatting = llPrintfHIP(op->getLoc(), op->getParentOfType(), formatStr, printfOperands, rewriter); #else @@ -616,20 +629,14 @@ struct GetProgramIdOpConversion matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { -#ifdef USE_ROCM - Location loc = op->getLoc(); - assert(op.getAxisAsInt() < 3); - - Value blockId = - rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]); - rewriter.replaceOpWithNewOp(op, i32_ty, blockId); - return success(); -#else Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(), op->getParentOfType(), rewriter); +#ifdef USE_ROCM + rewriter.replaceOpWithNewOp(op, i32_ty, programId); +#else rewriter.replaceOp(op, programId); - return success(); #endif + return success(); } static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, mlir::gpu::Dimension::y, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 53bb4cc03..ff8149125 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -1330,7 +1330,7 @@ protected: formatting.formatStrValue = LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline); formatting.formatStrSize = msgNewline.size_in_bytes(); - llPrintfHIP(loc, moduleOp, formatting, args, rewriter); + llPrintfHIP(loc, moduleOp, formatting, args, rewriter, stderr); return formatting; } @@ -1342,7 +1342,7 @@ protected: auto typeConverter = getTypeConverter(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); - mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8); + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); @@ -1364,7 +1364,7 @@ protected: moduleOp, loc, rewriter, "__ockl_printf_append_string_n", LLVM::LLVMFunctionType::get( llvmI64, - {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); + {llvmI64, {ptrType}, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall Value zeroI64 = rewriter.create(loc, llvmI64, 0); @@ -1405,12 +1405,11 @@ protected: Value arg = args[i]; if (auto floatType = arg.getType().dyn_cast()) { if (!floatType.isF64()) - arg = rewriter.create( - loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create(loc, llvmI64, arg); + arg = fpext(typeConverter->convertType(rewriter.getF64Type()), arg); + arg = bitcast(arg, llvmI64); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create(loc, llvmI64, arg); + arg = zext(llvmI64, arg); arguments.push_back(arg); } @@ -1421,7 +1420,7 @@ protected: auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create(loc, ocklAppendArgs, arguments); + auto call = call(ocklAppendArgs, arguments); printfDesc = call.getResult(); } } diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index 6776f09c1..4ffbea768 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -68,7 +68,8 @@ def kernel_print_no_arg(): def test_print(func: str, data_type: str): - shape = (128, ) + #shape = (128, ) + shape = (256, ) x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_print": diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 08bc63a3e..84fe2020d 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -35,22 +35,22 @@ def test_print(func_type: str, data_type: str): # pid (, , ) idx (, , ...) (operand ) expected_lines = Counter() if func_type == "print" or func_type == "device_print": - for i in range(128): + for i in range(256): line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" if data_type.startswith("float"): line += ".000000" expected_lines[line] = 1 elif func_type == "static_print": - expected_lines[" int32[constexpr[128]]"] = 1 + expected_lines[" int32[constexpr[256]]"] = 1 elif func_type == "no_arg_print": - expected_lines["pid (0, 0, 0) idx (): 0"] = 128 + expected_lines["pid (0, 0, 0) idx (): 0"] = 256 elif func_type == "print_no_arg": - expected_lines["pid (0, 0, 0) no arg"] = 128 + expected_lines["pid (0, 0, 0) no arg"] = 256 elif func_type == "device_print_large": for i, j, k in itertools.product(range(2), range(64), range(128)): expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": - for i in range(128): + for i in range(256): expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1