[BACKEND] shfl ptx insts should have side effects (#2376)

Otherwise, llvm pass could generate very weird structure of CFG and
yield incorrect results.

https://github.com/openai/triton/issues/2361
This commit is contained in:
Keren Zhou
2023-09-23 13:05:20 -04:00
committed by GitHub
parent cb83b42ed6
commit 57fc6d1f13
3 changed files with 24 additions and 7 deletions

View File

@@ -170,7 +170,26 @@ def test_scan2d_broadcast(RBLOCK, num_warps):
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
XBLOCK = 4
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda')
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)))
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
torch.testing.assert_close(output, ref)
def test_scan2d_for():
@triton.jit
def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
rbase = tl.arange(0, RBLOCK)[None, :]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
tmp3 = tl.where(rmask, 1, 0)
tmp6 = tl.cumsum(tmp3, 1)
tl.store(out_ptr0 + rindex, tmp6, rmask)
RBLOCK = 8
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
fn[(1,)](out0, RBLOCK, RBLOCK)
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
torch.testing.assert_close(out0, ref)