mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add a performance test for reductions (#2125)
Also stop promoting integer types as it doesn't give better perf this will allow more vectorization oportuinity in the future.
This commit is contained in:
@@ -225,3 +225,59 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)]
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
|
||||
#######################
|
||||
# Reduction
|
||||
#######################
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
# run in a loop to only to make it compute bound.
|
||||
for i in range(100):
|
||||
x = tl.sum(x, axis=0) + y
|
||||
|
||||
tl.store(output_ptr + offsets, x, mask=mask)
|
||||
|
||||
|
||||
reduction_data = {
|
||||
'a100': {
|
||||
1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031},
|
||||
1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys())
|
||||
@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32'])
|
||||
def test_reductions(N, dtype_str):
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
torch.manual_seed(0)
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str]
|
||||
ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
z = torch.empty((N, ), dtype=dtype, device='cuda')
|
||||
if dtype == torch.float16 or dtype == torch.float32:
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
else:
|
||||
info = torch.iinfo(dtype)
|
||||
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
cur_gpu_perf = 100. * 2. * N / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
@@ -1485,8 +1485,6 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
def get_reduced_dtype(dtype_str, op):
|
||||
if op in ('argmin', 'argmax'):
|
||||
return 'int32'
|
||||
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
||||
return 'int32'
|
||||
if dtype_str == 'bfloat16':
|
||||
return 'float32'
|
||||
return dtype_str
|
||||
|
||||
Reference in New Issue
Block a user