[FRONTEND][BACKEND] Add a performance test for reductions (#2125)

Also stop promoting integer types as it doesn't give better perf this
will allow more vectorization oportuinity in the future.
This commit is contained in:
Thomas
2023-08-17 16:30:33 -07:00
committed by GitHub
parent 3fa6d51bc9
commit 387fc890a5
5 changed files with 75 additions and 13 deletions

View File

@@ -225,3 +225,59 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
#######################
# Reduction
#######################
@triton.jit
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# run in a loop to only to make it compute bound.
for i in range(100):
x = tl.sum(x, axis=0) + y
tl.store(output_ptr + offsets, x, mask=mask)
reduction_data = {
'a100': {
1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031},
1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032},
}
}
@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys())
@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32'])
def test_reductions(N, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
torch.manual_seed(0)
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str]
ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
z = torch.empty((N, ), dtype=dtype, device='cuda')
if dtype == torch.float16 or dtype == torch.float32:
x = torch.randn_like(z)
y = torch.randn_like(z)
else:
info = torch.iinfo(dtype)
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench_cudagraph(fn)
cur_gpu_perf = 100. * 2. * N / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)

View File

@@ -1485,8 +1485,6 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
def get_reduced_dtype(dtype_str, op):
if op in ('argmin', 'argmax'):
return 'int32'
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
return 'int32'
if dtype_str == 'bfloat16':
return 'float32'
return dtype_str