_min_max for WHERE (#6564)

prereq to gated load simplification

just for int
This commit is contained in:
chenyu
2024-09-18 23:47:48 -04:00
committed by GitHub
parent 1b6eee02ad
commit 7f9fd556b0
3 changed files with 17 additions and 1 deletions

View File

@@ -441,6 +441,12 @@ class TestSymbolic(unittest.TestCase):
unrolled_div = (alu0+2559)//-2+(alu0+2560)//-2+2559
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
def test_gated_load(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
# TODO: simplify the true branch
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):(-1))")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):

View File

@@ -36,6 +36,14 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_where(self):
x = UOp.define_var('x', dtypes.int, 0, 10)
y = UOp.define_var('y', dtypes.int, 1, 11)
z = UOp.define_var('z', dtypes.int, 2, 12)
uop = x.lt(5).where(y, z)
self.assertEqual(uop.vmin, 1)
self.assertEqual(uop.vmax, 12)
class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_division_positive(self):
# vmin and vmax for division of a variable by a positive constant

View File

@@ -258,7 +258,7 @@ class UOp(MathTrait):
if self.op is UOps.CONST: return self.arg, self.arg
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
if self.op is UOps.ALU and self.dtype.count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL:
# both are non-positive
@@ -274,6 +274,8 @@ class UOp(MathTrait):
if s1.arg < 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
# float has NAN issue and we use explicit NAN in transcendental
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
return dtypes.min(self.dtype), dtypes.max(self.dtype)
@dataclass(frozen=True)