Invalid idx (#12067)

* merge index_dtype_3

* new lowering with Invalid idx

* remove that dtype from range

* finish merge

* annotate better

* indentation

* dont need that anymore

* always process replay for openpilot

* more uop_given_valid for idx

* valid past index_child

* fix bug preventing load getting an alt value

* add track_match_stats back in in shapetracker and remove cache

* get_valid_idx -> get_valid and get_idx

* fix heuristics with new idx

* split line

* fix typo

* fix signature

* dont skip idx if stride is 0

the idx may still be invalid

* lower const with new valid

* delete to_indexed_uops

* update shapetracker test

* delete axis_is_masked

* add cache back

* move around comment

* fix get_valid bug

* move invalid fold to symbolic so its earlier

* cleanup

* update applying padto to new idx

* add unit tests

* cleanup

* fold line

* improve spec

* dont try to render Invalid as a float

* more consistent invalid index

* update some tests

* Fold index with true cond

* skip test

* vconst min max if Invalid in arg

* fix signature of UOp.const

* add test for min/max of Invalid CONST/VCONST

* add InvalidType to as_const signature

* is Invalid to isinstance

* Add InvalidType to ConstLike

* index gate is a where gate

* make that a metaclass

* fix heurisics for new idx

* mypy happy
This commit is contained in:
Sieds Lykles
2025-09-12 01:42:02 +02:00
committed by GitHub
parent 544eb2c402
commit 1f3950a484
18 changed files with 188 additions and 86 deletions

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, Invalid
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
@@ -10,7 +10,8 @@ from tinygrad.codegen.late.devectorizer import sym
from itertools import product
def shapetracker_getitem(st:ShapeTracker, val:int):
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)])
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
assert idx.op is Ops.CONST and valid.op is Ops.CONST
return idx.arg, valid.arg
@@ -68,7 +69,7 @@ class CheckingShapeTracker:
def contiguous(self): return self.st.contiguous
def assert_same(self):
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))]
y = [self[i] for i in range(prod(self.shape))]
assert self.st.shape == self.shape
assert x == y, f"mismatch shapetracker:{x} real:{y}"
@@ -154,7 +155,7 @@ class TestRealStrides(unittest.TestCase):
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
))
self.assertEqual(st.real_strides(), (132, None, None, None, None))
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
class TestRealSimplifies(unittest.TestCase):
def tearDown(self):
@@ -816,12 +817,14 @@ class TestShapeTrackerSize(unittest.TestCase):
class TestRender(unittest.TestCase):
def test_render(self):
st = ShapeTracker.from_shape((2, 3))
idx, valid = st.to_indexed_uops()
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "True")
st = st.pad(((0, 1), (0, 0)))
idx, valid = st.to_indexed_uops()
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "(ridx0<2)")

View File

@@ -269,6 +269,7 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
@unittest.skip("this should be constructed with an invalid gate")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)

View File

@@ -2,7 +2,7 @@
import unittest, pickle, functools, math
import z3
from tinygrad.dtype import dtypes, ConstType, DType
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
@@ -934,6 +934,46 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert c == uconst(2)
"""
class TestInvalidIndex(unittest.TestCase):
def test_invalid_times_0(self):
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())*0
self.assertIs(idx.simplify(), (ridx<5).where(0, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
def test_invalid_comparison_drops_invalid(self):
# comparisons return a bool, and bools can't be invalid
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())<3
self.assertIs(idx.simplify(), (ridx<3), "comparison of index should drop the invalid")
self.assertIs(idx.where(UOp.const(dtypes.int, 1), 0).simplify(), (ridx<3).where(UOp.const(dtypes.int, 1), 0),
"comparison of index should drop the invalid")
def test_alu_moves_inside_invalid(self):
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())*10
self.assertIs(idx.simplify(), (ridx<5).where(ridx*10, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
def test_merge_invalid_conditions(self):
ridx0 = Variable("ridx0", 0, 10)
ridx1 = Variable("ridx1", 0, 10)
idx0 = (ridx0<5).where(ridx0, UOp.invalid())
idx1 = (ridx1<5).where(idx0//2, UOp.invalid())
self.assertIs(idx1.simplify(), ((ridx1<5)&(ridx0<5)).where(ridx0//2, UOp.invalid()),
"valid inside a valid should make a single valid and & the conditions")
def test_alu_invalid(self):
self.assertIs((UOp.invalid()*2).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()*0).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()+8).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()+Variable("a",0,10)).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()*Variable("a",0,10)).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()<Variable("a",0,10)).simplify().dtype, dtypes.bool)
def test_alu_invalid_vconst(self):
c1 = UOp.const(dtypes.index.vec(4), (1, 1, Invalid, Invalid))
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
class TestSymbolicRealWorld(unittest.TestCase):
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)

View File

@@ -1,6 +1,6 @@
import unittest, math
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, Invalid
class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_constant(self):
@@ -122,6 +122,15 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
def test_vmin_vmax_invalid(self):
i = UOp.invalid()
self.assertNotEqual(i.vmin, i.vmax)
def test_vmin_vmax_invalid_vconst(self):
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
self.assertLess(x.vmin, 0)
self.assertGreater(x.vmax, 4)
class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_division_positive(self):
# vmin and vmax for division of a variable by a positive constant