mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
UOp bound for div negative number (#5808)
This commit is contained in:
@@ -265,14 +265,6 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
|
||||
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
lidx = Variable("lidx", 0, 7)
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "(((1+lidx)//4)+(gidx*2))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*2)+(lidx//4))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "(((3+lidx)//4)+(gidx*2)+-1)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "(((2+lidx)//4)+(gidx*2)+-1)")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
|
||||
@@ -325,20 +325,34 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
# TODO: simplify the expression
|
||||
def test_div_neg_cancel(self):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((1+idx)//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "(idx//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "(((3+idx)//4)+-1)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((-idx)+199)//(-4))+50)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((-idx)+200)//(-4))+50)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((-idx)+201)//(-4))+50)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
# NOTE: tests are not correct in symbolic
|
||||
# TODO: simplify the expression
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
lidx = Variable("lidx", 0, 7)
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "(((1+lidx)//4)+(gidx*2))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*2)+(lidx//4))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "(((3+lidx)//4)+(gidx*2)+-1)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "(((2+lidx)//4)+(gidx*2)+-1)")
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((-gidx)*8)+(-lidx)+999)//(-4))+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1000)//(-4))+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1001)//(-4))+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((-gidx)*8)+(-lidx)+1002)//(-4))+250)")
|
||||
|
||||
# NOTE: tests are not correct in symbolic
|
||||
def test_div_neg_then_neg(self):
|
||||
# taken from arange opts
|
||||
lidx0 = Variable("lidx0", 0, 7)
|
||||
lidx1 = Variable("lidx1", 0, 7)
|
||||
alu2 = -lidx0-lidx1
|
||||
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
|
||||
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "(-4)")
|
||||
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "((((-lidx0)+(-lidx1)+134)//(-32))+4)")
|
||||
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
|
||||
@@ -117,7 +117,9 @@ class UOp:
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST and s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
|
||||
Reference in New Issue
Block a user