Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023

This commit is contained in:
Rohit Santhanam
2023-03-13 18:09:12 +00:00
91 changed files with 3492 additions and 3373 deletions

View File

@@ -0,0 +1,45 @@
import sys
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
assert x == 0, "x != 0"
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_assert(BLOCK == 128, "BLOCK != 128")
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert(func: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "static_assert":
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])

View File

@@ -0,0 +1,46 @@
import sys
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print(x)
tl.store(Y + tl.arange(0, BLOCK), x)
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
if __name__ == "__main__":
test_print(sys.argv[1], sys.argv[2])

View File

@@ -1,56 +0,0 @@
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
'int8': torch.int8,
'uint8': torch.uint8,
'int16': torch.int16,
"int32": torch.int32,
'int64': torch.long,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
else:
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
return x
# @pytest.mark.parametrize('data_type',
# [("int8"),
# ('int16'),
# ('int32'),
# ("int64"),
# ('float16'),
# ("float32"),
# ("float64")])
def printf(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.printf("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
printf("float16")
printf("int8")

View File

@@ -385,17 +385,22 @@ def test_where(dtype):
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
TEST_POINTERS: tl.constexpr,
TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
@@ -411,8 +416,12 @@ def test_where(dtype):
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
assert (z == to_numpy(z_tri)).all()
if select_ptrs:
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
z = np.where(cond[0], x, y)
assert (z == to_numpy(z_tri)).all()
def test_where_broadcast():
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_block(device="cuda"):
shape = (8, 8)
@triton.jit
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
val = offs.to(tl.float32)
x = X + offs
tl.atomic_min(x, val)
x = torch.ones((8, 8), device=device, dtype=torch.float32)
kernel[(2,)](x, shape[0], shape[1])
assert torch.min(x).item() == 0.0
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
assert torch.all(output == ref)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
def test_load_store_same_ptr():
@triton.jit()
def kernel(in_out_ptr):
pid = tl.program_id(axis=0)
x = tl.load(in_out_ptr + pid)
out = x * 2
tl.store(in_out_ptr + pid, out)
for _ in range(1000):
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
kernel[(65536,)](x, num_warps=32)
assert torch.all(x == 2)
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
check_type_supported(out_dtype)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
tl.store(output_ptr + offsets, output, mask=mask)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
f8_tensor[all_exp_ones] = 0
f8 = triton.reinterpret(f8_tensor, in_dtype)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
def test_f16_to_f8_rounding():
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
def test_f16_to_f8_rounding(in_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
abs_error = torch.abs(f16_input - f16_output)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
@@ -1240,6 +1284,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
out_static = torch.zeros((128), dtype=dtype, device="cuda")
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
assert torch.all(out_dynamic == 2)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
@@ -1409,6 +1479,28 @@ def test_vectorization(N):
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
@pytest.mark.parametrize("has_hints", [False, True])
def test_vectorization_hints(has_hints):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
off = torch.zeros(1, device='cuda', dtype=torch.int32)
@triton.jit
def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offsets = offsets + tl.load(off)
if HINT:
tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
ptx = pgm.asm["ptx"]
if has_hints:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.v4.b32" not in ptx
# ---------------
# test store
# ---------------
@@ -1479,7 +1571,7 @@ def test_pointer_arguments(device):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -1739,6 +1831,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
# -----------------------
def test_for_iv_int64():
@triton.jit
def kernel(Out, lo, hi):
acc = 0
acc = acc.to(tl.int64)
for i in range(lo, hi):
acc += i
tl.store(Out, acc)
lo = 2**35
hi = 2**35 + 20
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
kernel[(1,)](out, lo, hi)
assert out[0] == sum(range(lo, hi))
def test_if_else():
@triton.jit
@@ -1972,3 +2081,42 @@ def test_load_scalar_with_mask():
Out = torch.empty_like(Index, device='cuda')
kernel[(1,)](Input, Index, Out, Index.numel())
assert Out.data[0] == 0
# This test is used to test our own PTX codegen for float16 and int16 conversions
# maybe delete it later after ptxas has been fixed
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
def test_ptx_cast(dtype_str):
@triton.jit
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
tmp1 = 2
tmp2 = tmp0 * tmp1
tmp3 = tmp2.to(dtype)
tmp5 = _tmp4 < tmp3
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
torch.manual_seed(123)
if dtype_str == 'int16':
torch_dtype = torch.int16
triton_dtype = tl.int32
else:
torch_dtype = torch.float16
triton_dtype = tl.float32
s0 = 4
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
assert buf14.to(torch.float32).mean() == -2.0

View File

@@ -385,17 +385,22 @@ def test_where(dtype):
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr):
TEST_POINTERS: tl.constexpr,
TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
@@ -411,8 +416,12 @@ def test_where(dtype):
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
assert (z == to_numpy(z_tri)).all()
if select_ptrs:
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
z = np.where(cond[0], x, y)
assert (z == to_numpy(z_tri)).all()
def test_where_broadcast():
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_tensor_atomic_rmw_block(device="cuda"):
shape = (8, 8)
@triton.jit
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
val = offs.to(tl.float32)
x = X + offs
tl.atomic_min(x, val)
x = torch.ones((8, 8), device=device, dtype=torch.float32)
kernel[(2,)](x, shape[0], shape[1])
assert torch.min(x).item() == 0.0
def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
assert torch.all(output == ref)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_f8_xf16_roundtrip(dtype):
def test_load_store_same_ptr():
@triton.jit()
def kernel(in_out_ptr):
pid = tl.program_id(axis=0)
x = tl.load(in_out_ptr + pid)
out = x * 2
tl.store(in_out_ptr + pid, out)
for _ in range(1000):
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
kernel[(65536,)](x, num_warps=16)
assert torch.all(x == 2)
@pytest.mark.parametrize("in_dtype", [tl.float8e4]) # TODO: support tl.float8e5
@pytest.mark.parametrize("out_dtype", [torch.float16]) # TODO: support torch.float32
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
check_type_supported(dtype)
check_type_supported(out_dtype)
@triton.jit
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
tl.store(output_ptr + offsets, output, mask=mask)
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
f8 = triton.reinterpret(f8_tensor, tl.float8)
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
f8_tensor[all_exp_ones] = 0
f8 = triton.reinterpret(f8_tensor, in_dtype)
n_elements = f8_tensor.numel()
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
def test_f16_to_f8_rounding():
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
def test_f16_to_f8_rounding(in_dtype):
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
error is the minimum over all float8.
Or the same explanation a bit mathier:
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
n_elements = f16_input.numel()
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
abs_error = torch.abs(f16_input - f16_output)
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
@@ -1036,8 +1080,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
[(dtype, shape, perm)
# TODO: bfloat16
for dtype in ['float16', 'float32']
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
if torch.version.hip is not None:
@@ -1248,6 +1292,51 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
out_static = torch.zeros((128), dtype=dtype, device="cuda")
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
assert torch.all(out_dynamic == 2)
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
# @triton.jit
# def _kernel(out):
# a = GENERATE_TEST_HERE
# b = GENERATE_TEST_HERE
# c = tl.dot(a, b)
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(out_ptr, c)
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# out_ref = torch.matmul(a, b)
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
# kernel[(1,)](out)
# assert torch.all(out == out_ref)
# ---------------
# test arange
# ---------------
@@ -1426,7 +1515,7 @@ def test_pointer_arguments(device):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -1686,6 +1775,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
# -----------------------
def test_for_iv_int64():
@triton.jit
def kernel(Out, lo, hi):
acc = 0
acc = acc.to(tl.int64)
for i in range(lo, hi):
acc += i
tl.store(Out, acc)
lo = 2**35
hi = 2**35 + 20
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
kernel[(1,)](out, lo, hi)
assert out[0] == sum(range(lo, hi))
def test_if_else():
@triton.jit
@@ -1863,7 +1969,7 @@ else:
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])

View File

@@ -1,22 +0,0 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128

View File

@@ -0,0 +1,53 @@
import os
import subprocess
import sys
import pytest
dir_path = os.path.dirname(os.path.realpath(__file__))
print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
func_types = ["device_assert", "assert", "static_assert"]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = line
if func_type != "static_print":
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print":
for i in range(128):
assert i in new_lines
assert len(new_lines) == 128
else:
assert len(new_lines) == 1
@pytest.mark.parametrize("func_type", func_types)
def test_assert(func_type: str):
os.environ["TRITON_DEBUG"] = "1"
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
os.environ["TRITON_DEBUG"] = "0"
if func_type != "static_assert":
assert num_errs == 127
else:
assert num_errs == 0

View File

@@ -5,7 +5,8 @@ import triton
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
tri_dq, q.grad = q.grad.clone(), None
# compare
triton.testing.assert_almost_equal(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_dv, tri_dv)
decimal = 1 if dtype == torch.bfloat16 else 2
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)

View File

@@ -1,6 +1,5 @@
import multiprocessing
import os
import re
import shutil
from collections import namedtuple
@@ -107,33 +106,6 @@ def test_specialize(mode):
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.JITFunction.cache_hook = None
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type
def test_constexpr_not_callable() -> None:
@triton.jit
def kernel(X, c: tl.constexpr):
@@ -176,6 +148,26 @@ def test_jit_warmup_cache() -> None:
assert len(kernel_add.cache) == 1
def test_jit_debug() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 2
def test_compile_in_subproc() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):