mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Improve tl.full to accept both static and dynamic values (#1269)
This commit is contained in:
@@ -1240,6 +1240,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
|
||||
Reference in New Issue
Block a user