mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Simple mechanism to run Triton kernels on PyTorch for debugging purpose (upstream from Kernl). Todo: - random grid iteration - support of atomic ops - more unit tests - cover new APIs?
2842 lines
106 KiB
Python
2842 lines
106 KiB
Python
# flake8: noqa: F821,F841
|
|
import itertools
|
|
import os
|
|
import re
|
|
from typing import Optional, Union
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from numpy.random import RandomState
|
|
|
|
import triton
|
|
import triton._C.libtriton.triton as _triton
|
|
import triton.language as tl
|
|
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
|
|
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
|
float_dtypes = ['float16', 'float32', 'float64']
|
|
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
|
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
|
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
|
|
|
|
|
def _bitwidth(dtype: str) -> int:
|
|
# ex.: "int64" -> 64
|
|
return int(re.search(r'(\d+)$', dtype).group(1))
|
|
|
|
|
|
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
|
|
"""
|
|
Override `rs` if you're calling this function twice and don't want the same
|
|
result for both calls.
|
|
"""
|
|
if isinstance(shape, int):
|
|
shape = (shape, )
|
|
if rs is None:
|
|
rs = RandomState(seed=17)
|
|
if dtype_str in int_dtypes + uint_dtypes:
|
|
iinfo = np.iinfo(getattr(np, dtype_str))
|
|
low = iinfo.min if low is None else max(low, iinfo.min)
|
|
high = iinfo.max if high is None else min(high, iinfo.max)
|
|
dtype = getattr(np, dtype_str)
|
|
x = rs.randint(low, high, shape, dtype=dtype)
|
|
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
|
return x
|
|
elif dtype_str in float_dtypes:
|
|
return rs.normal(0, 1, shape).astype(dtype_str)
|
|
elif dtype_str == 'bfloat16':
|
|
return (rs.normal(0, 1, shape).astype('float32').view('uint32')
|
|
& np.uint32(0xffff0000)).view('float32')
|
|
elif dtype_str in ['bool', 'int1', 'bool_']:
|
|
return rs.normal(0, 1, shape) > 0.0
|
|
else:
|
|
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
|
|
|
|
|
def to_triton(x: np.ndarray, device='cuda', dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
|
|
'''
|
|
Note: We need dst_type because the type of x can be different from dst_type.
|
|
For example: x is of type `float32`, dst_type is `bfloat16`.
|
|
If dst_type is None, we infer dst_type from x.
|
|
'''
|
|
t = x.dtype.name
|
|
if t in uint_dtypes:
|
|
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
|
x_signed = x.astype(getattr(np, signed_type_name))
|
|
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
|
else:
|
|
if t == 'float32' and dst_type == 'bfloat16':
|
|
return torch.tensor(x, device=device).bfloat16()
|
|
return torch.tensor(x, device=device)
|
|
|
|
|
|
def torch_dtype_name(dtype) -> str:
|
|
if isinstance(dtype, triton.language.dtype):
|
|
return dtype.name
|
|
elif isinstance(dtype, torch.dtype):
|
|
# 'torch.int64' -> 'int64'
|
|
m = re.match(r'^torch\.(\w+)$', str(dtype))
|
|
return m.group(1)
|
|
else:
|
|
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
|
|
|
|
|
|
def to_numpy(x):
|
|
if isinstance(x, TensorWrapper):
|
|
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
|
elif isinstance(x, torch.Tensor):
|
|
if x.dtype is torch.bfloat16:
|
|
return x.cpu().float().numpy()
|
|
return x.cpu().numpy()
|
|
else:
|
|
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
|
|
|
|
|
def patch_kernel(template, to_replace):
|
|
kernel = triton.JITFunction(template.fn)
|
|
for key, value in to_replace.items():
|
|
kernel.src = kernel.src.replace(key, value)
|
|
return kernel
|
|
|
|
|
|
def check_type_supported(dtype):
|
|
'''
|
|
skip test if dtype is not supported on the current device
|
|
'''
|
|
cc = torch.cuda.get_device_capability()
|
|
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
|
|
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
|
|
|
|
|
class MmaLayout:
|
|
def __init__(self, version, warps_per_cta):
|
|
self.version = version
|
|
self.warps_per_cta = str(warps_per_cta)
|
|
|
|
def __str__(self):
|
|
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
|
|
|
|
|
class BlockedLayout:
|
|
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
|
self.sz_per_thread = str(size_per_thread)
|
|
self.threads_per_warp = str(threads_per_warp)
|
|
self.warps_per_cta = str(warps_per_cta)
|
|
self.order = str(order)
|
|
|
|
def __str__(self):
|
|
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
|
def test_empty_kernel(dtype_x, device='cuda'):
|
|
SIZE = 128
|
|
|
|
@triton.jit
|
|
def kernel(X, SIZE: tl.constexpr):
|
|
pass
|
|
check_type_supported(dtype_x)
|
|
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
|
|
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
|
|
|
|
|
# generic test functions
|
|
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
|
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
|
SIZE = 128
|
|
# define the kernel / launch-grid
|
|
|
|
@triton.jit
|
|
def kernel(Z, X, SIZE: tl.constexpr):
|
|
off = tl.arange(0, SIZE)
|
|
x = tl.load(X + off)
|
|
z = GENERATE_TEST_HERE
|
|
tl.store(Z + off, z)
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
|
# inputs
|
|
x = numpy_random(SIZE, dtype_str=dtype_x)
|
|
if 'log' in expr:
|
|
x = np.abs(x) + 0.01
|
|
# reference result
|
|
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
|
# triton result
|
|
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
|
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
|
|
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4)
|
|
# compare
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
|
|
|
|
|
def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
|
"""
|
|
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
|
operations on the two types should return. Returns None if the return value
|
|
matches numpy. This is generally needed because Triton and pytorch return
|
|
narrower floating point types than numpy in mixed operations, and because
|
|
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
|
|
numpy/pytorch do not.
|
|
"""
|
|
overrides = {
|
|
('float16', 'int16'): np.float16,
|
|
('float16', 'int32'): np.float16,
|
|
('float16', 'int64'): np.float16,
|
|
('float16', 'uint16'): np.float16,
|
|
('float16', 'uint32'): np.float16,
|
|
('float16', 'uint64'): np.float16,
|
|
('int8', 'uint8'): np.uint8,
|
|
('int8', 'uint16'): np.uint16,
|
|
('int8', 'uint32'): np.uint32,
|
|
('int8', 'uint64'): np.uint64,
|
|
('int16', 'uint16'): np.uint16,
|
|
('int16', 'uint32'): np.uint32,
|
|
('int16', 'uint64'): np.uint64,
|
|
('int32', 'uint32'): np.uint32,
|
|
('int32', 'uint64'): np.uint64,
|
|
('int64', 'uint64'): np.uint64,
|
|
}
|
|
key = (a, b) if a < b else (b, a)
|
|
return overrides.get(key)
|
|
|
|
|
|
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
|
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
|
check_type_supported(dtype_y)
|
|
SIZE = 128
|
|
# define the kernel / launch-grid
|
|
|
|
@triton.jit
|
|
def kernel(Z, X, Y, SIZE: tl.constexpr):
|
|
off = tl.arange(0, SIZE)
|
|
x = tl.load(X + off)
|
|
y = tl.load(Y + off)
|
|
z = GENERATE_TEST_HERE
|
|
tl.store(Z + off, z)
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
|
# inputs
|
|
rs = RandomState(17)
|
|
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
|
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
|
|
if mode_x == 'nan':
|
|
x[:] = float('nan')
|
|
if mode_y == 'nan':
|
|
y[:] = float('nan')
|
|
# reference result
|
|
z_ref = eval(expr if numpy_expr is None else numpy_expr)
|
|
dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
|
|
if dtype_z is not None:
|
|
z_ref = z_ref.astype(dtype_z)
|
|
# triton result
|
|
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
|
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
|
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
|
|
kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01)
|
|
|
|
|
|
def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|
# The result of x % y is ill-conditioned if x % y is much smaller than x.
|
|
# pytorch/CUDA has slightly different (probably better) rounding on
|
|
# remainders than stock LLVM. We currently don't expect to match it
|
|
# bit-for-bit.
|
|
return (dtype_x, dtype_y) in [
|
|
('int32', 'bfloat16'),
|
|
('int32', 'float16'),
|
|
('int32', 'float32'),
|
|
('int64', 'bfloat16'),
|
|
('int64', 'float16'),
|
|
('int64', 'float32'),
|
|
('int64', 'float64'),
|
|
('uint16', 'bfloat16'),
|
|
('uint16', 'float16'),
|
|
('uint16', 'float32'),
|
|
('uint32', 'bfloat16'),
|
|
('uint32', 'float16'),
|
|
('uint32', 'float32'),
|
|
('uint64', 'bfloat16'),
|
|
('uint64', 'float16'),
|
|
('uint64', 'float32'),
|
|
('uint64', 'float64'),
|
|
]
|
|
|
|
# ---------------
|
|
# test binary ops
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
|
(dtype_x, dtype_y, op)
|
|
for op in ['+', '-', '*', '/', '%']
|
|
for dtype_x in dtypes_with_bfloat16
|
|
for dtype_y in dtypes_with_bfloat16
|
|
])
|
|
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|
expr = f' x {op} y'
|
|
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
|
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
|
numpy_expr = 'np.fmod(x, y)'
|
|
elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'):
|
|
# Triton promotes 16-bit floating-point / and % to 32-bit because there
|
|
# are no native div or FRem operations on float16. Since we have to
|
|
# convert anyway, we may as well take the accuracy bump.
|
|
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
|
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
|
else:
|
|
numpy_expr = None
|
|
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
|
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
|
elif (op in ('%', '/') and
|
|
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
|
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
|
with pytest.raises(triton.CompilationError) as exc_info:
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
|
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
|
else:
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y",
|
|
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
|
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
|
)
|
|
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
|
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
|
# through to //, so we have to use a nonstandard expression to get a
|
|
# reference result for //.
|
|
expr = 'x // y'
|
|
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
|
|
|
|
|
def test_unsigned_name_mangling(device='cuda'):
|
|
# Test that uint32 and int32 are mangled differently by the compiler
|
|
SIZE = 128
|
|
# define the kernel / launch-grid
|
|
|
|
@triton.jit
|
|
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
|
|
off = tl.arange(0, SIZE)
|
|
x = tl.load(X + off)
|
|
y = tl.load(Y + off)
|
|
out1 = tl.abs(x) # uint32 -> nop
|
|
out2 = tl.abs(-y) # int32 -> should have an effect
|
|
tl.store(O1 + off, out1)
|
|
tl.store(O2 + off, out2)
|
|
|
|
dtype_x = 'uint32'
|
|
dtype_y = 'int32'
|
|
# inputs
|
|
rs = RandomState(17)
|
|
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
|
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
|
# reference result
|
|
expect = (np.abs(x), np.abs(-y))
|
|
# triton result
|
|
x_tri = to_triton(x, device=device, dst_type=dtype_x)
|
|
y_tri = to_triton(y, device=device, dst_type=dtype_y)
|
|
actual = tuple(
|
|
to_triton(np.empty_like(e), device=device)
|
|
for e in expect
|
|
)
|
|
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
|
|
|
|
# Bitwise op, so expect exact equality
|
|
assert (expect[0] == to_numpy(actual[0])).all()
|
|
assert (expect[1] == to_numpy(actual[1])).all()
|
|
|
|
|
|
# ---------------
|
|
# test bitwise ops
|
|
# ---------------
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
|
(dtype_x, dtype_y, op)
|
|
for op in ['&', '|', '^']
|
|
for dtype_x in dtypes + dtypes_with_bfloat16
|
|
for dtype_y in dtypes + dtypes_with_bfloat16
|
|
])
|
|
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
|
expr = f'x {op} y'
|
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
|
else:
|
|
numpy_expr = None
|
|
if 'float' in dtype_x + dtype_y:
|
|
with pytest.raises(triton.CompilationError) as exc_info:
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
|
# The CompilationError must have been caused by a C++ exception with this text.
|
|
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
|
else:
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
|
(dtype_x, dtype_y, op)
|
|
for op in ['<<', '>>']
|
|
for dtype_x in int_dtypes + uint_dtypes
|
|
for dtype_y in int_dtypes + uint_dtypes
|
|
])
|
|
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
|
expr = f'x {op} y'
|
|
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
|
if dtype_x.startswith('int'):
|
|
dtype_z = f'int{bw}'
|
|
else:
|
|
dtype_z = f'uint{bw}'
|
|
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
|
|
|
|
|
# ---------------
|
|
# test compare ops
|
|
# ---------------
|
|
ops = ['==', '!=', '>', '<', '>=', '<=']
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y",
|
|
# real
|
|
[
|
|
(dtype_x, dtype_y, op, 'real', 'real')
|
|
for op in ops
|
|
for dtype_x in dtypes
|
|
for dtype_y in dtypes
|
|
] +
|
|
# NaNs
|
|
[('float32', 'float32', op, mode_x, mode_y)
|
|
for op in ops
|
|
for mode_x, mode_y in [('nan', 'real'),
|
|
('real', 'nan'),
|
|
('nan', 'nan')]
|
|
|
|
])
|
|
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
|
expr = f'x {op} y'
|
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
|
else:
|
|
numpy_expr = None
|
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
|
|
|
|
|
# ---------------
|
|
# test broadcast
|
|
# ---------------
|
|
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
|
|
def test_broadcast(dtype):
|
|
@triton.jit
|
|
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
|
|
offset1 = tl.arange(0, M)
|
|
offset2 = tl.arange(0, N)
|
|
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
|
|
y = tl.load(y_ptr + offset2)
|
|
_, y_broadcasted = tl.broadcast(x, y)
|
|
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
|
|
|
|
M = 32
|
|
N = 64
|
|
rs = RandomState(17)
|
|
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
|
y = numpy_random(N, dtype_str=dtype, rs=rs)
|
|
_, y_broadcasted_np = np.broadcast_arrays(x, y)
|
|
|
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
|
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
|
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
|
|
|
|
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
|
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
|
|
|
|
|
# ----------------
|
|
# test expand_dims
|
|
# ----------------
|
|
def test_expand_dims():
|
|
@triton.jit
|
|
def expand_dims_kernel(dummy, N: tl.constexpr):
|
|
offset1 = tl.arange(0, N)
|
|
|
|
t = tl.expand_dims(offset1, 0)
|
|
tl.static_assert(t.shape == [1, N])
|
|
|
|
t = tl.expand_dims(offset1, 1)
|
|
tl.static_assert(t.shape == [N, 1])
|
|
|
|
t = tl.expand_dims(offset1, -1)
|
|
tl.static_assert(t.shape == [N, 1])
|
|
|
|
t = tl.expand_dims(offset1, -2)
|
|
tl.static_assert(t.shape == [1, N])
|
|
|
|
t = tl.expand_dims(offset1, (0, -1))
|
|
tl.static_assert(t.shape == [1, N, 1])
|
|
|
|
t = tl.expand_dims(offset1, (0, 1, 3))
|
|
tl.static_assert(t.shape == [1, 1, N, 1])
|
|
|
|
t = tl.expand_dims(offset1, (-4, 2, -1))
|
|
tl.static_assert(t.shape == [1, N, 1, 1])
|
|
|
|
t = tl.expand_dims(offset1, (3, 1, 2))
|
|
tl.static_assert(t.shape == [N, 1, 1, 1])
|
|
|
|
N = 32
|
|
dummy_tensor = torch.empty((), device="cuda")
|
|
expand_dims_kernel[(1,)](dummy_tensor, N)
|
|
|
|
|
|
def test_expand_dims_error_cases():
|
|
@triton.jit
|
|
def dim_out_of_range1(dummy, N: tl.constexpr):
|
|
offset1 = tl.arange(0, N)
|
|
|
|
t = tl.expand_dims(offset1, -2)
|
|
t = tl.expand_dims(offset1, -3)
|
|
|
|
@triton.jit
|
|
def dim_out_of_range2(dummy, N: tl.constexpr):
|
|
offset1 = tl.arange(0, N)
|
|
|
|
t = tl.expand_dims(offset1, 1)
|
|
t = tl.expand_dims(offset1, 2)
|
|
|
|
@triton.jit
|
|
def duplicate_dim1(dummy, N: tl.constexpr):
|
|
offset1 = tl.arange(0, N)
|
|
|
|
t = tl.expand_dims(offset1, (0, 0))
|
|
|
|
@triton.jit
|
|
def duplicate_dim2(dummy, N: tl.constexpr):
|
|
offset1 = tl.arange(0, N)
|
|
|
|
t = tl.expand_dims(offset1, (0, -3))
|
|
|
|
N = 32
|
|
dummy_tensor = torch.empty((), device="cuda")
|
|
|
|
with pytest.raises(triton.CompilationError, match="invalid axis -3"):
|
|
dim_out_of_range1[(1,)](dummy_tensor, N)
|
|
|
|
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
|
|
dim_out_of_range2[(1,)](dummy_tensor, N)
|
|
|
|
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
|
duplicate_dim1[(1,)](dummy_tensor, N)
|
|
|
|
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
|
duplicate_dim2[(1,)](dummy_tensor, N)
|
|
|
|
|
|
# ---------------
|
|
# test where
|
|
# ---------------
|
|
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
|
|
def test_where(dtype):
|
|
select_ptrs = False
|
|
if dtype == "*int32":
|
|
dtype = "int64"
|
|
select_ptrs = True
|
|
check_type_supported(dtype)
|
|
|
|
@triton.jit
|
|
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
TEST_POINTERS: tl.constexpr,
|
|
TEST_SCALAR_POINTERS: tl.constexpr):
|
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
decide = tl.load(cond_ptr + offsets, mask=mask)
|
|
if TEST_SCALAR_POINTERS:
|
|
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
|
output = tl.load(ptr + offsets, mask=mask)
|
|
else:
|
|
if TEST_POINTERS:
|
|
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
|
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
|
else:
|
|
a = tl.load(a_ptr + offsets, mask=mask)
|
|
b = tl.load(b_ptr + offsets, mask=mask)
|
|
output = tl.where(decide, a, b)
|
|
tl.store(output_ptr + offsets, output, mask=mask)
|
|
|
|
SIZE = 1_000
|
|
rs = RandomState(17)
|
|
cond = numpy_random(SIZE, 'bool', rs)
|
|
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
|
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
|
|
z = np.where(cond, x, y)
|
|
|
|
cond_tri = to_triton(cond, device='cuda')
|
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
|
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
|
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
|
|
|
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
|
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
|
assert (z == to_numpy(z_tri)).all()
|
|
if select_ptrs:
|
|
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
|
z = np.where(cond[0], x, y)
|
|
assert (z == to_numpy(z_tri)).all()
|
|
|
|
|
|
def test_where_broadcast():
|
|
@triton.jit
|
|
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
|
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
|
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
|
|
|
mask = tl.load(cond_ptr + yoffsets)
|
|
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
|
res = tl.where(mask, vals, 0.)
|
|
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
|
|
|
@triton.jit
|
|
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
|
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
|
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
|
mask = 0
|
|
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
|
res = tl.where(mask, vals, 0.)
|
|
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
|
|
|
SIZE = 32
|
|
dtype = 'float32'
|
|
rs = RandomState(17)
|
|
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
|
|
mask = numpy_random(SIZE, 'bool', rs=rs)
|
|
z = np.where(mask, x, 0)
|
|
cond_tri = to_triton(mask, device="cuda")
|
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
|
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
|
|
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
|
|
assert (z == to_numpy(z_tri)).all()
|
|
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
|
|
z = np.where(0, x, 0)
|
|
assert (z == to_numpy(z_tri)).all()
|
|
|
|
# ---------------
|
|
# test unary ops
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, expr", [
|
|
(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16
|
|
] + [
|
|
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
|
])
|
|
def test_unary_op(dtype_x, expr, device='cuda'):
|
|
_test_unary(dtype_x, expr, device=device)
|
|
|
|
# ----------------
|
|
# test math ops
|
|
# ----------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']])
|
|
def test_math_op(dtype_x, expr, device='cuda'):
|
|
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
|
|
|
|
# ----------------
|
|
# test abs
|
|
# ----------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x", [
|
|
(dtype_x)
|
|
for dtype_x in dtypes_with_bfloat16
|
|
])
|
|
def test_abs(dtype_x, device='cuda'):
|
|
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
|
|
|
|
|
|
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
|
def test_abs_f8(in_dtype):
|
|
|
|
@triton.jit
|
|
def abs_kernel(Z, X, SIZE: tl.constexpr):
|
|
off = tl.arange(0, SIZE)
|
|
x = tl.load(X + off)
|
|
z = tl.abs(x)
|
|
tl.store(Z + off, z)
|
|
|
|
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
|
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
|
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
|
f8_tensor[all_exp_ones] = 0
|
|
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
|
n_elements = f8_tensor.numel()
|
|
out_f8 = torch.empty_like(f8_tensor)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
|
abs_kernel[(1,)](f8, triton.reinterpret(out_f8, in_dtype), n_elements)
|
|
|
|
f32_tensor = convert_float_to_float32(f8_tensor, in_dtype)
|
|
expect = f32_tensor.abs()
|
|
actual_f8 = convert_float_to_float32(out_f8, in_dtype)
|
|
torch.testing.assert_allclose(expect, actual_f8)
|
|
|
|
|
|
# ----------------
|
|
# test indexing
|
|
# ----------------
|
|
|
|
|
|
def make_ptr_str(name, shape):
|
|
rank = len(shape)
|
|
offsets = []
|
|
stride = 1
|
|
for i in reversed(range(rank)):
|
|
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
|
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
|
stride *= shape[i]
|
|
return f"{name} + {' + '.join(offsets)}"
|
|
|
|
|
|
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
|
@pytest.mark.parametrize("expr, dtype_str", [
|
|
(f'x[{s}]', d)
|
|
for s in ['None, :', ':, None',
|
|
'None, :, :',
|
|
':, :, None']
|
|
for d in ['int32', 'uint32', 'uint16']
|
|
])
|
|
def test_index1d(expr, dtype_str, device='cuda'):
|
|
rank_x = expr.count(':')
|
|
rank_y = expr.count(',') + 1
|
|
shape_x = [32 for _ in range(rank_x)]
|
|
shape_z = [32 for _ in range(rank_y)]
|
|
shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)]
|
|
shape_z_dim_mismatch = [64 for _ in range(rank_y)]
|
|
|
|
# Triton kernel
|
|
@triton.jit
|
|
def kernel(Z, X, SIZE: tl.constexpr):
|
|
m = tl.arange(0, SIZE)
|
|
n = tl.arange(0, SIZE)
|
|
x = tl.load(X_PTR_EXPR)
|
|
z = GENERATE_TEST_HERE
|
|
tl.store(Z_PTR_EXPR, z)
|
|
|
|
def generate_kernel(shape_x, shape_z):
|
|
to_replace = {
|
|
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
|
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
|
|
'GENERATE_TEST_HERE': expr,
|
|
}
|
|
return patch_kernel(kernel, to_replace)
|
|
|
|
kernel_match = generate_kernel(shape_x, shape_z)
|
|
kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch)
|
|
kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch)
|
|
|
|
# torch result
|
|
x = numpy_random(shape_x, dtype_str=dtype_str)
|
|
y = np.zeros(shape_z, dtype=getattr(np, dtype_str))
|
|
z_ref = eval(expr) + y
|
|
# triton result
|
|
z_tri = to_triton(np.empty_like(z_ref), device=device)
|
|
x_tri = to_triton(x)
|
|
kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
|
# compare
|
|
assert (z_ref == to_numpy(z_tri)).all()
|
|
|
|
def catch_compilation_error(kernel):
|
|
try:
|
|
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
|
except triton.CompilationError as e:
|
|
np.testing.assert_(True)
|
|
except BaseException:
|
|
np.testing.assert_(False)
|
|
|
|
catch_compilation_error(kernel_dim_mismatch)
|
|
catch_compilation_error(kernel_rank_mismatch)
|
|
|
|
|
|
# ---------------
|
|
# test tuples
|
|
# ---------------
|
|
|
|
|
|
@triton.jit
|
|
def tuples_fn(a, b):
|
|
return a + b, \
|
|
a - b, \
|
|
a * b
|
|
|
|
|
|
def test_tuples():
|
|
device = 'cuda'
|
|
|
|
@triton.jit
|
|
def with_fn(X, Y, A, B, C):
|
|
x = tl.load(X)
|
|
y = tl.load(Y)
|
|
a, b, c = tuples_fn(x, y)
|
|
tl.store(A, a)
|
|
tl.store(B, b)
|
|
tl.store(C, c)
|
|
|
|
@triton.jit
|
|
def without_fn(X, Y, A, B, C):
|
|
x = tl.load(X)
|
|
y = tl.load(Y)
|
|
a, b, c = x + y, x - y, x * y
|
|
tl.store(A, a)
|
|
tl.store(B, b)
|
|
tl.store(C, c)
|
|
|
|
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
|
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
|
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
|
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
|
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
|
for kernel in [with_fn, without_fn]:
|
|
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
|
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
|
assert a_tri == a_ref
|
|
assert b_tri == b_ref
|
|
assert c_tri == c_ref
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_simple_fn(x, y, Z):
|
|
z = x + y
|
|
tl.store(Z, z)
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_call_graph_fn1(x):
|
|
return x + 1
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_call_graph_fn2(y):
|
|
return y + 2
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_call_graph_fn(x, y, Z):
|
|
t0 = noinline_call_graph_fn1(x)
|
|
t1 = noinline_call_graph_fn2(y)
|
|
z = t0 + t1
|
|
tl.store(Z, z)
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_shared_fn(x, y, Z):
|
|
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
|
z = tl.load(Z + offs)
|
|
z = tl.dot(z, z) + x + y
|
|
tl.store(Z + offs, z)
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_dynamic_fn(x, y, Z):
|
|
if x >= 1:
|
|
x = noinline_call_graph_fn1(x)
|
|
else:
|
|
x = noinline_call_graph_fn2(x)
|
|
if y >= 2:
|
|
y = noinline_call_graph_fn2(y)
|
|
else:
|
|
y = noinline_call_graph_fn1(y)
|
|
z = x + y
|
|
tl.store(Z, z)
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_call_multi_values_fn(x, y):
|
|
return x + 1, y + 2
|
|
|
|
|
|
@triton.jit(noinline=True)
|
|
def noinline_multi_values_fn(x, y, Z):
|
|
x, y = noinline_call_multi_values_fn(x, y)
|
|
z = x + y
|
|
tl.store(Z, z)
|
|
|
|
|
|
@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"])
|
|
def test_noinline(mode):
|
|
device = 'cuda'
|
|
|
|
@triton.jit
|
|
def kernel(X, Y, Z):
|
|
x = tl.load(X)
|
|
y = tl.load(Y)
|
|
GENERATE_TEST_HERE(x, y, Z)
|
|
|
|
func_name = f'noinline_{mode}_fn'
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name})
|
|
x = torch.tensor([1.0], device=device, dtype=torch.float32)
|
|
y = torch.tensor([2.0], device=device, dtype=torch.float32)
|
|
if mode == "shared":
|
|
z = torch.ones((16, 16), device=device, dtype=torch.float32)
|
|
else:
|
|
z = torch.tensor([0.0], device=device, dtype=torch.float32)
|
|
kernel[(1,)](x, y, z, num_warps=1)
|
|
if mode == "simple":
|
|
assert torch.equal(z, x + y)
|
|
elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values":
|
|
assert torch.equal(z, x + 1 + y + 2)
|
|
elif mode == "shared":
|
|
ref = torch.full((16, 16), 16, device=device, dtype=torch.float32)
|
|
assert torch.equal(z, ref + x + y)
|
|
|
|
|
|
# ---------------
|
|
# test atomics
|
|
# ---------------
|
|
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
|
[
|
|
('add', 'float16', mode),
|
|
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
|
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
|
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
|
]
|
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
|
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
|
capability = torch.cuda.get_device_capability()
|
|
if capability[0] < 7:
|
|
if dtype_x_str == 'float16':
|
|
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
|
|
n_programs = 5
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, Z):
|
|
pid = tl.program_id(0)
|
|
x = tl.load(X + pid)
|
|
old = GENERATE_TEST_HERE
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
|
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
|
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
|
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
|
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
|
|
|
|
# triton result
|
|
rs = RandomState(17)
|
|
x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str))
|
|
if mode == 'all_neg':
|
|
x = -np.abs(x)
|
|
if mode == 'all_pos':
|
|
x = np.abs(x)
|
|
if mode == 'min_neg':
|
|
idx = rs.randint(n_programs, size=(1, )).item()
|
|
x[idx] = -np.max(np.abs(x)) - 1
|
|
if mode == 'max_pos':
|
|
idx = rs.randint(n_programs, size=(1, )).item()
|
|
x[idx] = np.max(np.abs(x)) + 1
|
|
x_tri = to_triton(x, device=device)
|
|
|
|
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
|
kernel[(n_programs, )](x_tri, z_tri)
|
|
# torch result
|
|
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
|
# compare
|
|
exact = op not in ['add']
|
|
if exact:
|
|
assert z_ref.item() == to_numpy(z_tri).item()
|
|
else:
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
|
|
|
|
|
def test_atomic_rmw_predicate(device="cuda"):
|
|
@triton.jit
|
|
def kernel(X):
|
|
val = tl.program_id(0)
|
|
if val < 64:
|
|
tl.atomic_max(X, val)
|
|
x = torch.zeros((1,), device=device, dtype=torch.int32)
|
|
kernel[(4096,)](x)
|
|
assert x.item() == 63
|
|
|
|
|
|
@pytest.mark.parametrize("shape, axis",
|
|
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
|
|
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
|
shape0, shape1 = shape
|
|
# triton kernel
|
|
|
|
@triton.jit
|
|
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
|
off0 = tl.arange(0, SHAPE0)
|
|
off1 = tl.arange(0, SHAPE1)
|
|
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
|
z = tl.sum(x, axis=AXIS)
|
|
if AXIS == 1:
|
|
tl.atomic_add(Z + off0, z)
|
|
else:
|
|
tl.atomic_add(Z + off1, z)
|
|
rs = RandomState(17)
|
|
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
|
# reference result
|
|
z_ref = np.sum(x, axis=axis, keepdims=False)
|
|
# triton result
|
|
x_tri = to_triton(x, device=device)
|
|
z_shape = (shape0, ) if axis == 1 else (shape1, )
|
|
z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device)
|
|
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
|
|
|
|
|
def test_tensor_atomic_rmw_block(device="cuda"):
|
|
shape = (8, 8)
|
|
|
|
@triton.jit
|
|
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
|
off0 = tl.arange(0, SHAPE0)
|
|
off1 = tl.arange(0, SHAPE1)
|
|
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
|
val = offs.to(tl.float32)
|
|
x = X + offs
|
|
tl.atomic_min(x, val)
|
|
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
|
kernel[(2,)](x, shape[0], shape[1])
|
|
assert torch.min(x).item() == 0.0
|
|
|
|
|
|
def test_atomic_cas():
|
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
|
@triton.jit
|
|
def change_value(Lock):
|
|
tl.atomic_cas(Lock, 0, 1)
|
|
|
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
|
change_value[(1,)](Lock)
|
|
|
|
assert (Lock[0] == 1)
|
|
|
|
# 2. only one block enters the critical section
|
|
@triton.jit
|
|
def serialized_add(data, Lock):
|
|
ptrs = data + tl.arange(0, 128)
|
|
while tl.atomic_cas(Lock, 0, 1) == 1:
|
|
pass
|
|
|
|
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
|
|
|
# release lock
|
|
tl.atomic_xchg(Lock, 0)
|
|
|
|
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
|
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
|
ref = torch.full((128,), 64.0)
|
|
serialized_add[(64,)](data, Lock)
|
|
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
|
|
|
|
|
# ---------------
|
|
# test cast
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
|
(dtype_x, dtype_z, False)
|
|
for dtype_x in dtypes
|
|
for dtype_z in dtypes
|
|
] + [
|
|
('float32', 'bfloat16', False),
|
|
('bfloat16', 'float32', False),
|
|
('float32', 'int32', True),
|
|
('float32', 'int1', False),
|
|
] + [
|
|
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
|
] + [
|
|
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
|
])
|
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|
# bfloat16 on cc < 80 will not be tested
|
|
check_type_supported(dtype_x)
|
|
check_type_supported(dtype_z)
|
|
|
|
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
|
x0 = 43 if dtype_x in int_dtypes else 43.5
|
|
if dtype_x in float_dtypes and dtype_z == 'int1':
|
|
x0 = 0.5
|
|
if dtype_x.startswith('bfloat'):
|
|
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
|
else:
|
|
x = np.array([x0], dtype=getattr(np, dtype_x))
|
|
x_tri = to_triton(x)
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, Z, BITCAST: tl.constexpr):
|
|
x_ptr = X + tl.arange(0, 1)
|
|
z_ptr = Z + tl.arange(0, 1)
|
|
x = tl.load(x_ptr)
|
|
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
|
tl.store(z_ptr, z)
|
|
|
|
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
|
# triton result
|
|
if dtype_z.startswith('bfloat'):
|
|
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
|
else:
|
|
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
|
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
|
# torch result
|
|
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
|
assert bitcast is False
|
|
z_ref = x_tri.to(z_tri.dtype)
|
|
assert z_tri == z_ref
|
|
else:
|
|
if bitcast:
|
|
z_ref = x.view(getattr(np, dtype_z_np))
|
|
else:
|
|
z_ref = x.astype(getattr(np, dtype_z_np))
|
|
assert to_numpy(z_tri) == z_ref
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
|
|
def test_store_constant(dtype_str):
|
|
check_type_supported(dtype_str)
|
|
|
|
"""Tests that boolean True is stored as 1"""
|
|
@triton.jit
|
|
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
output = GENERATE_TEST_HERE
|
|
tl.store(output_ptr + offsets, output, mask=mask)
|
|
|
|
triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'})
|
|
block_size = 128
|
|
ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device='cuda')
|
|
output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device='cuda')
|
|
kernel[(1,)](output, block_size, BLOCK_SIZE=block_size)
|
|
|
|
assert torch.all(output == ref)
|
|
|
|
|
|
def test_load_store_same_ptr():
|
|
@triton.jit()
|
|
def kernel(in_out_ptr):
|
|
pid = tl.program_id(axis=0)
|
|
x = tl.load(in_out_ptr + pid)
|
|
out = x * 2
|
|
tl.store(in_out_ptr + pid, out)
|
|
|
|
for _ in range(1000):
|
|
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
|
kernel[(65536,)](x, num_warps=32)
|
|
assert torch.all(x == 2)
|
|
|
|
|
|
def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
|
if not dtype:
|
|
dtype = getattr(tl, torch_dtype_name(fp.dtype))
|
|
|
|
fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}"))
|
|
exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1
|
|
exp_bias = 2 ** (exp_width - 1) - 1
|
|
sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int()
|
|
exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int()
|
|
frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int()
|
|
|
|
output = torch.where(exp == 0,
|
|
# subnormal
|
|
((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)),
|
|
# normal
|
|
((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float()
|
|
|
|
extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width
|
|
# special cases, exp is 0b11..1
|
|
if dtype == tl.float8e4:
|
|
# float8e4m3 does not have infinities
|
|
output[fp == torch.tensor(0b01111111, dtype=torch.int8)] = torch.nan
|
|
output[fp == torch.tensor(0b11111111, dtype=torch.int8)] = torch.nan
|
|
else:
|
|
output = torch.where(exp == (1 << exp_width) - 1,
|
|
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
|
|
output)
|
|
return output
|
|
|
|
|
|
@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16])
|
|
def test_convert_float16_to_float32(in_dtype):
|
|
"""Tests that check convert_float_to_float32 function"""
|
|
check_type_supported(in_dtype)
|
|
|
|
f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype)
|
|
f32_output = convert_float_to_float32(f16_input)
|
|
|
|
nan = f16_input.isnan()
|
|
assert torch.all(f32_output[nan].isnan())
|
|
inf = f16_input.isinf()
|
|
assert torch.all(f32_output[inf].isinf())
|
|
other = torch.logical_not(torch.logical_or(nan, inf))
|
|
assert torch.all(f16_input[other] == f32_output[other])
|
|
|
|
|
|
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
|
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
|
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
|
check_type_supported(out_dtype)
|
|
|
|
@triton.jit
|
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
input = tl.load(input_ptr + offsets, mask=mask)
|
|
output = input
|
|
tl.store(output_ptr + offsets, output, mask=mask)
|
|
|
|
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
|
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
|
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
|
f8_tensor[all_exp_ones] = 0
|
|
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
|
n_elements = f8_tensor.numel()
|
|
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
|
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
|
|
|
# exponent_mask = 0b01111100 for float8e5
|
|
# exponent_mask = 0b01111000 for float8e4
|
|
exponent_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1)
|
|
normal = torch.logical_and((f8_tensor & exponent_mask) != 0, (f8_tensor & exponent_mask) != exponent_mask)
|
|
ref16 = convert_float_to_float32(f8_tensor, in_dtype)
|
|
# WARN: currently only normal float8s are handled
|
|
assert torch.all(xf16[normal] == ref16[normal])
|
|
|
|
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
|
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
|
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
|
|
|
assert torch.all(f8_tensor == f8_output_tensor)
|
|
|
|
|
|
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
|
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
|
def test_f16_to_f8_rounding(in_dtype, out_dtype):
|
|
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
|
error is the minimum over all float8.
|
|
Or the same explanation a bit mathier:
|
|
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
|
@triton.jit
|
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
input = tl.load(input_ptr + offsets, mask=mask)
|
|
output = input
|
|
tl.store(output_ptr + offsets, output, mask=mask)
|
|
|
|
i16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16, device='cuda')
|
|
f16_input = i16_input.view(out_dtype)
|
|
n_elements = f16_input.numel()
|
|
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
|
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
|
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
|
|
|
f16_output = torch.empty_like(f16_input, dtype=out_dtype)
|
|
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
|
|
|
abs_error = torch.abs(f16_input - f16_output)
|
|
|
|
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
|
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
|
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=out_dtype)
|
|
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
|
|
|
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
|
torch.isfinite(all_f8_vals_in_f16)
|
|
]
|
|
|
|
min_error = torch.min(
|
|
torch.abs(
|
|
f16_input.reshape((-1, 1))
|
|
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
|
),
|
|
dim=1,
|
|
)[0]
|
|
|
|
# WARN: only normalized numbers are handled
|
|
f8_normal_min = 1 << in_dtype.fp_mantissa_width # 0b00001000 for float8e4
|
|
f8_normal_max = 0b01111110 if in_dtype == tl.float8e4 else 0b01111011
|
|
f16_min, f16_max, f16_max_minus_1 = convert_float_to_float32(torch.tensor([f8_normal_min, f8_normal_max, f8_normal_max - 1], dtype=torch.int8), in_dtype)
|
|
assert torch.all(torch.isfinite(f16_min))
|
|
assert torch.all(torch.isfinite(f16_max))
|
|
thres_error = f16_max - f16_max_minus_1
|
|
mismatch = torch.logical_and(
|
|
torch.logical_or(abs_error != min_error, abs_error > thres_error), torch.logical_and(torch.isfinite(f16_input), torch.logical_and(torch.abs(f16_input) <= f16_max, torch.abs(f16_input) >= f16_min))
|
|
)
|
|
assert torch.all(
|
|
torch.logical_not(mismatch)
|
|
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
|
|
|
|
|
# ---------------
|
|
# test reduce
|
|
# ---------------
|
|
|
|
|
|
def get_reduced_dtype(dtype_str, op):
|
|
if op in ('argmin', 'argmax'):
|
|
return 'int32'
|
|
if dtype_str in ['int8', 'uint8', 'int16', 'uint16']:
|
|
return 'int32'
|
|
if dtype_str == 'bfloat16':
|
|
return 'float32'
|
|
return dtype_str
|
|
|
|
|
|
@pytest.mark.parametrize("op, dtype_str, shape",
|
|
[(op, dtype, shape)
|
|
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
|
for dtype in dtypes_with_bfloat16
|
|
for shape in [32, 64, 128, 512]])
|
|
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, Z, BLOCK: tl.constexpr):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
tl.store(Z, GENERATE_TEST_HERE)
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
|
# input
|
|
rs = RandomState(17)
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
|
x_tri = to_triton(x, device=device)
|
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
|
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
|
# numpy result
|
|
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
|
|
z_tri_dtype_str = z_dtype_str
|
|
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
|
z_dtype_str = 'float32'
|
|
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
|
# trunc mantissa for a fair comparison of accuracy
|
|
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
|
z_tri_dtype_str = 'bfloat16'
|
|
else:
|
|
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
|
# triton result
|
|
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs),
|
|
device=device, dst_type=z_tri_dtype_str)
|
|
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
|
z_tri = to_numpy(z_tri)
|
|
# compare
|
|
if op == 'sum':
|
|
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
|
else:
|
|
if op in ('argmin', 'argmax'):
|
|
# argmin and argmax can have multiple valid indices.
|
|
# so instead we compare the values pointed by indices
|
|
np.testing.assert_equal(x[z_ref], x[z_tri])
|
|
else:
|
|
np.testing.assert_equal(z_ref, z_tri)
|
|
|
|
|
|
# TODO: [Qingyi] Fix argmin / argmax
|
|
reduce_configs1 = [
|
|
(op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16
|
|
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
|
for axis in [1]
|
|
]
|
|
|
|
|
|
# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory
|
|
# exceeds the limit of 99KB
|
|
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
|
|
# TODO: fix and uncomment
|
|
# , (32, 64), (64, 128)]
|
|
if 'V100' in torch.cuda.get_device_name(0):
|
|
reduce2d_shapes += [(128, 256) and (32, 1024)]
|
|
|
|
|
|
reduce_configs2 = [
|
|
(op, 'float32', shape, axis)
|
|
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
|
for shape in reduce2d_shapes
|
|
for axis in [0, 1]
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
|
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
|
range_m = tl.arange(0, BLOCK_M)
|
|
range_n = tl.arange(0, BLOCK_N)
|
|
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
|
z = GENERATE_TEST_HERE
|
|
if AXIS == 1:
|
|
tl.store(Z + range_m, z)
|
|
else:
|
|
tl.store(Z + range_n, z)
|
|
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
|
# input
|
|
rs = RandomState(17)
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
|
x_tri = to_triton(x)
|
|
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
|
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
|
z_dtype_str = get_reduced_dtype(dtype_str, op)
|
|
z_tri_dtype_str = z_dtype_str
|
|
# numpy result
|
|
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
|
|
z_dtype_str = 'float32'
|
|
z_tri_dtype_str = 'bfloat16'
|
|
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
|
# trunc mantissa for a fair comparison of accuracy
|
|
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
|
|
else:
|
|
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
|
# triton result
|
|
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
|
device=device, dst_type=z_tri_dtype_str)
|
|
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
|
z_tri = to_numpy(z_tri)
|
|
# compare
|
|
if op == 'sum':
|
|
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
|
else:
|
|
if op in ('argmin', 'argmax'):
|
|
# argmin and argmax can have multiple valid indices.
|
|
# so instead we compare the values pointed by indices
|
|
z_ref_index = np.expand_dims(z_ref, axis=axis)
|
|
z_tri_index = np.expand_dims(z_tri, axis=axis)
|
|
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
|
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
|
np.testing.assert_equal(z_ref_value, z_tri_value)
|
|
else:
|
|
np.testing.assert_equal(z_ref, z_tri)
|
|
|
|
|
|
layouts = [
|
|
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
|
|
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
|
|
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128]])
|
|
@pytest.mark.parametrize("src_layout", layouts)
|
|
@pytest.mark.parametrize("axis", [0, 1])
|
|
def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
|
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
|
|
rdims_1d = f"{N}" if axis == 0 else f"{M}"
|
|
store_range = "%7" if axis == 0 else "%1"
|
|
ir = f"""
|
|
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}>
|
|
#src = {src_layout}
|
|
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
|
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
|
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
|
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
|
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
|
|
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
|
|
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #blocked>
|
|
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<f32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
|
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
|
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
|
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<f32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<f32>, #blocked>
|
|
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
|
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<f32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
|
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>
|
|
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
|
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked>
|
|
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src>
|
|
%15 = "tt.reduce"(%14) ({{
|
|
^bb0(%arg3: f32, %arg4: f32):
|
|
%16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1
|
|
%17 = arith.select %16, %arg3, %arg4 : f32
|
|
tt.reduce.return %17 : f32
|
|
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
|
|
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
|
|
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
|
|
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
|
|
tt.return
|
|
}}
|
|
}}
|
|
"""
|
|
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
|
f.write(ir)
|
|
f.flush()
|
|
kernel = triton.compile(f.name)
|
|
|
|
rs = RandomState(17)
|
|
x = rs.randint(0, 4, (M, N)).astype('float32')
|
|
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
|
|
|
if axis == 0:
|
|
z = np.zeros((1, N)).astype('float32')
|
|
else:
|
|
z = np.zeros((M, 1)).astype('float32')
|
|
|
|
x_tri = torch.tensor(x, device=device)
|
|
z_tri = torch.tensor(z, device=device)
|
|
|
|
pgm = kernel[(1, 1, 4)](x_tri, x_tri.stride(0), z_tri)
|
|
|
|
z_ref = np.max(x, axis=axis, keepdims=True)
|
|
|
|
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
|
|
|
|
|
layouts = [
|
|
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
|
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
|
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("M", [32, 64, 128, 256])
|
|
@pytest.mark.parametrize("src_layout", layouts)
|
|
def test_store_op(M, src_layout, device='cuda'):
|
|
ir = f"""
|
|
#src = {src_layout}
|
|
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
|
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
|
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
|
|
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
|
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
|
|
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
|
|
tt.store %8, %4 : tensor<{M}x1xf32, #src>
|
|
tt.return
|
|
}}
|
|
}}
|
|
"""
|
|
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
|
f.write(ir)
|
|
f.flush()
|
|
store_kernel = triton.compile(f.name)
|
|
|
|
rs = RandomState(17)
|
|
x = rs.randint(0, 4, (M, 1)).astype('float32')
|
|
y = np.zeros((M, 1), dtype='float32')
|
|
x_tri = torch.tensor(x, device=device)
|
|
y_tri = torch.tensor(y, device=device)
|
|
|
|
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
|
|
y_ref = x
|
|
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
|
|
|
|
|
layouts = [
|
|
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
|
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
|
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("M", [64, 128, 256])
|
|
@pytest.mark.parametrize("src_layout", layouts)
|
|
@pytest.mark.parametrize("dst_layout", layouts)
|
|
@pytest.mark.parametrize("src_dim", [0, 1])
|
|
@pytest.mark.parametrize("dst_dim", [0, 1])
|
|
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device='cuda'):
|
|
ir = f"""
|
|
#dst = {dst_layout}
|
|
#src = {src_layout}
|
|
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
|
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
|
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
|
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
|
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
|
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
|
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
|
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
|
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
|
%7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
|
tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
|
tt.return
|
|
}}
|
|
}}
|
|
"""
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
|
f.write(ir)
|
|
f.flush()
|
|
kernel = triton.compile(f.name)
|
|
|
|
rs = RandomState(17)
|
|
x = rs.randint(0, 4, (M, )).astype('int32')
|
|
y = np.zeros((M, ), dtype='int32')
|
|
x_tri = torch.tensor(x, device=device)
|
|
y_tri = torch.tensor(y, device=device)
|
|
pgm = kernel[(1, 1, 1)](x_tri, y_tri)
|
|
y_ref = x
|
|
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
|
|
|
|
|
@triton.jit
|
|
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
|
delta = mean_2 - mean_1
|
|
new_weight = weight_1 + weight_2
|
|
w2_over_w = weight_2 / new_weight
|
|
return (
|
|
mean_1 + delta * w2_over_w,
|
|
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
|
|
new_weight,
|
|
)
|
|
|
|
|
|
# layouts = [
|
|
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
|
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
|
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
|
|
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
|
|
# ]
|
|
|
|
|
|
# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
|
|
# @pytest.mark.parametrize("src_layout", layouts)
|
|
# def test_reduce_2d(M, N, src_layout, device='cuda'):
|
|
# ir = f"""
|
|
# #src = {src_layout}
|
|
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
|
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
|
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
|
|
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
|
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
|
|
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
|
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
|
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
|
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
|
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
|
|
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
|
|
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
|
|
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
|
|
# %11 = "tt.reduce"(%10) ({{
|
|
# ^bb0(%arg2: i32, %arg3: i32):
|
|
# %13 = arith.addi %arg2, %arg3 : i32
|
|
# tt.reduce.return %13 : i32
|
|
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
|
# %12 = "tt.reduce"(%11) ({{
|
|
# ^bb0(%arg2: i32, %arg3: i32):
|
|
# %13 = arith.addi %arg2, %arg3 : i32
|
|
# tt.reduce.return %13 : i32
|
|
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
|
|
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
|
|
# tt.return
|
|
# }}
|
|
# }}
|
|
# """
|
|
# import tempfile
|
|
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
|
# f.write(ir)
|
|
# f.flush()
|
|
# kernel = triton.compile(f.name)
|
|
#
|
|
# rs = RandomState(17)
|
|
# x = rs.randint(0, 4, (M, N)).astype('int32')
|
|
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
|
|
#
|
|
# z = np.zeros((1,)).astype('int32')
|
|
#
|
|
# x_tri = torch.tensor(x, device=device)
|
|
# z_tri = torch.tensor(z, device=device)
|
|
#
|
|
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
|
|
#
|
|
# z_ref = np.sum(x)
|
|
#
|
|
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
|
|
|
|
|
def test_generic_reduction(device='cuda'):
|
|
|
|
@triton.jit
|
|
def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr):
|
|
xindex = tl.arange(0, BLOCK)
|
|
x = tl.load(X + xindex)
|
|
mean = x
|
|
m2 = tl.zeros_like(x)
|
|
weight = tl.full(x.shape, 1, x.dtype)
|
|
(mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine)
|
|
tl.store(out_mean, mean)
|
|
tl.store(out_var, m2 / weight)
|
|
|
|
SIZE = 512
|
|
x = torch.rand(SIZE, device=device)
|
|
out_mean = torch.empty((), device=device)
|
|
out_var = torch.empty((), device=device)
|
|
|
|
var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE)
|
|
|
|
expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0)
|
|
torch.testing.assert_close(out_mean, expect_mean)
|
|
torch.testing.assert_close(out_var, expect_var)
|
|
|
|
|
|
# ---------------
|
|
# test permute
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str, shape, perm",
|
|
[(dtype, shape, perm)
|
|
# TODO: bfloat16
|
|
for dtype in ['float16', 'float32']
|
|
for shape in [(64, 64), (128, 128)]
|
|
for perm in [(1, 0)]])
|
|
def test_permute(dtype_str, shape, perm, device='cuda'):
|
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, stride_xm, stride_xn,
|
|
Z, stride_zm, stride_zn,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
|
off_m = tl.arange(0, BLOCK_M)
|
|
off_n = tl.arange(0, BLOCK_N)
|
|
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
|
tl.store(Zs, tl.load(Xs))
|
|
# input
|
|
x = numpy_random(shape, dtype_str=dtype_str)
|
|
# triton result
|
|
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
|
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
|
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
|
z_tri, z_tri.stride(1), z_tri.stride(0),
|
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
|
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
|
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
|
# numpy result
|
|
z_ref = x.transpose(*perm)
|
|
# compare
|
|
np.testing.assert_allclose(to_numpy(z_tri), z_ref)
|
|
np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref)
|
|
# parse ptx to make sure ld/st are vectorized
|
|
ptx = pgm.asm['ptx']
|
|
assert 'ld.global.v4' in ptx
|
|
assert 'st.global.v4' in ptx
|
|
ptx = pgm_contiguous.asm['ptx']
|
|
assert 'ld.global.v4' in ptx
|
|
assert 'st.global.v4' in ptx
|
|
|
|
# ---------------
|
|
# test dot
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
|
|
[(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
|
|
for shape in [(64, 64, 64), (16, 16, 16)]
|
|
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
|
for allow_tf32 in [True, False]
|
|
for in_dtype, out_dtype in [('float16', 'float16'),
|
|
('float16', 'float32'),
|
|
('float32', 'float32')]
|
|
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
|
|
|
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype)
|
|
for shape_nw in [[128, 256, 32, 8],
|
|
[128, 16, 32, 4],
|
|
[32, 128, 64, 4],
|
|
[128, 128, 64, 4],
|
|
[64, 128, 128, 4],
|
|
[32, 128, 64, 2],
|
|
[64, 64, 32, 4],
|
|
[32, 32, 128, 16],
|
|
[128, 128, 64, 2],
|
|
[64, 128, 128, 2]]
|
|
for allow_tf32 in [True]
|
|
for col_a in [True, False]
|
|
for col_b in [True, False]
|
|
for in_dtype, out_dtype in [('int8', 'int8'),
|
|
('float16', 'float16'),
|
|
('float16', 'float32'),
|
|
('float32', 'float32')]])
|
|
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'):
|
|
capability = torch.cuda.get_device_capability()
|
|
if capability[0] < 7:
|
|
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
|
if capability[0] < 8:
|
|
if in_dtype == 'int8':
|
|
pytest.skip("Only test int8 on devices with sm >= 80")
|
|
elif in_dtype == 'float32' and allow_tf32:
|
|
pytest.skip("Only test tf32 on devices with sm >= 80")
|
|
if capability[0] == 7:
|
|
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
|
pytest.skip("shared memory out of resource")
|
|
if out_dtype == 'float16':
|
|
# TODO: support out_dtype=float16 for tl.dot on V100
|
|
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
|
|
|
# triton kernel
|
|
@triton.jit
|
|
def kernel(X, stride_xm, stride_xk,
|
|
Y, stride_yk, stride_yn,
|
|
W, stride_wn, stride_wl,
|
|
Z, stride_zm, stride_zn,
|
|
out_dtype: tl.constexpr,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
|
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
|
ALLOW_TF32: tl.constexpr,
|
|
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
|
COL_A: tl.constexpr, COL_B: tl.constexpr):
|
|
off_m = tl.arange(0, BLOCK_M)
|
|
off_n = tl.arange(0, BLOCK_N)
|
|
off_l = tl.arange(0, BLOCK_N)
|
|
off_k = tl.arange(0, BLOCK_K)
|
|
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
|
|
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
|
|
Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl
|
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
|
x = tl.load(Xs)
|
|
y = tl.load(Ys)
|
|
z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
|
if ADD_MATRIX:
|
|
z += tl.load(Zs)
|
|
if ADD_ROWS:
|
|
ZRs = Z + off_m * stride_zm
|
|
z += tl.load(ZRs)[:, None]
|
|
if ADD_COLS:
|
|
ZCs = Z + off_n * stride_zn
|
|
z += tl.load(ZCs)[None, :]
|
|
if DO_SOFTMAX:
|
|
max = tl.max(z, 1)
|
|
z = z - max[:, None]
|
|
num = tl.exp(z.to(tl.float32)).to(max.dtype)
|
|
den = tl.sum(num, 1)
|
|
z = num / den[:, None]
|
|
if CHAIN_DOT:
|
|
w = tl.load(Ws)
|
|
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
|
|
tl.store(Zs, z)
|
|
# input
|
|
rs = RandomState(17)
|
|
if col_a:
|
|
x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T
|
|
else:
|
|
x = numpy_random((M, K), dtype_str=in_dtype, rs=rs)
|
|
if col_b:
|
|
y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T
|
|
else:
|
|
y = numpy_random((K, N), dtype_str=in_dtype, rs=rs)
|
|
w = numpy_random((N, N), dtype_str=in_dtype, rs=rs)
|
|
if 'int' not in in_dtype:
|
|
x *= .1
|
|
y *= .1
|
|
if in_dtype == 'float32' and allow_tf32:
|
|
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
|
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
|
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
|
x_tri = to_triton(x, device=device)
|
|
y_tri = to_triton(y, device=device)
|
|
w_tri = to_triton(w, device=device)
|
|
# triton result
|
|
if out_dtype == 'int8':
|
|
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
|
|
else:
|
|
z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1
|
|
|
|
z_tri = to_triton(z, device=device)
|
|
if epilogue == 'trans':
|
|
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
|
|
|
if out_dtype == 'int8':
|
|
out_dtype = tl.int8
|
|
elif out_dtype == 'float16' and epilogue != 'softmax':
|
|
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
|
|
# fail with the following error: 'llvm.fmul' op requires the same type
|
|
# for all operands and results
|
|
out_dtype = tl.float16
|
|
else:
|
|
out_dtype = tl.float32
|
|
|
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
|
y_tri, y_tri.stride(0), y_tri.stride(1),
|
|
w_tri, w_tri.stride(0), w_tri.stride(1),
|
|
z_tri, z_tri.stride(0), z_tri.stride(1),
|
|
out_dtype,
|
|
COL_A=col_a, COL_B=col_b,
|
|
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
|
ADD_MATRIX=epilogue == 'add-matrix',
|
|
ADD_ROWS=epilogue == 'add-rows',
|
|
ADD_COLS=epilogue == 'add-cols',
|
|
DO_SOFTMAX=epilogue == 'softmax',
|
|
CHAIN_DOT=epilogue == 'chain-dot',
|
|
ALLOW_TF32=allow_tf32,
|
|
num_warps=num_warps)
|
|
# torch result
|
|
if in_dtype == 'int8':
|
|
z_ref = np.matmul(x.astype(np.float32),
|
|
y.astype(np.float32())).astype(np.int32)
|
|
else:
|
|
z_ref = np.matmul(x, y)
|
|
|
|
if epilogue == 'add-matrix':
|
|
z_ref += z
|
|
if epilogue == 'add-rows':
|
|
z_ref += z[:, 0][:, None]
|
|
if epilogue == 'add-cols':
|
|
z_ref += z[0, :][None, :]
|
|
if epilogue == 'softmax':
|
|
num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True))
|
|
denom = np.sum(num, axis=-1, keepdims=True)
|
|
z_ref = num / denom
|
|
if epilogue == 'chain-dot':
|
|
z_ref = np.matmul(z_ref, w)
|
|
# compare
|
|
# print(z_ref[:,0], z_tri[:,0])
|
|
if in_dtype == 'float32':
|
|
# XXX: Somehow there's a larger difference when we use float32
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
|
elif out_dtype == tl.float16:
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
|
else:
|
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
|
# make sure ld/st are vectorized
|
|
ptx = pgm.asm['ptx']
|
|
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
|
|
# XXX: skip small sizes because they are not vectorized
|
|
assert 'ld.global.v4' in ptx
|
|
assert 'st.global.v4' in ptx
|
|
if in_dtype == 'float32' and allow_tf32:
|
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
|
elif in_dtype == 'float32' and allow_tf32:
|
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
|
elif in_dtype == 'int8':
|
|
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
|
elif out_dtype == tl.float16:
|
|
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
|
def test_full(dtype_str):
|
|
dtype = getattr(torch, dtype_str)
|
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
|
|
|
@triton.jit
|
|
def kernel_static(out):
|
|
a = GENERATE_TEST_HERE
|
|
out_ptr = out + tl.arange(0, 128)[:]
|
|
tl.store(out_ptr, a)
|
|
|
|
@triton.jit
|
|
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
|
a = tl.full((128,), val, dtype)
|
|
out_ptr = out + tl.arange(0, 128)[:]
|
|
tl.store(out_ptr, a)
|
|
|
|
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
|
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
|
kernel_static_patched[(1,)](out_static)
|
|
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
|
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
|
assert torch.all(out_static == 2)
|
|
assert torch.all(out_dynamic == 2)
|
|
|
|
|
|
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
|
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
|
# def test_dot_without_load(dtype_str):
|
|
# @triton.jit
|
|
# def _kernel(out):
|
|
# a = GENERATE_TEST_HERE
|
|
# b = GENERATE_TEST_HERE
|
|
# c = tl.dot(a, b)
|
|
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
|
# tl.store(out_ptr, c)
|
|
|
|
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
|
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
|
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
|
# out_ref = torch.matmul(a, b)
|
|
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
|
# kernel[(1,)](out)
|
|
# assert torch.all(out == out_ref)
|
|
|
|
# ---------------
|
|
# test arange
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
|
def test_arange(start, device='cuda'):
|
|
BLOCK = 128
|
|
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
|
|
|
@triton.jit
|
|
def _kernel(z, BLOCK: tl.constexpr,
|
|
START: tl.constexpr, END: tl.constexpr):
|
|
off = tl.arange(0, BLOCK)
|
|
val = tl.arange(START, END)
|
|
tl.store(z + off, val)
|
|
_kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK)
|
|
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
|
np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref))
|
|
|
|
# ---------------
|
|
# test load
|
|
# ---------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
|
|
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
|
dtype = getattr(torch, dtype_str)
|
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
|
|
|
input_size = size - size_diff
|
|
output_size = size
|
|
if dtype_str == 'bool':
|
|
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
|
|
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
|
|
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
|
|
else:
|
|
input = torch.rand(input_size, dtype=dtype, device=device)
|
|
output = torch.zeros((output_size,), dtype=dtype, device=device)
|
|
|
|
@triton.jit
|
|
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
|
in_offsets = tl.arange(0, out_size)
|
|
# Load inputs.
|
|
x = GENERATE_TEST_HERE
|
|
# Store output
|
|
output_offsets = tl.arange(0, out_size)
|
|
tl.store(out_ptr + output_offsets, x)
|
|
|
|
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
|
|
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
|
|
kernel[(1,)](input, output, input_size, output_size)
|
|
|
|
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
|
# print((output - reference_out).nonzero())
|
|
torch.testing.assert_allclose(output, reference_out)
|
|
|
|
# Testing masked loads with an intermate copy to shared memory run.
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
|
def test_masked_load_shared_memory(dtype, device='cuda'):
|
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
|
|
|
M = 32
|
|
N = 32
|
|
K = 16
|
|
|
|
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
|
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
|
out = torch.zeros((M, N), dtype=dtype, device=device)
|
|
|
|
@triton.jit
|
|
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
|
in_stride, in2_stride, out_stride,
|
|
in_numel, in2_numel, out_numel,
|
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
|
|
|
|
M_offsets = tl.arange(0, M)
|
|
N_offsets = tl.arange(0, N)
|
|
K_offsets = tl.arange(0, K)
|
|
|
|
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
|
|
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
|
|
|
|
# Load inputs.
|
|
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
|
|
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
|
|
|
|
# Without a dot product the memory doesn't get promoted to shared.
|
|
o = tl.dot(x, w, out_dtype=tl.float32)
|
|
|
|
# Store output
|
|
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
|
|
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
|
|
|
|
pgm = _kernel[(1,)](in1, in2, out,
|
|
in1.stride()[0],
|
|
in2.stride()[0],
|
|
out.stride()[0],
|
|
in1.numel(),
|
|
in2.numel(),
|
|
out.numel(),
|
|
M=M, N=N, K=K)
|
|
|
|
reference_out = torch.matmul(in1, in2)
|
|
torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
|
def test_load_cache_modifier(cache):
|
|
src = torch.empty(128, device='cuda')
|
|
dst = torch.empty(128, device='cuda')
|
|
|
|
@triton.jit
|
|
def _kernel(dst, src, CACHE: tl.constexpr):
|
|
offsets = tl.arange(0, 128)
|
|
x = tl.load(src + offsets, cache_modifier=CACHE)
|
|
tl.store(dst + offsets, x)
|
|
|
|
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
|
ptx = pgm.asm['ptx']
|
|
if cache == '':
|
|
assert 'ld.global.ca' not in ptx
|
|
assert 'ld.global.cg' not in ptx
|
|
if cache == '.cg':
|
|
assert 'ld.global.cg' in ptx
|
|
assert 'ld.global.ca' not in ptx
|
|
if cache == '.ca':
|
|
assert 'ld.global.ca' in ptx
|
|
assert 'ld.global.cg' not in ptx
|
|
|
|
|
|
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
|
def test_vectorization(N):
|
|
src = torch.empty(1024, device='cuda')
|
|
dst = torch.empty(1024, device='cuda')
|
|
|
|
@triton.jit
|
|
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(src + offsets, mask=offsets < N)
|
|
tl.store(dst + offsets, x, mask=offsets < N)
|
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
|
ptx = pgm.asm["ptx"]
|
|
if N % 16 == 0:
|
|
assert "ld.global.v4.b32" in ptx
|
|
else:
|
|
assert "ld.global.b32" in ptx
|
|
# np.testing.assert_allclose(dst, src[:N])
|
|
|
|
|
|
@pytest.mark.parametrize("has_hints", [False, True])
|
|
def test_vectorization_hints(has_hints):
|
|
src = torch.empty(1024, device='cuda')
|
|
dst = torch.empty(1024, device='cuda')
|
|
off = torch.zeros(1, device='cuda', dtype=torch.int32)
|
|
|
|
@triton.jit
|
|
def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offsets = offsets + tl.load(off)
|
|
if HINT:
|
|
tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024)
|
|
x = tl.load(src + offsets, mask=offsets < N)
|
|
tl.store(dst + offsets, x, mask=offsets < N)
|
|
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
|
|
ptx = pgm.asm["ptx"]
|
|
if has_hints:
|
|
assert "ld.global.v4.b32" in ptx
|
|
else:
|
|
assert "ld.global.v4.b32" not in ptx
|
|
|
|
# ---------------
|
|
# test store
|
|
# ---------------
|
|
|
|
# ---------------
|
|
# test if
|
|
# ---------------
|
|
|
|
# ---------------
|
|
# test for
|
|
# ---------------
|
|
|
|
# ---------------
|
|
# test while
|
|
# ---------------
|
|
|
|
# ---------------
|
|
# test default
|
|
# ---------------
|
|
# TODO: can't be local to test_default
|
|
|
|
|
|
@triton.jit
|
|
def _impl(value=10):
|
|
return value
|
|
|
|
|
|
def test_default():
|
|
value = 5
|
|
ret0 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
|
ret1 = torch.zeros(1, dtype=torch.int32, device='cuda')
|
|
|
|
@triton.jit
|
|
def _kernel(ret0, ret1, value):
|
|
tl.store(ret0, _impl())
|
|
tl.store(ret1, _impl(value))
|
|
|
|
_kernel[(1,)](ret0, ret1, value)
|
|
assert ret0.item() == 10
|
|
assert ret1.item() == value
|
|
|
|
# ---------------
|
|
# test noop
|
|
# ----------------
|
|
|
|
|
|
def test_noop(device='cuda'):
|
|
@triton.jit
|
|
def kernel(x):
|
|
pass
|
|
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
|
kernel[(1, )](x)
|
|
|
|
|
|
@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned'])
|
|
def test_pointer_arguments(device):
|
|
@triton.jit
|
|
def kernel(x):
|
|
pass
|
|
pin_memory = 'pinned' in device
|
|
x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory)
|
|
if device == "cpu":
|
|
with pytest.raises(ValueError):
|
|
kernel[(1,)](x)
|
|
else:
|
|
kernel[(1, )](x)
|
|
|
|
|
|
@pytest.mark.parametrize("value, value_type", [
|
|
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
|
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
|
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
|
])
|
|
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
|
spec_type = None
|
|
|
|
def cache_hook(*args, **kwargs):
|
|
nonlocal spec_type
|
|
spec_type = kwargs["compile"]["signature"][0]
|
|
JITFunction.cache_hook = cache_hook
|
|
|
|
@triton.jit
|
|
def kernel(VALUE, X):
|
|
pass
|
|
|
|
x = torch.tensor([3.14159], device='cuda')
|
|
pgm = kernel[(1, )](value, x)
|
|
|
|
JITFunction.cache_hook = None
|
|
assert spec_type == value_type
|
|
|
|
# --------------------
|
|
# value specialization
|
|
# --------------------
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"value, overflow",
|
|
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
|
)
|
|
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
|
|
|
@triton.jit
|
|
def kernel(VALUE, X):
|
|
pass
|
|
|
|
x = torch.tensor([3.14159], device='cuda')
|
|
|
|
if overflow:
|
|
with pytest.raises(OverflowError):
|
|
kernel[(1, )](value, x)
|
|
else:
|
|
kernel[(1, )](value, x)
|
|
|
|
|
|
# ----------------
|
|
# test constexpr
|
|
# ----------------
|
|
|
|
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|'])
|
|
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
|
|
@pytest.mark.parametrize("is_rhs_constexpr", [True, False])
|
|
def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr):
|
|
|
|
@triton.jit
|
|
def kernel(Z, X, Y):
|
|
x = tl.load(X)
|
|
y = tl.load(Y)
|
|
z = GENERATE_TEST_HERE
|
|
tl.store(Z, z)
|
|
|
|
if op in ['<<', '>>', '&', '^', '|']: # int op
|
|
x_str = "3" if is_lhs_constexpr else "x"
|
|
y_str = "4" if is_rhs_constexpr else "y"
|
|
x = numpy_random((1,), dtype_str="int32")
|
|
y = numpy_random((1,), dtype_str="int32")
|
|
else:
|
|
x_str = "3.14" if is_lhs_constexpr else "x"
|
|
y_str = "4.13" if is_rhs_constexpr else "y"
|
|
x = numpy_random((1,), dtype_str="float32")
|
|
y = numpy_random((1,), dtype_str="float32")
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"})
|
|
z = np.array(eval(f"{x_str} {op} {y_str}"))
|
|
x_tri = to_triton(x)
|
|
y_tri = to_triton(y)
|
|
z_tri = to_triton(np.empty((1,), dtype=z.dtype))
|
|
kernel[(1,)](z_tri, x_tri, y_tri)
|
|
np.testing.assert_allclose(z, to_numpy(z_tri))
|
|
|
|
|
|
def test_constexpr_shape():
|
|
|
|
@triton.jit
|
|
def kernel(X):
|
|
off = tl.arange(0, 128 + 128)
|
|
tl.store(X + off, off)
|
|
|
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
|
kernel[(1,)](x_tri)
|
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
|
|
|
|
|
def test_constexpr_scalar_shape():
|
|
|
|
@triton.jit
|
|
def kernel(X, s):
|
|
off = tl.arange(0, 256)
|
|
val = off % (256 // s)
|
|
tl.store(X + off, val)
|
|
|
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
|
kernel[(1,)](x_tri, 32)
|
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
|
|
|
# -------------
|
|
# test call
|
|
# -------------
|
|
|
|
|
|
@triton.jit
|
|
def val_multiplier(val, i):
|
|
return val * i
|
|
|
|
|
|
@triton.jit
|
|
def vecmul_kernel(ptr, n_elements, rep):
|
|
pid = tl.program_id(axis=0)
|
|
offsets = pid * 128 + tl.arange(0, 128)
|
|
mask = offsets < n_elements
|
|
vec = tl.load(ptr + offsets, mask=mask)
|
|
for i in range(1, rep):
|
|
vec = val_multiplier(vec, i)
|
|
tl.store(ptr + offsets, vec, mask=mask)
|
|
|
|
|
|
def test_call():
|
|
|
|
@triton.jit
|
|
def kernel(ptr, n_elements, num1, num2):
|
|
vecmul_kernel(ptr, n_elements, num1)
|
|
vecmul_kernel(ptr, n_elements, num2)
|
|
|
|
size = 1024
|
|
rand_val = numpy_random((size,), dtype_str="float32")
|
|
rand_val_tri = to_triton(rand_val, device='cuda')
|
|
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
|
|
|
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
|
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
|
|
|
# -------------
|
|
# test if
|
|
# -------------
|
|
|
|
|
|
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
|
|
def test_if(if_type):
|
|
|
|
@triton.jit
|
|
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
cond = tl.load(Cond)
|
|
if IfType == "if":
|
|
if pid % 2:
|
|
tl.store(Ret, tl.load(XTrue))
|
|
else:
|
|
tl.store(Ret, tl.load(XFalse))
|
|
else:
|
|
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
|
|
|
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
|
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
|
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
|
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
|
kernel[(1,)](cond, x_true, x_false, ret, if_type)
|
|
|
|
|
|
def test_num_warps_pow2():
|
|
dst = torch.empty(128, device='cuda')
|
|
|
|
@triton.jit
|
|
def _kernel(dst):
|
|
pass
|
|
|
|
with pytest.raises(AssertionError, match='must be a power of 2'):
|
|
_kernel[(1,)](dst=dst, num_warps=3)
|
|
_kernel[(1,)](dst=dst, num_warps=1)
|
|
_kernel[(1,)](dst=dst, num_warps=2)
|
|
_kernel[(1,)](dst=dst, num_warps=4)
|
|
|
|
# -------------
|
|
# test extern
|
|
# -------------
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
|
[('int32', 'math.ffs', ''),
|
|
('float32', 'math.log2', ''),
|
|
('float32', 'math.scalbn', ''),
|
|
('float32', 'math.pow', tl.math.libdevice_path()),
|
|
('float64', 'math.pow_dtype', tl.math.libdevice_path()),
|
|
('float64', 'math.norm4d', '')])
|
|
def test_math_tensor(dtype_str, expr, lib_path):
|
|
|
|
@triton.jit
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
x = tl.load(X + tl.arange(0, BLOCK))
|
|
y = GENERATE_TEST_HERE
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
|
|
shape = (128, )
|
|
rs = RandomState(17)
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
|
|
|
if expr == 'math.log2':
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.broadcast_to(tl.{expr}(5.0), x.shape)'})
|
|
y_ref = np.log2(5.0)
|
|
elif expr == 'math.ffs':
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{expr}(x)'})
|
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
|
for i in range(shape[0]):
|
|
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
|
|
elif expr == 'math.scalbn':
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{expr}(x, 2)'})
|
|
y_ref = x * pow(2, 2)
|
|
elif expr == 'math.pow_dtype':
|
|
x = np.abs(x)
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.math.pow(x, 0.5)'})
|
|
y_ref = np.power(x, 0.5)
|
|
elif expr == 'math.pow':
|
|
# numpy does not allow negative factors in power, so we use abs()
|
|
x = np.abs(x)
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{expr}(x, x)'})
|
|
y_ref = np.power(x, x)
|
|
elif expr == 'math.pow_dtype':
|
|
x = np.abs(x)
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, 0.5)'})
|
|
y_ref = np.power(x, 0.5)
|
|
elif expr == 'math.norm4d':
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{expr}(x, x, x, x)'})
|
|
y_ref = np.sqrt(4 * np.power(x, 2))
|
|
|
|
x_tri = to_triton(x)
|
|
# triton result
|
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
|
# compare
|
|
if expr == 'math.ffs':
|
|
np.testing.assert_equal(y_ref, to_numpy(y_tri))
|
|
else:
|
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype_str, expr, lib_path",
|
|
[('float32', 'math.pow', ''),
|
|
('float64', 'math.pow_dtype', ''),
|
|
('float64', 'math.pow', tl.math.libdevice_path())])
|
|
def test_math_scalar(dtype_str, expr, lib_path):
|
|
|
|
@triton.jit
|
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
|
x = X
|
|
y = GENERATE_TEST_HERE
|
|
tl.store(Y + tl.arange(0, BLOCK), y)
|
|
|
|
shape = (128, )
|
|
rs = RandomState(17)
|
|
# limit the range of integers so that the sum does not overflow
|
|
x = numpy_random((1,), dtype_str=dtype_str, rs=rs)
|
|
y_ref = np.zeros(shape, dtype=x.dtype)
|
|
|
|
# numpy does not allow negative factors in power, so we use abs()
|
|
if expr == 'math.pow':
|
|
x = np.abs(x)
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, x)'})
|
|
y_ref[:] = np.power(x, x)
|
|
elif expr == 'math.pow_dtype':
|
|
x = np.abs(x)
|
|
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, 0.5)'})
|
|
y_ref[:] = np.power(x, 0.5)
|
|
|
|
# triton result
|
|
x_tri = to_triton(x)[0].item()
|
|
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
|
|
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
|
|
# compare
|
|
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
|
|
|
|
# -----------------------
|
|
# test control flow
|
|
# -----------------------
|
|
|
|
|
|
@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3),
|
|
(15, -16, -1), (15, -16, -2), (15, -16, -3),
|
|
(-18, -22, -1), (22, 18, -1)])
|
|
def test_for_iv(lo, hi, iv):
|
|
|
|
@triton.jit
|
|
def kernel(Out, lo, hi, iv: tl.constexpr):
|
|
acc = 0
|
|
acc = acc.to(tl.int64)
|
|
for i in range(lo, hi, iv):
|
|
acc += i
|
|
tl.store(Out, acc)
|
|
|
|
lo = 2**35
|
|
hi = 2**35 + 20
|
|
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
|
kernel[(1,)](out, lo, hi, iv)
|
|
assert out[0] == sum(range(lo, hi, iv))
|
|
|
|
|
|
def test_if_else():
|
|
|
|
@triton.jit
|
|
def kernel(Cond, TrueVal, FalseVal, Out):
|
|
if tl.load(Cond):
|
|
val = tl.load(TrueVal)
|
|
else:
|
|
val = tl.load(FalseVal)
|
|
tl.store(Out, val)
|
|
|
|
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
true_val = to_triton(np.full((1,), 1, dtype=np.int32), device='cuda')
|
|
false_val = to_triton(np.full((1,), 2, dtype=np.int32), device='cuda')
|
|
cond = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
# True
|
|
cond[0] = True
|
|
kernel[(1,)](cond, true_val, false_val, out)
|
|
assert to_numpy(out)[0] == true_val[0]
|
|
# False
|
|
cond[0] = False
|
|
kernel[(1,)](cond, true_val, false_val, out)
|
|
assert to_numpy(out)[0] == false_val[0]
|
|
|
|
|
|
def test_if_return():
|
|
|
|
@triton.jit
|
|
def kernel(ExitEarly, Out):
|
|
if tl.load(ExitEarly):
|
|
tl.store(Out, 0)
|
|
return
|
|
tl.store(Out, 1)
|
|
|
|
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
# exit early path taken
|
|
exit_early[0] = 1
|
|
kernel[(1,)](exit_early, out)
|
|
assert to_numpy(out)[0] == 0
|
|
# exit early path not taken
|
|
exit_early[0] = 0
|
|
kernel[(1,)](exit_early, out)
|
|
assert to_numpy(out)[0] == 1
|
|
|
|
|
|
@triton.jit
|
|
def add_fn(x):
|
|
return x + 1
|
|
|
|
|
|
@pytest.mark.parametrize("call_type", ["attribute", "jit_function"])
|
|
def test_if_call(call_type):
|
|
@triton.jit
|
|
def kernel(Out, call_type: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
o = tl.load(Out)
|
|
if pid == 0:
|
|
if call_type == "attribute":
|
|
a = o + 1
|
|
a = a.to(tl.int32)
|
|
o = a
|
|
else:
|
|
a = o
|
|
a = add_fn(a)
|
|
o = a
|
|
tl.store(Out, o)
|
|
|
|
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
kernel[(1,)](out, call_type)
|
|
assert to_numpy(out)[0] == 1
|
|
|
|
|
|
@pytest.mark.parametrize("_cond1", [True, False])
|
|
@pytest.mark.parametrize("_cond2", [True, False])
|
|
@pytest.mark.parametrize("_cond3", [True, False])
|
|
def test_nested_if_else_return(_cond1, _cond2, _cond3):
|
|
|
|
@triton.jit
|
|
def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out):
|
|
val = 0
|
|
if tl.load(Cond1):
|
|
if tl.load(Cond2):
|
|
val = tl.load(Val1)
|
|
else:
|
|
return
|
|
else:
|
|
if tl.load(Cond3):
|
|
val = tl.load(Val2)
|
|
else:
|
|
val = tl.load(Val3)
|
|
tl.store(Out, val)
|
|
|
|
out = to_triton(np.full((1,), -1, dtype=np.int32), device='cuda')
|
|
cond1 = to_triton(np.full((1,), _cond1, dtype=np.int32), device='cuda')
|
|
cond2 = to_triton(np.full((1,), _cond2, dtype=np.int32), device='cuda')
|
|
cond3 = to_triton(np.full((1,), _cond3, dtype=np.int32), device='cuda')
|
|
val1 = to_triton(np.full((1,), 1, dtype=np.int32), device='cuda')
|
|
val2 = to_triton(np.full((1,), 2, dtype=np.int32), device='cuda')
|
|
val3 = to_triton(np.full((1,), 3, dtype=np.int32), device='cuda')
|
|
kernel[(1,)](cond1, cond2, cond3, val1, val2, val3, out)
|
|
targets = {
|
|
(True, True, True): val1[0],
|
|
(True, True, False): val1[0],
|
|
(True, False, True): out[0],
|
|
(True, False, False): out[0],
|
|
(False, True, True): val2[0],
|
|
(False, True, False): val3[0],
|
|
(False, False, True): val2[0],
|
|
(False, False, False): val3[0],
|
|
}
|
|
assert out[0] == targets[(_cond1, _cond2, _cond3)]
|
|
|
|
|
|
def test_while():
|
|
|
|
@triton.jit
|
|
def kernel(InitI, Bound, CutOff, OutI, OutJ):
|
|
init_i = tl.load(InitI)
|
|
curr_i = init_i
|
|
j = 0
|
|
while curr_i == init_i and j < tl.load(Bound):
|
|
curr_i = curr_i + (j == tl.load(CutOff))
|
|
j += 1
|
|
tl.store(OutI, curr_i)
|
|
tl.store(OutJ, j)
|
|
|
|
out_i = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
out_j = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
init_i = to_triton(np.full((1,), 1, dtype=np.int32), device='cuda')
|
|
bound = to_triton(np.full((1,), 10, dtype=np.int32), device='cuda')
|
|
cut_off = to_triton(np.full((1,), 5, dtype=np.int32), device='cuda')
|
|
kernel[(1,)](init_i, bound, cut_off, out_i, out_j)
|
|
assert out_i[0] == init_i[0] + 1
|
|
assert out_j[0] == cut_off[0] + 1
|
|
|
|
# def test_for_if():
|
|
|
|
# @triton.jit
|
|
# def kernel(bound, cutoff, M, N):
|
|
# m = 0
|
|
# n = 0
|
|
# for i in range(bound):
|
|
# if i > cutoff:
|
|
# m = m + 1
|
|
# else:
|
|
# n = n + 1
|
|
# tl.store(M, m)
|
|
# tl.store(N, n)
|
|
|
|
# m = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
# n = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
|
# kernel[(1,)](10, 7, m, n)
|
|
# print(m[0])
|
|
# print(n[0])
|
|
|
|
# -----------------------
|
|
# test extra
|
|
# -----------------------
|
|
|
|
|
|
def test_globaltimer():
|
|
|
|
@triton.jit
|
|
def kernel(Out1, Out2):
|
|
start = tl.extra.cuda.globaltimer()
|
|
off = tl.arange(0, 128)
|
|
for i in range(100):
|
|
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
|
|
end = tl.extra.cuda.globaltimer()
|
|
tl.store(Out2, end - start)
|
|
|
|
out1 = to_triton(np.zeros((128,), dtype=np.int64), device='cuda')
|
|
out2 = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
|
h = kernel[(1,)](out1, out2)
|
|
assert out2[0] > 0
|
|
# 2 inlined globaltimers + one extra in the wrapper extern function
|
|
assert h.asm["ptx"].count("%globaltimer") == 3
|
|
|
|
|
|
def test_smid():
|
|
|
|
@triton.jit
|
|
def kernel(Out):
|
|
tl.store(Out + tl.program_id(0), tl.extra.cuda.smid())
|
|
|
|
out = to_triton(np.zeros((1024,), dtype=np.int32), device='cuda')
|
|
h = kernel[(out.shape[0],)](out)
|
|
assert out.sort()[0].unique().shape[0] > 0
|
|
assert h.asm["ptx"].count("%smid") == 2
|
|
|
|
# -----------------------
|
|
# test layout conversions
|
|
# -----------------------
|
|
# TODO: backend should be tested separately
|
|
|
|
|
|
layouts = [
|
|
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
|
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
|
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
|
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
|
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
|
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
|
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
|
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
|
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
|
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
|
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("shape", [(128, 128)])
|
|
@pytest.mark.parametrize("dtype", ['float16'])
|
|
@pytest.mark.parametrize("src_layout", layouts)
|
|
@pytest.mark.parametrize("dst_layout", layouts)
|
|
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
|
if str(src_layout) == str(dst_layout):
|
|
pytest.skip()
|
|
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
|
pytest.skip()
|
|
|
|
ir = f"""
|
|
#src = {src_layout}
|
|
#dst = {dst_layout}
|
|
""" + """
|
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
|
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
|
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
|
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
|
|
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
|
|
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
|
|
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
|
|
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
|
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
|
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
|
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
|
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
|
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
|
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
|
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
|
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
|
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
|
tt.return
|
|
}
|
|
}
|
|
"""
|
|
|
|
x = to_triton(numpy_random(shape, dtype_str=dtype))
|
|
z = torch.empty_like(x)
|
|
|
|
# write the IR to a temporary file using mkstemp
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
|
f.write(ir)
|
|
f.flush()
|
|
kernel = triton.compile(f.name)
|
|
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
|
|
|
|
assert torch.equal(z, x)
|
|
|
|
|
|
def test_load_scalar_with_mask():
|
|
@triton.jit
|
|
def kernel(Input, Index, Out, N: int):
|
|
index = tl.load(Index)
|
|
scalar = tl.load(Input + index, mask=index < N, other=0)
|
|
tl.store(Out, scalar, mask=index < N)
|
|
Index = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
Input = torch.tensor([0], dtype=torch.int32, device='cuda')
|
|
Out = torch.empty_like(Index, device='cuda')
|
|
kernel[(1,)](Input, Index, Out, Index.numel())
|
|
assert Out.data[0] == 0
|
|
|
|
|
|
# This test is used to test our own PTX codegen for float16 and int16 conversions
|
|
# maybe delete it later after ptxas has been fixed
|
|
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
|
|
def test_ptx_cast(dtype_str):
|
|
@triton.jit
|
|
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
|
xoffset = tl.program_id(0) * XBLOCK
|
|
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
|
xmask = xindex < xnumel
|
|
rbase = tl.arange(0, RBLOCK)[None, :]
|
|
x0 = xindex
|
|
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
|
|
for roffset in range(0, rnumel, RBLOCK):
|
|
rindex = roffset + rbase
|
|
rmask = rindex < rnumel
|
|
r1 = rindex
|
|
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
|
|
tmp1 = 2
|
|
tmp2 = tmp0 * tmp1
|
|
tmp3 = tmp2.to(dtype)
|
|
tmp5 = _tmp4 < tmp3
|
|
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
|
|
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
|
|
|
|
torch.manual_seed(123)
|
|
if dtype_str == 'int16':
|
|
torch_dtype = torch.int16
|
|
triton_dtype = tl.int32
|
|
else:
|
|
torch_dtype = torch.float16
|
|
triton_dtype = tl.float32
|
|
|
|
s0 = 4
|
|
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
|
|
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
|
|
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
|
assert buf14.to(torch.float32).mean() == -2.0
|