[FRONTEND] Use inline asm for global timer and smid functions (#2143)

Simplify the code by using inline asm to implement globaltimer and smid
instead of relying on bc file.
This commit is contained in:
Thomas
2023-08-20 22:56:37 -07:00
committed by GitHub
parent ad3e363a44
commit 54ca7fcb35
6 changed files with 14 additions and 35 deletions

View File

@@ -514,7 +514,7 @@ def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
}];
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
let results = (outs TT_Tensor:$result);
let results = (outs TT_Type:$result);
let assemblyFormat = [{

View File

@@ -1,17 +0,0 @@
; ~/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/bin/llvm-as ./src/extra/cuda.ll -o ./triton/language/extra/cuda.bc
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "nvptx64-nvidia-cuda"
define i64 @globaltimer() #0 {
%1 = call i64 asm sideeffect "mov.u64 $0, %globaltimer;", "=l"() nounwind
ret i64 %1
}
define i32 @smid() #0 {
%1 = call i32 asm "mov.u32 $0, %smid;", "=r"() nounwind
ret i32 %1
}
attributes #0 = { alwaysinline nounwind }

View File

@@ -3403,8 +3403,7 @@ def test_globaltimer(device):
out2 = to_triton(np.zeros((1,), dtype=np.int64), device=device)
h = kernel[(1,)](out1, out2)
assert out2[0] > 0
# 2 inlined globaltimers + one extra in the wrapper extern function
assert h.asm["ptx"].count("%globaltimer") == 3
assert h.asm["ptx"].count("%globaltimer") == 2
def test_smid(device):
@@ -3417,7 +3416,7 @@ def test_smid(device):
out = to_triton(np.zeros((1024,), dtype=np.int32), device=device)
h = kernel[(out.shape[0],)](out)
assert out.sort()[0].unique().shape[0] > 0
assert h.asm["ptx"].count("%smid") == 2
assert h.asm["ptx"].count("%smid") == 1
# -----------------------
# test layout conversions

View File

@@ -1811,6 +1811,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
is_pure = _constexpr_to_value(is_pure)
ret_shape = None
arg_types = []
res_ty = dtype
for i in range(len(dispatch_args)):
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
arg_types.append(dispatch_args[i].dtype)
@@ -1825,10 +1826,10 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False)
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape).to_ir(_builder)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty, is_pure, pack)
return tensor(call, block_type(dtype, ret_shape))
ret_shape = broadcast_arg.shape
res_ty = block_type(dtype, ret_shape)
call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack)
return tensor(call, res_ty)
# -----------------------

Binary file not shown.

View File

@@ -1,19 +1,15 @@
import os
from .. import core
__path__ = os.path.dirname(os.path.abspath(__file__))
@core.extern
def globaltimer(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("globaltimer", core.dtype("int64")),
}, is_pure=False, _builder=_builder)
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [],
dtype=core.int64, is_pure=False,
pack=1, _builder=_builder)
@core.extern
def smid(_builder=None):
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
{tuple(): ("smid", core.dtype("int32")),
}, is_pure=True, _builder=_builder)
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [],
dtype=core.int32, is_pure=True,
pack=1, _builder=_builder)