[FRONTEND] Improve tl.full to accept both static and dynamic values (#1269)

This commit is contained in:
Keren Zhou
2023-03-02 12:19:54 -08:00
committed by GitHub
parent d54745538b
commit 65e5a3bc24
3 changed files with 62 additions and 16 deletions

View File

@@ -1240,6 +1240,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
out_static = torch.zeros((128), dtype=dtype, device="cuda")
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
assert torch.all(out_dynamic == 2)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):