Merge pull request #229 from ROCmSoftwarePlatform/ifu230601

IFU 230601
This commit is contained in:
jayfurmanek
2023-06-09 07:55:32 -05:00
committed by GitHub
44 changed files with 1367 additions and 414 deletions

View File

@@ -1,4 +1,7 @@
import numpy as np
import pytest
import torch
from numpy.random import RandomState
import triton
import triton.language as tl
@@ -66,3 +69,162 @@ def test_chained_matmul():
block_k=block_k)
assert (torch_result == triton_result).all()
def test_vecmat():
@triton.jit
def batched_vecmat(
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
dim_m, dim_n, dim_k,
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
):
m_index = tl.program_id(0)
n_index = tl.program_id(1)
# Output tile
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
k_blocks = dim_k // block_k
for k_index in range(k_blocks):
# Load A tile
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
a = tl.load(A + a_tile)
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
# leading dimension.
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
b = tl.load(B + b_tile)
expanded_a, _ = tl.broadcast(a, b)
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
tl.store(output + output_tile, vecmat)
M, N, K = 128, 128, 128
block_m, block_n, block_k = 16, 32, 64
rs = RandomState(17)
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
A = A_vec
B = B_vec
A_tri = torch.tensor(A, device='cuda')
B_tri = torch.tensor(B, device='cuda')
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
grid = (M // block_m, N // block_n)
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
block_m=block_m, block_n=block_n, block_k=block_k,
num_warps=4, num_stages=1)
A_expanded = A[:, np.newaxis, :]
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
AB = A_broadcasted * B
C_ref = np.sum(AB, axis=2)
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
def test_iv_dependent_matmul(type):
@triton.jit
def kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
type: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a_ptrs = a_ptr
b_ptrs = b_ptr
if type == "post_load_two_iters":
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
elif type == "post_load_three_iters":
a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk
a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak
b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if type == "pre_load":
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk
elif type == "post_pre_mixed":
a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator += tl.dot(a, b)
if type == "post_load":
a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
elif type == "post_pre_mixed":
b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk
elif type == "post_load_two_iters":
a_ptrs = a_ptrs_next
b_ptrs = b_ptrs_next
a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak
b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk
elif type == "post_load_three_iters":
a_ptrs = a_ptrs_next
b_ptrs = b_ptrs_next
a_ptrs_next = a_ptrs_next_next
b_ptrs_next = b_ptrs_next_next
a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak
b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
M = 256
K = 256
N = 256
BLOCK_SIZE_K = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_M = 32
a = torch.rand((M, K), device='cuda')
b = torch.rand((K, N), device='cuda')
torch_output = torch.mm(a, b)
triton_output = torch.empty_like(
torch_output, device=torch_output.device)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
num_stages = 4 if type == "post_load_three_iters" else 3
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
type=type, num_stages=num_stages)
torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2)

View File

