mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
@@ -1,4 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -66,3 +69,162 @@ def test_chained_matmul():
|
||||
block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
|
||||
|
||||
def test_vecmat():
|
||||
@triton.jit
|
||||
def batched_vecmat(
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
dim_m, dim_n, dim_k,
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
|
||||
):
|
||||
m_index = tl.program_id(0)
|
||||
n_index = tl.program_id(1)
|
||||
# Output tile
|
||||
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
|
||||
|
||||
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
|
||||
k_blocks = dim_k // block_k
|
||||
for k_index in range(k_blocks):
|
||||
# Load A tile
|
||||
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
|
||||
a = tl.load(A + a_tile)
|
||||
|
||||
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
|
||||
# leading dimension.
|
||||
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
|
||||
b = tl.load(B + b_tile)
|
||||
|
||||
expanded_a, _ = tl.broadcast(a, b)
|
||||
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
|
||||
|
||||
tl.store(output + output_tile, vecmat)
|
||||
|
||||
M, N, K = 128, 128, 128
|
||||
block_m, block_n, block_k = 16, 32, 64
|
||||
|
||||
rs = RandomState(17)
|
||||
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
|
||||
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
|
||||
A = A_vec
|
||||
B = B_vec
|
||||
|
||||
A_tri = torch.tensor(A, device='cuda')
|
||||
B_tri = torch.tensor(B, device='cuda')
|
||||
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
|
||||
|
||||
grid = (M // block_m, N // block_n)
|
||||
|
||||
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
|
||||
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||
num_warps=4, num_stages=1)
|
||||
|
||||
A_expanded = A[:, np.newaxis, :]
|
||||
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
|
||||
AB = A_broadcasted * B
|
||||
C_ref = np.sum(AB, axis=2)
|
||||
|
||||
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
def test_iv_dependent_matmul(type):
|
||||
@triton.jit
|
||||
def kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
type: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
a_ptrs = a_ptr
|
||||
b_ptrs = b_ptr
|
||||
if type == "post_load_two_iters":
|
||||
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_three_iters":
|
||||
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
|
||||
a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
if type == "pre_load":
|
||||
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_pre_mixed":
|
||||
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator += tl.dot(a, b)
|
||||
if type == "post_load":
|
||||
a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_pre_mixed":
|
||||
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_two_iters":
|
||||
a_ptrs = a_ptrs_next
|
||||
b_ptrs = b_ptrs_next
|
||||
a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk
|
||||
elif type == "post_load_three_iters":
|
||||
a_ptrs = a_ptrs_next
|
||||
b_ptrs = b_ptrs_next
|
||||
a_ptrs_next = a_ptrs_next_next
|
||||
b_ptrs_next = b_ptrs_next_next
|
||||
a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
M = 256
|
||||
K = 256
|
||||
N = 256
|
||||
BLOCK_SIZE_K = 32
|
||||
BLOCK_SIZE_N = 32
|
||||
BLOCK_SIZE_M = 32
|
||||
|
||||
a = torch.rand((M, K), device='cuda')
|
||||
b = torch.rand((K, N), device='cuda')
|
||||
|
||||
torch_output = torch.mm(a, b)
|
||||
triton_output = torch.empty_like(
|
||||
torch_output, device=torch_output.device)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
|
||||
num_stages = 4 if type == "post_load_three_iters" else 3
|
||||
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
type=type, num_stages=num_stages)
|
||||
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -43,6 +50,9 @@ def test_assert(func: str):
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
@@ -50,5 +60,72 @@ def test_assert(func: str):
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def jit_device_assert_none(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def jit_device_assert_true(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def jit_device_assert_false(x):
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit(debug=False)
|
||||
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
if jit_debug == "true":
|
||||
jit_device_assert_true(x)
|
||||
elif jit_debug == "false":
|
||||
jit_device_assert_false(x)
|
||||
else:
|
||||
jit_device_assert_none(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert_nested(caller: str, callee: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if caller == "none":
|
||||
kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
if len(sys.argv) == 3:
|
||||
test_assert_nested(sys.argv[1], sys.argv[2])
|
||||
else:
|
||||
test_assert(sys.argv[1])
|
||||
|
||||
@@ -130,6 +130,17 @@ class BlockedLayout:
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
class SharedLayout:
|
||||
def __init__(self, vec, per_phase, max_phase, order):
|
||||
self.vec = str(vec)
|
||||
self.per_phase = str(per_phase)
|
||||
self.max_phase = str(max_phase)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -456,6 +467,21 @@ def test_broadcast(dtype):
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
# ------------------
|
||||
# test invalid slice
|
||||
# ------------------
|
||||
|
||||
|
||||
def test_invalid_slice():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
dst[10:]
|
||||
|
||||
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
|
||||
_kernel[(1,)](dst=dst)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test expand_dims
|
||||
@@ -537,6 +563,20 @@ def test_expand_dims_error_cases():
|
||||
duplicate_dim2[(1,)](dummy_tensor, N)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# test invalid program id axis
|
||||
# ----------------------------
|
||||
def test_invalid_pid_axis():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
pid = tl.program_id(20)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
|
||||
_kernel[(1,)](dst)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test where
|
||||
# ---------------
|
||||
@@ -1368,6 +1408,9 @@ reduce_configs2 = [
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
] + [
|
||||
(op, 'float32', [16, 32], None)
|
||||
for op in ['min', 'max', 'sum']
|
||||
]
|
||||
|
||||
|
||||
@@ -1382,7 +1425,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS == 1:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
@@ -1407,7 +1452,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
ret_numel = 1 if axis is None else shape[1 - axis]
|
||||
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
@@ -1958,6 +2004,23 @@ def test_full(dtype_str):
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("literal, dtype_str",
|
||||
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
|
||||
('float("inf")', "f32"), ('float("-inf")', "f32"),
|
||||
('float("nan")', "f32"), ('float("-nan")', "f32"),
|
||||
(0., "f32"),
|
||||
(5, "i32"), (2**40, "i64"),])
|
||||
def test_constexpr(literal, dtype_str):
|
||||
@triton.jit
|
||||
def kernel(out_ptr):
|
||||
val = GENERATE_TEST_HERE
|
||||
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
|
||||
|
||||
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
|
||||
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
|
||||
h = kernel_patched[(1,)](out)
|
||||
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
@@ -2628,41 +2691,64 @@ def add_fn_static_cond(x, cond: tl.constexpr):
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
|
||||
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit",
|
||||
"jit", "jit_if", "jit_ifexp", "jit_expr",
|
||||
"jit_static_cond", "jit_noinline", "jit_extern"])
|
||||
def test_if_call(call_type):
|
||||
@triton.jit
|
||||
def kernel(Out, call_type: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
o = tl.load(Out)
|
||||
if pid == 0:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
a = o + 1
|
||||
a = a.to(tl.int32).to(tl.int32)
|
||||
o = a
|
||||
else:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
if pid == 0:
|
||||
a = o
|
||||
if call_type == "jit_function":
|
||||
# regular function call
|
||||
a = add_fn(a)
|
||||
elif call_type == "jit_function_return":
|
||||
# function without end_if block
|
||||
a = add_fn_return(a, pid)
|
||||
elif call_type == "ifexp":
|
||||
# ifexp expression
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
elif call_type == "expr":
|
||||
if pid == 1:
|
||||
return
|
||||
a = add_fn(a)
|
||||
if pid == 0:
|
||||
# call without return
|
||||
add_fn_expr(Out, a)
|
||||
elif call_type == "jit_function_static_cond":
|
||||
a = add_fn_static_cond(a, call_type)
|
||||
elif call_type == "jit_function_noinline":
|
||||
a = add_fn_noinline(a)
|
||||
a = a.to(tl.int32).to(tl.int32) + 1
|
||||
o = a
|
||||
elif call_type == "attribute_jit":
|
||||
# call attribute and jit function
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1
|
||||
o = a
|
||||
elif call_type == "jit":
|
||||
if pid == 0:
|
||||
# regular function call
|
||||
a = o
|
||||
a = add_fn(a)
|
||||
o = a
|
||||
elif call_type == "jit_if":
|
||||
# function without end_if block
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = add_fn_return(a, pid)
|
||||
o = a
|
||||
elif call_type == "jit_ifexp":
|
||||
# ifexp expression
|
||||
if pid == 0:
|
||||
a = o
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
o = a
|
||||
elif call_type == "jit_expr":
|
||||
# call without return
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_expr(Out, a)
|
||||
o = a
|
||||
elif call_type == "jit_static_cond":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_static_cond(o, call_type)
|
||||
o = a
|
||||
elif call_type == "jit_noinline":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
add_fn_noinline(a)
|
||||
o = a
|
||||
elif call_type == "jit_extern":
|
||||
if pid == 0:
|
||||
a = o + 1
|
||||
tl.cdiv(a, a)
|
||||
o = a
|
||||
|
||||
tl.store(Out, o)
|
||||
@@ -2766,7 +2852,7 @@ def test_globaltimer():
|
||||
def kernel(Out1, Out2):
|
||||
start = tl.extra.cuda.globaltimer()
|
||||
off = tl.arange(0, 128)
|
||||
for i in range(100):
|
||||
for i in range(10000):
|
||||
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
|
||||
end = tl.extra.cuda.globaltimer()
|
||||
tl.store(Out2, end - start)
|
||||
@@ -2810,22 +2896,49 @@ layouts = [
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
intermediate_layouts = [
|
||||
None,
|
||||
SharedLayout(1, 1, 1, [1, 0]),
|
||||
SharedLayout(4, 2, 4, [1, 0]),
|
||||
SharedLayout(2, 2, 4, [1, 0]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
layouts = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" if interm_layout is None else f"""
|
||||
#src = {src_layout}
|
||||
#interm = {interm_layout}
|
||||
#dst = {dst_layout}
|
||||
"""
|
||||
|
||||
conversion = f"""
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" if interm_layout is None else f"""
|
||||
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
|
||||
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
|
||||
|
||||
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
"""
|
||||
|
||||
ir = layouts + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
@@ -2840,8 +2953,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" + conversion + """
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
tt.return
|
||||
|
||||
@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
|
||||
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
|
||||
def test_assert_nested(caller_type, callee_type):
|
||||
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
if caller_type == "none":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
elif caller_type == "true":
|
||||
if callee_type == "false":
|
||||
assert num_errs == 0
|
||||
else:
|
||||
assert num_errs == 127
|
||||
elif caller_type == "false":
|
||||
if callee_type == "true":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
|
||||
@@ -5,7 +5,10 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
|
||||
(4, 48, 1024, 32),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 1024, 128)])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
assert len(kernel_add.cache[device]) == 3
|
||||
bins = list(kernel_add.cache[device].values())
|
||||
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
|
||||
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user