[OPTIMIZER][BACKEND] Cleaned up Volta codegen (#1185)

This commit is contained in:
Philippe Tillet
2023-02-14 22:39:35 -08:00
committed by GitHub
parent 8bca84ce3d
commit e3941f9d09
8 changed files with 297 additions and 496 deletions

View File

@@ -1240,23 +1240,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str):
@triton.jit
def _kernel(out):
a = GENERATE_TEST_HERE
b = GENERATE_TEST_HERE
c = tl.dot(a, b)
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(out_ptr, c)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
# @triton.jit
# def _kernel(out):
# a = GENERATE_TEST_HERE
# b = GENERATE_TEST_HERE
# c = tl.dot(a, b)
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(out_ptr, c)
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
out_ref = torch.matmul(a, b)
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
kernel[(1,)](out)
assert torch.all(out == out_ref)
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# out_ref = torch.matmul(a, b)
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# kernel[(1,)](out)
# assert torch.all(out == out_ref)
# ---------------
# test arange