[FRONTEND] Removed torch dependency and cleaned up testing (#1394)

`assert triton.testing.allclose` -> `torch.testing.assert_allclose`
`triton.testing.assert_almost_equal` -> `torch.testing.assert_allclose`
This commit is contained in:
Philippe Tillet
2023-03-23 22:37:21 -07:00
committed by GitHub
parent ff1d0377e0
commit fc7c0b0e43
14 changed files with 152 additions and 188 deletions

View File

@@ -6,6 +6,7 @@ import torch
import triton
import triton.language as tl
import triton.ops
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
@@ -96,7 +97,7 @@ def test_matmul(M, N, K, dtype_str):
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
@@ -152,7 +153,7 @@ def test_elementwise(N):
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
# Flash-Attention
@@ -200,4 +201,4 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
torch.testing.assert_allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

View File

@@ -783,7 +783,7 @@ def test_atomic_cas():
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
triton.testing.assert_almost_equal(data, ref)
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
# ---------------
@@ -1214,8 +1214,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# numpy result
z_ref = x.transpose(*perm)
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
@@ -1477,7 +1477,7 @@ def test_arange(start, device='cuda'):
tl.store(z + off, val)
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
triton.testing.assert_almost_equal(z_tri, z_ref)
np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref))
# ---------------
# test load
@@ -1513,7 +1513,8 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
kernel[(1,)](input, output, input_size, output_size)
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# print((output - reference_out).nonzero())
torch.testing.assert_allclose(output, reference_out)
# Testing masked loads with an intermate copy to shared memory run.
@@ -1544,15 +1545,15 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w, out_dtype=tl.float32)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
@@ -1564,7 +1565,7 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
M=M, N=N, K=K)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
@@ -1607,7 +1608,7 @@ def test_vectorization(N):
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# np.testing.assert_allclose(dst, src[:N])
@pytest.mark.parametrize("has_hints", [False, True])

View File

@@ -2,6 +2,34 @@ import pytest
import torch
import triton
import triton.ops
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
return ret
def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
ref_ret = ref_ret * alpha + beta
ref_ret = ref_ret.half().to(dtype)
if trans:
ref_ret = ref_ret.t().requires_grad_()
ref_ret = ref_ret.detach().requires_grad_()
tri_ret = ref_ret.clone().detach().requires_grad_()
return ref_ret, tri_ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
return ret
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@@ -16,8 +44,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
is_sdd = MODE == "sdd"
is_dsd = MODE == "dsd"
is_dds = MODE == "dds"
do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK)
do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
do_mask = lambda x: mask_tensor(x, layout, BLOCK)
# create inputs
# create op
a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
@@ -32,9 +60,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
layout[1, 2, :] = 0
layout[1, :, 1] = 0
# create data
a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = triton.testing.make_pair(c_shape, dtype=DTYPE)
a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)
# compute [torch]
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
@@ -59,9 +87,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
triton.testing.assert_almost_equal(c_ref, c_tri)
triton.testing.assert_almost_equal(da_ref, da_tri)
triton.testing.assert_almost_equal(db_ref, db_tri)
torch.testing.assert_allclose(c_ref, c_tri)
torch.testing.assert_allclose(da_ref, da_tri)
torch.testing.assert_allclose(db_ref, db_tri)
configs = [
@@ -88,10 +116,10 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
layout[1, :, 1] = 0
# initialize data
a_shape = (Z, H, M, N)
a_ref, a_tri = triton.testing.make_pair(a_shape)
dout_ref, dout_tri = triton.testing.make_pair(a_shape)
a_ref, a_tri = make_pair(a_shape)
dout_ref, dout_tri = make_pair(a_shape)
# compute [torch]
a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
if is_causal:
@@ -100,19 +128,19 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
a_ref[M == 0] = float("-inf")
out_ref = torch.softmax(a_ref * scale, -1)
out_ref.backward(dout_ref)
out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK)
da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK)
out_ref = sparsify_tensor(out_ref, layout, BLOCK)
da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK)
# compute [triton]
a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK)
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK)
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
# compare
triton.testing.assert_almost_equal(out_tri, out_ref)
triton.testing.assert_almost_equal(da_tri, da_ref)
torch.testing.assert_allclose(out_tri, out_ref)
torch.testing.assert_allclose(da_tri, da_ref)
@pytest.mark.parametrize("block", [16, 32, 64])
@@ -168,9 +196,9 @@ def test_attention_fwd_bwd(
# comparison
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
triton.testing.assert_almost_equal(loss, torch_loss)
torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0)
for g1, g2 in zip(grads, torch_grads):
triton.testing.assert_almost_equal(g1, g2)
torch.testing.assert_allclose(g1, g2)
@pytest.mark.parametrize("block", [16, 32, 64])

View File

@@ -2,6 +2,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize("M, N, dtype, mode",
@@ -24,7 +25,7 @@ def test_op(M, N, dtype, mode):
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
triton.testing.assert_almost_equal(th_y, tt_y)
torch.testing.assert_allclose(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
@@ -35,4 +36,4 @@ def test_op(M, N, dtype, mode):
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)
torch.testing.assert_allclose(th_dx, tt_dx)

View File

@@ -2,6 +2,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@@ -38,8 +39,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype):
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
decimal = 1 if dtype == torch.bfloat16 else 2
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)

View File

@@ -52,4 +52,4 @@ def test_normalization_with_remat():
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
triton.testing.allclose(buf16.mean(), buf14.mean().item(), atol=1e-7, rtol=0)
torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)

View File

@@ -4,6 +4,7 @@ import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize(
@@ -95,4 +96,4 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest)
triton.testing.assert_almost_equal(th_c, tt_c)
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)