mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Use inline asm for global timer and smid functions (#2143)
Simplify the code by using inline asm to implement globaltimer and smid instead of relying on bc file.
This commit is contained in:
@@ -514,7 +514,7 @@ def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
|
||||
}];
|
||||
|
||||
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
|
||||
let results = (outs TT_Tensor:$result);
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
; ~/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/bin/llvm-as ./src/extra/cuda.ll -o ./triton/language/extra/cuda.bc
|
||||
|
||||
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
|
||||
target triple = "nvptx64-nvidia-cuda"
|
||||
|
||||
|
||||
define i64 @globaltimer() #0 {
|
||||
%1 = call i64 asm sideeffect "mov.u64 $0, %globaltimer;", "=l"() nounwind
|
||||
ret i64 %1
|
||||
}
|
||||
|
||||
define i32 @smid() #0 {
|
||||
%1 = call i32 asm "mov.u32 $0, %smid;", "=r"() nounwind
|
||||
ret i32 %1
|
||||
}
|
||||
|
||||
attributes #0 = { alwaysinline nounwind }
|
||||
@@ -3403,8 +3403,7 @@ def test_globaltimer(device):
|
||||
out2 = to_triton(np.zeros((1,), dtype=np.int64), device=device)
|
||||
h = kernel[(1,)](out1, out2)
|
||||
assert out2[0] > 0
|
||||
# 2 inlined globaltimers + one extra in the wrapper extern function
|
||||
assert h.asm["ptx"].count("%globaltimer") == 3
|
||||
assert h.asm["ptx"].count("%globaltimer") == 2
|
||||
|
||||
|
||||
def test_smid(device):
|
||||
@@ -3417,7 +3416,7 @@ def test_smid(device):
|
||||
out = to_triton(np.zeros((1024,), dtype=np.int32), device=device)
|
||||
h = kernel[(out.shape[0],)](out)
|
||||
assert out.sort()[0].unique().shape[0] > 0
|
||||
assert h.asm["ptx"].count("%smid") == 2
|
||||
assert h.asm["ptx"].count("%smid") == 1
|
||||
|
||||
# -----------------------
|
||||
# test layout conversions
|
||||
|
||||
@@ -1811,6 +1811,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
is_pure = _constexpr_to_value(is_pure)
|
||||
ret_shape = None
|
||||
arg_types = []
|
||||
res_ty = dtype
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
|
||||
arg_types.append(dispatch_args[i].dtype)
|
||||
@@ -1825,10 +1826,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape).to_ir(_builder)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty, is_pure, pack)
|
||||
return tensor(call, block_type(dtype, ret_shape))
|
||||
ret_shape = broadcast_arg.shape
|
||||
res_ty = block_type(dtype, ret_shape)
|
||||
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
|
||||
return tensor(call, res_ty)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
Binary file not shown.
@@ -1,19 +1,15 @@
|
||||
import os
|
||||
|
||||
from .. import core
|
||||
|
||||
__path__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("globaltimer", core.dtype("int64")),
|
||||
}, is_pure=False, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
|
||||
dtype=core.int64, is_pure=False,
|
||||
pack=1, _builder=_builder)
|
||||
|
||||
|
||||
@core.extern
|
||||
def smid(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("smid", core.dtype("int32")),
|
||||
}, is_pure=True, _builder=_builder)
|
||||
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
|
||||
dtype=core.int32, is_pure=True,
|
||||
pack=1, _builder=_builder)
|
||||
|
||||
Reference in New Issue
Block a user