mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix negative induction variable (#1382)
This commit is contained in:
@@ -1963,21 +1963,24 @@ def test_math_scalar(dtype_str, expr, lib_path):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_for_iv_int64():
|
||||
@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3),
|
||||
(15, -16, -1), (15, -16, -2), (15, -16, -3),
|
||||
(-18, -22, -1), (22, 18, -1)])
|
||||
def test_for_iv(lo, hi, iv):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out, lo, hi):
|
||||
def kernel(Out, lo, hi, iv: tl.constexpr):
|
||||
acc = 0
|
||||
acc = acc.to(tl.int64)
|
||||
for i in range(lo, hi):
|
||||
for i in range(lo, hi, iv):
|
||||
acc += i
|
||||
tl.store(Out, acc)
|
||||
|
||||
lo = 2**35
|
||||
hi = 2**35 + 20
|
||||
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
kernel[(1,)](out, lo, hi)
|
||||
assert out[0] == sum(range(lo, hi))
|
||||
kernel[(1,)](out, lo, hi, iv)
|
||||
assert out[0] == sum(range(lo, hi, iv))
|
||||
|
||||
|
||||
def test_if_else():
|
||||
|
||||
Reference in New Issue
Block a user