Don't rewrite idiv to rshift when numerator is negative (#8885)

* more conditions for shift rewrite mul/idiv

* make ptx test uint so the new condition is true

* delete idiv test

* rewrite to 0 is wrong for idiv, as denominator is cast to 0 before division

* mul/div by 2**(large count) is unsupported anyway
This commit is contained in:
eliotgolding
2025-02-04 23:47:33 +00:00
committed by GitHub
parent 666b6149bc
commit bb5ded85cc
3 changed files with 14 additions and 8 deletions

View File

@@ -725,6 +725,12 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True)
helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True)
def test_idiv_shift_rewrite_negative(self):
a = Tensor(-5).idiv(2).item()
b = Tensor(-5).contiguous().idiv(2).item()
self.assertEqual(a, b)
self.assertEqual(Tensor(-1).contiguous().idiv(4).item(), 0) # NOTE this is trunc-div behaviour
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())

View File

@@ -355,12 +355,12 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.MUL, ops)
def test_bitshift_right(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
a1 = UOp(Ops.IDIV, dtypes.int, (l1, c1))
a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2))
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.uint, (), 2)
c2 = UOp(Ops.CONST, dtypes.uint, (), 3)
l1 = UOp(Ops.LOAD, dtypes.uint, (g1.index(c1),))
a1 = UOp(Ops.IDIV, dtypes.uint, (l1, c1))
a2 = UOp(Ops.IDIV, dtypes.uint, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
ops = [x.op for x in uops]