mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Add arith::SelectOp => LLVM::SelectOp conversion (#1307)
This commit is contained in:
@@ -385,17 +385,22 @@ def test_where(dtype):
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
TEST_POINTERS: tl.constexpr,
|
||||
TEST_SCALAR_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
if TEST_SCALAR_POINTERS:
|
||||
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
||||
output = tl.load(ptr + offsets, mask=mask)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
@@ -411,8 +416,12 @@ def test_where(dtype):
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
if select_ptrs:
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
||||
z = np.where(cond[0], x, y)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
def test_where_broadcast():
|
||||
|
||||
Reference in New Issue
Block a user