mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023
This commit is contained in:
45
python/test/unit/language/assert_helper.py
Normal file
45
python/test/unit/language/assert_helper.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
assert x == 0, "x != 0"
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_assert(BLOCK == 128, "BLOCK != 128")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert(func: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
46
python/test/unit/language/print_helper.py
Normal file
46
python/test/unit/language/print_helper.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_print(sys.argv[1], sys.argv[2])
|
||||
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
||||
@@ -385,17 +385,22 @@ def test_where(dtype):
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
TEST_POINTERS: tl.constexpr,
|
||||
TEST_SCALAR_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
if TEST_SCALAR_POINTERS:
|
||||
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
||||
output = tl.load(ptr + offsets, mask=mask)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
@@ -411,8 +416,12 @@ def test_where(dtype):
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
if select_ptrs:
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
||||
z = np.where(cond[0], x, y)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
def test_where_broadcast():
|
||||
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
shape = (8, 8)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
||||
val = offs.to(tl.float32)
|
||||
x = X + offs
|
||||
tl.atomic_min(x, val)
|
||||
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
||||
kernel[(2,)](x, shape[0], shape[1])
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=32)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
check_type_supported(out_dtype)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -1240,6 +1284,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
@@ -1409,6 +1479,28 @@ def test_vectorization(N):
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_hints", [False, True])
|
||||
def test_vectorization_hints(has_hints):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
off = torch.zeros(1, device='cuda', dtype=torch.int32)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offsets = offsets + tl.load(off)
|
||||
if HINT:
|
||||
tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
|
||||
ptx = pgm.asm["ptx"]
|
||||
if has_hints:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.v4.b32" not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
@@ -1479,7 +1571,7 @@ def test_pointer_arguments(device):
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
@@ -1739,6 +1831,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_for_iv_int64():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out, lo, hi):
|
||||
acc = 0
|
||||
acc = acc.to(tl.int64)
|
||||
for i in range(lo, hi):
|
||||
acc += i
|
||||
tl.store(Out, acc)
|
||||
|
||||
lo = 2**35
|
||||
hi = 2**35 + 20
|
||||
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
kernel[(1,)](out, lo, hi)
|
||||
assert out[0] == sum(range(lo, hi))
|
||||
|
||||
|
||||
def test_if_else():
|
||||
|
||||
@triton.jit
|
||||
@@ -1972,3 +2081,42 @@ def test_load_scalar_with_mask():
|
||||
Out = torch.empty_like(Index, device='cuda')
|
||||
kernel[(1,)](Input, Index, Out, Index.numel())
|
||||
assert Out.data[0] == 0
|
||||
|
||||
|
||||
# This test is used to test our own PTX codegen for float16 and int16 conversions
|
||||
# maybe delete it later after ptxas has been fixed
|
||||
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
|
||||
def test_ptx_cast(dtype_str):
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
||||
xmask = xindex < xnumel
|
||||
rbase = tl.arange(0, RBLOCK)[None, :]
|
||||
x0 = xindex
|
||||
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
|
||||
for roffset in range(0, rnumel, RBLOCK):
|
||||
rindex = roffset + rbase
|
||||
rmask = rindex < rnumel
|
||||
r1 = rindex
|
||||
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
|
||||
tmp1 = 2
|
||||
tmp2 = tmp0 * tmp1
|
||||
tmp3 = tmp2.to(dtype)
|
||||
tmp5 = _tmp4 < tmp3
|
||||
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
|
||||
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
|
||||
|
||||
torch.manual_seed(123)
|
||||
if dtype_str == 'int16':
|
||||
torch_dtype = torch.int16
|
||||
triton_dtype = tl.int32
|
||||
else:
|
||||
torch_dtype = torch.float16
|
||||
triton_dtype = tl.float32
|
||||
|
||||
s0 = 4
|
||||
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
||||
assert buf14.to(torch.float32).mean() == -2.0
|
||||
|
||||
@@ -385,17 +385,22 @@ def test_where(dtype):
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
TEST_POINTERS: tl.constexpr,
|
||||
TEST_SCALAR_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
if TEST_SCALAR_POINTERS:
|
||||
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
||||
output = tl.load(ptr + offsets, mask=mask)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
@@ -411,8 +416,12 @@ def test_where(dtype):
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
if select_ptrs:
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
||||
z = np.where(cond[0], x, y)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
def test_where_broadcast():
|
||||
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
shape = (8, 8)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
||||
val = offs.to(tl.float32)
|
||||
x = X + offs
|
||||
tl.atomic_min(x, val)
|
||||
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
||||
kernel[(2,)](x, shape[0], shape[1])
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=16)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4]) # TODO: support tl.float8e5
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16]) # TODO: support torch.float32
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
check_type_supported(out_dtype)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -1036,8 +1080,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
@@ -1248,6 +1292,51 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
# @triton.jit
|
||||
# def _kernel(out):
|
||||
# a = GENERATE_TEST_HERE
|
||||
# b = GENERATE_TEST_HERE
|
||||
# c = tl.dot(a, b)
|
||||
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(out_ptr, c)
|
||||
|
||||
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# out_ref = torch.matmul(a, b)
|
||||
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
# assert torch.all(out == out_ref)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
@@ -1426,7 +1515,7 @@ def test_pointer_arguments(device):
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
@@ -1686,6 +1775,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_for_iv_int64():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out, lo, hi):
|
||||
acc = 0
|
||||
acc = acc.to(tl.int64)
|
||||
for i in range(lo, hi):
|
||||
acc += i
|
||||
tl.store(Out, acc)
|
||||
|
||||
lo = 2**35
|
||||
hi = 2**35 + 20
|
||||
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
kernel[(1,)](out, lo, hi)
|
||||
assert out[0] == sum(range(lo, hi))
|
||||
|
||||
|
||||
def test_if_else():
|
||||
|
||||
@triton.jit
|
||||
@@ -1863,7 +1969,7 @@ else:
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
53
python/test/unit/language/test_subprocess.py
Normal file
53
python/test/unit/language/test_subprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", func_types)
|
||||
def test_assert(func_type: str):
|
||||
os.environ["TRITON_DEBUG"] = "1"
|
||||
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
os.environ["TRITON_DEBUG"] = "0"
|
||||
if func_type != "static_assert":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
@@ -5,7 +5,8 @@ import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
p = torch.softmax(p.float(), dim=-1).to(dtype)
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
decimal = 1 if dtype == torch.bfloat16 else 2
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
@@ -107,33 +106,6 @@ def test_specialize(mode):
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs["repr"]
|
||||
triton.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
||||
|
||||
def test_constexpr_not_callable() -> None:
|
||||
@triton.jit
|
||||
def kernel(X, c: tl.constexpr):
|
||||
@@ -176,6 +148,26 @@ def test_jit_warmup_cache() -> None:
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_jit_debug() -> None:
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.device_assert(idx < 32, "idx < 32")
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
|
||||
Reference in New Issue
Block a user