mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Removed torch dependency and cleaned up testing (#1394)
`assert triton.testing.allclose` -> `torch.testing.assert_allclose` `triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.ops
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
|
||||
@@ -96,7 +97,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
|
||||
#######################
|
||||
@@ -152,7 +153,7 @@ def test_elementwise(N):
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
@@ -200,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
@@ -783,7 +783,7 @@ def test_atomic_cas():
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
triton.testing.assert_almost_equal(data, ref)
|
||||
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -1214,8 +1214,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
@@ -1477,7 +1477,7 @@ def test_arange(start, device='cuda'):
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref))
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
@@ -1513,7 +1513,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
# print((output - reference_out).nonzero())
|
||||
torch.testing.assert_allclose(output, reference_out)
|
||||
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
@@ -1544,15 +1545,15 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w, out_dtype=tl.float32)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
@@ -1564,7 +1565,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out = torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
@@ -1607,7 +1608,7 @@ def test_vectorization(N):
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
# np.testing.assert_allclose(dst, src[:N])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_hints", [False, True])
|
||||
|
||||
@@ -2,6 +2,34 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
||||
return ret
|
||||
|
||||
|
||||
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
|
||||
if data is None:
|
||||
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
|
||||
ref_ret = data
|
||||
ref_ret = ref_ret * alpha + beta
|
||||
ref_ret = ref_ret.half().to(dtype)
|
||||
if trans:
|
||||
ref_ret = ref_ret.t().requires_grad_()
|
||||
ref_ret = ref_ret.detach().requires_grad_()
|
||||
tri_ret = ref_ret.clone().detach().requires_grad_()
|
||||
return ref_ret, tri_ret
|
||||
|
||||
|
||||
def mask_tensor(x, mask, block, value=0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
||||
return ret
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
|
||||
@@ -16,8 +44,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
is_sdd = MODE == "sdd"
|
||||
is_dsd = MODE == "dsd"
|
||||
is_dds = MODE == "dds"
|
||||
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
|
||||
do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
|
||||
do_mask = lambda x: mask_tensor(x, layout, BLOCK)
|
||||
# create inputs
|
||||
# create op
|
||||
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
|
||||
@@ -32,9 +60,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
layout[1, 2, :] = 0
|
||||
layout[1, :, 1] = 0
|
||||
# create data
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
|
||||
a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
|
||||
b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
|
||||
dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)
|
||||
# compute [torch]
|
||||
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
|
||||
a_ref = do_mask(a_ref) if is_dsd else a_ref
|
||||
@@ -59,9 +87,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
||||
da_tri = a_tri.grad
|
||||
db_tri = b_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(c_ref, c_tri)
|
||||
triton.testing.assert_almost_equal(da_ref, da_tri)
|
||||
triton.testing.assert_almost_equal(db_ref, db_tri)
|
||||
torch.testing.assert_allclose(c_ref, c_tri)
|
||||
torch.testing.assert_allclose(da_ref, da_tri)
|
||||
torch.testing.assert_allclose(db_ref, db_tri)
|
||||
|
||||
|
||||
configs = [
|
||||
@@ -88,10 +116,10 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
layout[1, :, 1] = 0
|
||||
# initialize data
|
||||
a_shape = (Z, H, M, N)
|
||||
a_ref, a_tri = triton.testing.make_pair(a_shape)
|
||||
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
|
||||
a_ref, a_tri = make_pair(a_shape)
|
||||
dout_ref, dout_tri = make_pair(a_shape)
|
||||
# compute [torch]
|
||||
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
|
||||
a_ref.retain_grad()
|
||||
at_mask = torch.ones((M, N), device="cuda")
|
||||
if is_causal:
|
||||
@@ -100,19 +128,19 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
|
||||
a_ref[M == 0] = float("-inf")
|
||||
out_ref = torch.softmax(a_ref * scale, -1)
|
||||
out_ref.backward(dout_ref)
|
||||
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
out_ref = sparsify_tensor(out_ref, layout, BLOCK)
|
||||
da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK)
|
||||
# compute [triton]
|
||||
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
|
||||
a_tri.retain_grad()
|
||||
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
|
||||
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
|
||||
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
|
||||
out_tri.backward(dout_tri)
|
||||
da_tri = a_tri.grad
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(out_tri, out_ref)
|
||||
triton.testing.assert_almost_equal(da_tri, da_ref)
|
||||
torch.testing.assert_allclose(out_tri, out_ref)
|
||||
torch.testing.assert_allclose(da_tri, da_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
@@ -168,9 +196,9 @@ def test_attention_fwd_bwd(
|
||||
|
||||
# comparison
|
||||
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||
triton.testing.assert_almost_equal(loss, torch_loss)
|
||||
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
|
||||
for g1, g2 in zip(grads, torch_grads):
|
||||
triton.testing.assert_almost_equal(g1, g2)
|
||||
torch.testing.assert_allclose(g1, g2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", [16, 32, 64])
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, dtype, mode",
|
||||
@@ -24,7 +25,7 @@ def test_op(M, N, dtype, mode):
|
||||
tt_y = triton.ops.cross_entropy(x, idx)
|
||||
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
|
||||
if mode == 'forward':
|
||||
triton.testing.assert_almost_equal(th_y, tt_y)
|
||||
torch.testing.assert_allclose(th_y, tt_y)
|
||||
# backward pass
|
||||
elif mode == 'backward':
|
||||
dy = torch.randn_like(tt_y)
|
||||
@@ -35,4 +36,4 @@ def test_op(M, N, dtype, mode):
|
||||
x.grad.zero_()
|
||||
th_y.backward(dy)
|
||||
th_dx = x.grad.clone()
|
||||
triton.testing.assert_almost_equal(th_dx, tt_dx)
|
||||
torch.testing.assert_allclose(th_dx, tt_dx)
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@@ -38,8 +39,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
decimal = 1 if dtype == torch.bfloat16 else 2
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
@@ -52,4 +52,4 @@ def test_normalization_with_remat():
|
||||
arg8_1 = torch.rand(64, device="cuda")
|
||||
arg9_1 = torch.rand(64, device="cuda")
|
||||
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
|
||||
triton.testing.allclose(buf16.mean(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -95,4 +96,4 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
|
||||
triton.testing.assert_almost_equal(th_c, tt_c)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user