mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER][BACKEND] Cleaned up Volta codegen (#1185)
This commit is contained in:
@@ -1240,23 +1240,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
def test_dot_without_load(dtype_str):
|
||||
@triton.jit
|
||||
def _kernel(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
b = GENERATE_TEST_HERE
|
||||
c = tl.dot(a, b)
|
||||
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(out_ptr, c)
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
# @triton.jit
|
||||
# def _kernel(out):
|
||||
# a = GENERATE_TEST_HERE
|
||||
# b = GENERATE_TEST_HERE
|
||||
# c = tl.dot(a, b)
|
||||
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(out_ptr, c)
|
||||
|
||||
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
out_ref = torch.matmul(a, b)
|
||||
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
kernel[(1,)](out)
|
||||
assert torch.all(out == out_ref)
|
||||
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# out_ref = torch.matmul(a, b)
|
||||
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
# assert torch.all(out == out_ref)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
|
||||
Reference in New Issue
Block a user