[TESTS] make performance regression testing less strict (#1231)

This commit is contained in:
Philippe Tillet
2023-02-21 22:22:02 -08:00
committed by GitHub
parent 6bef0c2bd6
commit ba0198326e
2 changed files with 13 additions and 10 deletions

View File

@@ -57,7 +57,7 @@ matmul_data = {
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.80, 'float32': 0.75, 'int8': 0.46},
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
@@ -96,7 +96,7 @@ def test_matmul(M, N, K, dtype_str):
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
@@ -149,10 +149,10 @@ def test_elementwise(N):
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
# Flash-Attention
@@ -188,7 +188,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
@@ -200,4 +200,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

View File

@@ -93,7 +93,11 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
def allclose(x, y, tol=1e-2):
def allclose(x, y, atol=0, rtol=1e-2):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
@@ -101,12 +105,11 @@ def allclose(x, y, tol=1e-2):
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
tol = 0
rtol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
err = torch.max(diff) / torch.max(x_max, y_max)
return err <= tol
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
def nvsmi(attrs):