mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: remove old tests
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# trigger the torch.device implicitly to ensure cuda context initialization
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
|
||||
def test_empty_kernel_cubin_compile():
|
||||
|
||||
kernel = triton.compile(empty_kernel,
|
||||
signature="*fp32,i32,i32",
|
||||
constants={"BLOCK": 256})
|
||||
if torch.version.hip is not None:
|
||||
assert len(kernel.asm["hsaco_path"]) > 0
|
||||
else:
|
||||
assert len(kernel.asm["cubin"]) > 0
|
||||
|
||||
|
||||
def test_empty_kernel_launch():
|
||||
grid = lambda META: (
|
||||
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
|
||||
)
|
||||
|
||||
A = torch.zeros([1024], device="cuda")
|
||||
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)
|
||||
@@ -1,148 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr
|
||||
):
|
||||
offs_m = tl.arange(0, M)
|
||||
offs_n = tl.arange(0, N)
|
||||
offs_k = tl.arange(0, K)
|
||||
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
|
||||
c = tl.dot(a, b)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
# TODO: num_warps could only be 4 for now
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[128, 64, 32, 4],
|
||||
[256, 128, 16, 4],
|
||||
[128, 32, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 32, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[64, 128, 128, 2],
|
||||
])
|
||||
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# Non-forloop
|
||||
[64, 32, 64, 4, 64, 32, 64],
|
||||
[128, 64, 128, 4, 128, 64, 128],
|
||||
# K-Forloop
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
[128, 32, 128, 4, 128, 32, 32],
|
||||
[32, 32, 128, 4, 32, 32, 32],
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 64, 64, 4, 128, 64, 32],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
[64, 64, 256, 4, 64, 64, 64],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
[256, 128, 64, 4, 256, 128, 16],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
@@ -1,137 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
|
||||
|
||||
|
||||
def get_reduced_dtype(dtype):
|
||||
if dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
return torch.int32
|
||||
if dtype in [torch.bfloat16]:
|
||||
return torch.float32
|
||||
return dtype
|
||||
|
||||
|
||||
def patch_kernel(template, to_replace):
|
||||
kernel = triton.JITFunction(template.fn)
|
||||
for key, value in to_replace.items():
|
||||
kernel.src = kernel.src.replace(key, value)
|
||||
return kernel
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr):
|
||||
x = tl.load(x_ptr + tl.arange(0, block))
|
||||
tl.store(z_ptr, tl.OP(x, axis=0))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr):
|
||||
range_m = tl.arange(0, block_m)
|
||||
range_n = tl.arange(0, block_n)
|
||||
x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :])
|
||||
z = tl.OP(x, axis=axis)
|
||||
if axis == 0:
|
||||
tl.store(z_ptr + range_n, z)
|
||||
else:
|
||||
tl.store(z_ptr + range_m, z)
|
||||
|
||||
|
||||
reduce1d_configs = [
|
||||
(op, dtype, shape)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in dtypes
|
||||
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
|
||||
def test_reduce1d(op, dtype, shape):
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn((shape,), device='cuda', dtype=dtype)
|
||||
elif dtype is torch.uint8:
|
||||
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
|
||||
else:
|
||||
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
|
||||
z = torch.empty(
|
||||
tuple(),
|
||||
device=x.device,
|
||||
dtype=reduced_dtype,
|
||||
)
|
||||
|
||||
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
|
||||
grid = (1,)
|
||||
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x).to(reduced_dtype)
|
||||
else:
|
||||
golden_z = torch.max(x).to(reduced_dtype)
|
||||
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape >= 32:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
||||
|
||||
|
||||
reduce2d_configs = [
|
||||
(op, dtype, shape, axis)
|
||||
for op in ['sum', 'min', 'max']
|
||||
for dtype in dtypes
|
||||
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
|
||||
def test_reduce2d(op, dtype, shape, axis):
|
||||
dtype = dtype_mapping[dtype]
|
||||
reduced_dtype = get_reduced_dtype(dtype)
|
||||
reduced_shape = (shape[1 - axis],)
|
||||
|
||||
if dtype.is_floating_point:
|
||||
x = torch.randn(shape, device='cuda', dtype=dtype)
|
||||
elif dtype is torch.uint8:
|
||||
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
|
||||
else:
|
||||
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
|
||||
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
|
||||
|
||||
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
|
||||
grid = (1,)
|
||||
kernel[grid](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
|
||||
|
||||
if op == 'sum':
|
||||
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
|
||||
elif op == 'min':
|
||||
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
else:
|
||||
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
|
||||
if dtype.is_floating_point and op == 'sum':
|
||||
if shape[axis] >= 256:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.1)
|
||||
elif shape[axis] >= 32:
|
||||
assert_close(z, golden_z, rtol=0.05, atol=0.02)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.01, atol=0.01)
|
||||
else:
|
||||
assert_close(z, golden_z, rtol=0.001, atol=0.001)
|
||||
@@ -1,47 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xm,
|
||||
z_ptr, stride_zn,
|
||||
SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
|
||||
off_m = tl.arange(0, SIZE_M)
|
||||
off_n = tl.arange(0, SIZE_N)
|
||||
Xs = x_ptr + off_m[:, None] * stride_xm + off_n[None, :] * 1
|
||||
Zs = z_ptr + off_m[:, None] * 1 + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
|
||||
# These sizes cover the case of:
|
||||
# - blocked layout and sliced layout with block parent
|
||||
# -- blocked layout in which sizePerThread/threadsPerWarp/warpsPerCTA
|
||||
# need/need not to be wrapped
|
||||
# -- sliced layout incase sizePerThread need to be wrapped
|
||||
# -- different orders
|
||||
# - LayoutConversion from blocked -> blocked
|
||||
# - tt.Broadcast which requires for broadcast in either/both of
|
||||
# CTA/perThread level
|
||||
|
||||
# What is not covered and requires for TODO:
|
||||
# - vectorization load/store of shared memory
|
||||
# - multiple replication of layout conversion
|
||||
|
||||
|
||||
@pytest.mark.parametrize('NUM_WARPS,SIZE_M,SIZE_N', [
|
||||
[1, 16, 16],
|
||||
[1, 32, 32],
|
||||
[1, 32, 64],
|
||||
[2, 64, 128],
|
||||
[2, 128, 64]
|
||||
])
|
||||
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
||||
grid = lambda META: (1, )
|
||||
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
||||
golden_z = torch.t(x)
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7, check_dtype=False)
|
||||
@@ -1,215 +0,0 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||
[4, 256, 1],
|
||||
[4, 1024, 256],
|
||||
])
|
||||
def test_vecadd_scf_no_mask(num_warps, block_size, iter_size):
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
block_size,
|
||||
iter_size: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(0, block_size, iter_size):
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
x_ptr += iter_size
|
||||
y_ptr += iter_size
|
||||
z_ptr += iter_size
|
||||
|
||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||
|
||||
golden_z = x + y
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [
|
||||
[(127, 3), 2, 128, 1],
|
||||
[(127, 3), 2, 128, 32],
|
||||
])
|
||||
def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
num_elements,
|
||||
block_size: tl.constexpr,
|
||||
iter_size: tl.constexpr
|
||||
):
|
||||
'''
|
||||
@block_size: size of a block
|
||||
@iter_size: size of the iteration, a block has multiple iterations
|
||||
@num_elements: number of elements
|
||||
'''
|
||||
pid = tl.program_id(axis=0)
|
||||
for i in range(math.ceil(block_size / iter_size)):
|
||||
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
||||
offset = pid * block_size + tl.arange(0, iter_size)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
|
||||
x = tl.load(x_ptrs, mask=offset < num_elements)
|
||||
y = tl.load(y_ptrs, mask=offset < num_elements)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z, mask=offset < num_elements)
|
||||
|
||||
x_ptr += iter_size
|
||||
y_ptr += iter_size
|
||||
z_ptr += iter_size
|
||||
|
||||
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (math.ceil(x.numel() / block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps,
|
||||
num_elements=x.numel())
|
||||
|
||||
golden_z = x + y
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def vecadd_no_scf_tester(num_warps, block_size, shape):
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
n_elements,
|
||||
block_size_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
|
||||
mask = offset < n_elements
|
||||
|
||||
x = tl.load(x_ptrs, mask=mask)
|
||||
y = tl.load(y_ptrs, mask=mask)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
||||
|
||||
golden_z = x + y
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
||||
'''
|
||||
vecadd tester with float comparation as load/store mask.
|
||||
'''
|
||||
@triton.jit
|
||||
def kernel(x_ptr,
|
||||
y_ptr,
|
||||
z_ptr,
|
||||
n_elements,
|
||||
block_size_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
|
||||
io_mask = offset < n_elements
|
||||
x = tl.load(x_ptrs, mask=io_mask)
|
||||
y = tl.load(y_ptrs, mask=io_mask)
|
||||
|
||||
z = x + y
|
||||
val_mask = offset < n_elements and (z < 0. or z > 1.)
|
||||
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z, mask=val_mask)
|
||||
|
||||
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||
z = torch.zeros(shape, device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
||||
|
||||
golden_z: torch.Tensor = x + y
|
||||
gz_data = torch.flatten(golden_z)
|
||||
for i in range(golden_z.numel()):
|
||||
gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0.
|
||||
|
||||
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||
[4, 256, (256,)],
|
||||
[2, 256, (256,)],
|
||||
[1, 256, (256,)],
|
||||
[4, 16, (256,)],
|
||||
[2, 64, (256,)],
|
||||
[1, 128, (256,)],
|
||||
])
|
||||
def test_vecadd_no_scf(num_warps, block_size, shape):
|
||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||
[1, 128, (256 + 1,)],
|
||||
[1, 256, (256 + 1,)],
|
||||
[2, 256, (3, 256 + 7)],
|
||||
[4, 256, (3, 256 + 7)],
|
||||
])
|
||||
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
|
||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||
|
||||
|
||||
def test_vecadd_no_scf_masked_randomly():
|
||||
random.seed(0) # fix seed to make random test reproducible
|
||||
for i in range(10):
|
||||
num_elements = random.randint(128, 2048)
|
||||
shape = (num_elements,)
|
||||
max_warps = num_elements // 32 # floor div
|
||||
for num_warps in range(1, max_warps):
|
||||
is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0
|
||||
if not is_power2: continue
|
||||
block_size = min(32, num_warps * 32)
|
||||
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||
[1, 128, (256 + 1,)],
|
||||
[1, 256, (256 + 1,)],
|
||||
[2, 256, (3, 256 + 7)],
|
||||
[4, 256, (3, 256 + 7)],
|
||||
])
|
||||
def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape):
|
||||
vecadd_fcmp_no_scf_tester(num_warps, block_size, shape)
|
||||
Reference in New Issue
Block a user