mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] shfl ptx insts should have side effects (#2376)
Otherwise, llvm pass could generate very weird structure of CFG and yield incorrect results. https://github.com/openai/triton/issues/2361
This commit is contained in:
@@ -170,7 +170,26 @@ def test_scan2d_broadcast(RBLOCK, num_warps):
|
||||
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
|
||||
|
||||
XBLOCK = 4
|
||||
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda')
|
||||
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda')
|
||||
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
|
||||
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
|
||||
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
|
||||
torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)))
|
||||
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
|
||||
torch.testing.assert_close(output, ref)
|
||||
|
||||
|
||||
def test_scan2d_for():
|
||||
@triton.jit
|
||||
def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
|
||||
rbase = tl.arange(0, RBLOCK)[None, :]
|
||||
for roffset in range(0, rnumel, RBLOCK):
|
||||
rindex = roffset + rbase
|
||||
rmask = rindex < rnumel
|
||||
tmp3 = tl.where(rmask, 1, 0)
|
||||
tmp6 = tl.cumsum(tmp3, 1)
|
||||
tl.store(out_ptr0 + rindex, tmp6, rmask)
|
||||
|
||||
RBLOCK = 8
|
||||
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
|
||||
fn[(1,)](out0, RBLOCK, RBLOCK)
|
||||
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
|
||||
torch.testing.assert_close(out0, ref)
|
||||
|
||||
Reference in New Issue
Block a user