mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
_min_max for WHERE (#6564)
prereq to gated load simplification just for int
This commit is contained in:
@@ -441,6 +441,12 @@ class TestSymbolic(unittest.TestCase):
|
||||
unrolled_div = (alu0+2559)//-2+(alu0+2560)//-2+2559
|
||||
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
|
||||
|
||||
def test_gated_load(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
|
||||
# TODO: simplify the true branch
|
||||
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):(-1))")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
|
||||
@@ -36,6 +36,14 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(uop.vmin, 5)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_where(self):
|
||||
x = UOp.define_var('x', dtypes.int, 0, 10)
|
||||
y = UOp.define_var('y', dtypes.int, 1, 11)
|
||||
z = UOp.define_var('z', dtypes.int, 2, 12)
|
||||
uop = x.lt(5).where(y, z)
|
||||
self.assertEqual(uop.vmin, 1)
|
||||
self.assertEqual(uop.vmax, 12)
|
||||
|
||||
class TestVminVmaxDivMod(unittest.TestCase):
|
||||
def test_vmin_vmax_division_positive(self):
|
||||
# vmin and vmax for division of a variable by a positive constant
|
||||
|
||||
@@ -258,7 +258,7 @@ class UOp(MathTrait):
|
||||
if self.op is UOps.CONST: return self.arg, self.arg
|
||||
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is UOps.ALU and self.dtype.count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.arg is BinaryOps.MUL:
|
||||
# both are non-positive
|
||||
@@ -274,6 +274,8 @@ class UOp(MathTrait):
|
||||
if s1.arg < 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
|
||||
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
|
||||
if self.arg is BinaryOps.CMPLT: return (s0.vmax<s1.vmin, s0.vmin<s1.vmax)
|
||||
# float has NAN issue and we use explicit NAN in transcendental
|
||||
if self.arg is TernaryOps.WHERE and dtypes.is_int(s1.dtype): return min(s1.vmin, s2.vmin), max(s1.vmax, s2.vmax)
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user