@@ -22,6 +22,13 @@ def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_assert(x == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -43,6 +50,9 @@ def test_assert(func: str):
if func == "device_assert":
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "no_debug":
# TRITON_DEBUG=True can override the debug flag
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":
@@ -50,5 +60,72 @@ def test_assert(func: str):
assert_close(y, x)
@triton.jit
def jit_device_assert_none(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=True)
def jit_device_assert_true(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit(debug=False)
def jit_device_assert_false(x):
tl.device_assert(x == 0, "x != 0")
@triton.jit
def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=True)
def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit(debug=False)
def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
if jit_debug == "true":
jit_device_assert_true(x)
elif jit_debug == "false":
jit_device_assert_false(x)
else:
jit_device_assert_none(x)
tl.store(Y + tl.arange(0, BLOCK), x)
def test_assert_nested(caller: str, callee: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if caller == "none":
kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
elif caller == "true":
kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
elif caller == "false":
kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
assert_close(y, x)
if __name__ == "__main__":
test_assert(sys.argv[1])
if len(sys.argv) == 3:
test_assert_nested(sys.argv[1], sys.argv[2])
else:
test_assert(sys.argv[1])

View File

@@ -130,6 +130,17 @@ class BlockedLayout:
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
class SharedLayout:
def __init__(self, vec, per_phase, max_phase, order):
self.vec = str(vec)
self.per_phase = str(per_phase)
self.max_phase = str(max_phase)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@@ -456,6 +467,21 @@ def test_broadcast(dtype):
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ------------------
# test invalid slice
# ------------------
def test_invalid_slice():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
dst[10:]
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
_kernel[(1,)](dst=dst)
# ----------------
# test expand_dims
@@ -537,6 +563,20 @@ def test_expand_dims_error_cases():
duplicate_dim2[(1,)](dummy_tensor, N)
# ----------------------------
# test invalid program id axis
# ----------------------------
def test_invalid_pid_axis():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
pid = tl.program_id(20)
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
_kernel[(1,)](dst)
# ---------------
# test where
# ---------------
@@ -1368,6 +1408,9 @@ reduce_configs2 = [
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for shape in reduce2d_shapes
for axis in [0, 1]
] + [
(op, 'float32', [16, 32], None)
for op in ['min', 'max', 'sum']
]
@@ -1382,7 +1425,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if AXIS == 1:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 1:
tl.store(Z + range_m, z)
else:
tl.store(Z + range_n, z)
@@ -1407,7 +1452,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
ret_numel = 1 if axis is None else shape[1 - axis]
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)
@@ -1958,6 +2004,23 @@ def test_full(dtype_str):
assert torch.all(out_dynamic == 2)
@pytest.mark.parametrize("literal, dtype_str",
[(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"),
('float("inf")', "f32"), ('float("-inf")', "f32"),
('float("nan")', "f32"), ('float("-nan")', "f32"),
(0., "f32"),
(5, "i32"), (2**40, "i64"),])
def test_constexpr(literal, dtype_str):
@triton.jit
def kernel(out_ptr):
val = GENERATE_TEST_HERE
tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val)
kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"})
out = torch.zeros((1,), dtype=torch.float32, device="cuda")
h = kernel_patched[(1,)](out)
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
# def test_dot_without_load(dtype_str):
@@ -2628,41 +2691,64 @@ def add_fn_static_cond(x, cond: tl.constexpr):
return x + 1
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit",
"jit", "jit_if", "jit_ifexp", "jit_expr",
"jit_static_cond", "jit_noinline", "jit_extern"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
pid = tl.program_id(0)
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32).to(tl.int32)
o = a
else:
if call_type == "attribute":
# call attribute
if pid == 0:
a = o
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
a = a.to(tl.int32).to(tl.int32) + 1
o = a
elif call_type == "attribute_jit":
# call attribute and jit function
if pid == 0:
a = o
a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1
o = a
elif call_type == "jit":
if pid == 0:
# regular function call
a = o
a = add_fn(a)
o = a
elif call_type == "jit_if":
# function without end_if block
if pid == 0:
a = o
a = add_fn_return(a, pid)
o = a
elif call_type == "jit_ifexp":
# ifexp expression
if pid == 0:
a = o
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
o = a
elif call_type == "jit_expr":
# call without return
if pid == 0:
a = o + 1
add_fn_expr(Out, a)
o = a
elif call_type == "jit_static_cond":
if pid == 0:
a = o + 1
add_fn_static_cond(o, call_type)
o = a
elif call_type == "jit_noinline":
if pid == 0:
a = o + 1
add_fn_noinline(a)
o = a
elif call_type == "jit_extern":
if pid == 0:
a = o + 1
tl.cdiv(a, a)
o = a
tl.store(Out, o)
@@ -2766,7 +2852,7 @@ def test_globaltimer():
def kernel(Out1, Out2):
start = tl.extra.cuda.globaltimer()
off = tl.arange(0, 128)
for i in range(100):
for i in range(10000):
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
end = tl.extra.cuda.globaltimer()
tl.store(Out2, end - start)
@@ -2810,22 +2896,49 @@ layouts = [
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
intermediate_layouts = [
None,
SharedLayout(1, 1, 1, [1, 0]),
SharedLayout(4, 2, 4, [1, 0]),
SharedLayout(2, 2, 4, [1, 0]),
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
""" if interm_layout is None else f"""
#src = {src_layout}
#interm = {interm_layout}
#dst = {dst_layout}
"""
conversion = f"""
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" if interm_layout is None else f"""
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
"""
ir = layouts + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
@@ -2840,8 +2953,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" + conversion + """
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
tt.return

View File

@@ -9,7 +9,8 @@ print_path = os.path.join(dir_path, "print_helper.py")
assert_path = os.path.join(dir_path, "assert_helper.py")
# TODO: bfloat16 after LLVM-15
func_types = ["device_assert", "assert", "static_assert"]
func_types = ["device_assert", "assert", "static_assert", "no_debug"]
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
@@ -51,3 +52,29 @@ def test_assert(func_type: str):
assert num_errs == 127
else:
assert num_errs == 0
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
def test_assert_nested(caller_type, callee_type):
proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
_, errs = proc.communicate()
errs = errs.splitlines()
num_errs = 0
for err in errs:
if "x != 0" in err.decode("utf-8"):
num_errs += 1
if caller_type == "none":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0
elif caller_type == "true":
if callee_type == "false":
assert num_errs == 0
else:
assert num_errs == 127
elif caller_type == "false":
if callee_type == "true":
assert num_errs == 127
else:
assert num_errs == 0

View File

@@ -5,7 +5,10 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op(Z, H, N_CTX, D_HEAD, dtype):
capability = torch.cuda.get_device_capability()

View File

@@ -160,12 +160,12 @@ def test_jit_debug() -> None:
assert len(kernel_add.cache[device]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 1
assert len(kernel_add.cache[device]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache[device]) == 2
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
@triton.jit