mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Removed torch dependency and cleaned up testing (#1394)
`assert triton.testing.allclose` -> `torch.testing.assert_allclose` `triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.ops
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
|
||||
@@ -96,7 +97,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
|
||||
#######################
|
||||
@@ -152,7 +153,7 @@ def test_elementwise(N):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
@@ -200,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
Reference in New Issue
Block a user