mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
improve idiv _min_max (#8066)
for the cases that the we don't know the exact bounds, we might still know the sign. with this, can remove some resolve for symbolic shapetracker
This commit is contained in:
@@ -527,6 +527,15 @@ class TestSymbolic(unittest.TestCase):
|
||||
# not combining # TODO: can combine if one is identity element const
|
||||
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
|
||||
|
||||
def test_symbolic_div(self):
|
||||
# from symbolic arange
|
||||
a = Variable("a", 1, 10)
|
||||
denominator = ((a*-2)+1)
|
||||
numerator = (((((a*2)+-1)*2)+1)*a)
|
||||
self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)")
|
||||
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
|
||||
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -425,9 +425,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
||||
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
||||
if self.op is Ops.IDIV:
|
||||
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
||||
# don't know exact bounds, but know the sign
|
||||
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
|
||||
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
|
||||
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
||||
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
||||
|
||||
Reference in New Issue
Block a user