diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index cdbf7ec1a4..64f9cfddf9 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -527,6 +527,15 @@ class TestSymbolic(unittest.TestCase): # not combining # TODO: can combine if one is identity element const self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))") + def test_symbolic_div(self): + # from symbolic arange + a = Variable("a", 1, 10) + denominator = ((a*-2)+1) + numerator = (((((a*2)+-1)*2)+1)*a) + self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)") + self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))") + self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a21215f316..07c47d62ef 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -425,9 +425,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1 - if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST - if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin - if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin) + if self.op is Ops.IDIV: + if s1_vmin == s1_vmax: # min/max are equal in a CONST + if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin + if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin) + # don't know exact bounds, but know the sign + if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype) + if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0 if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax) if self.op is Ops.CMPLT: return (s0_vmax