mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Invalid idx (#12067)
* merge index_dtype_3 * new lowering with Invalid idx * remove that dtype from range * finish merge * annotate better * indentation * dont need that anymore * always process replay for openpilot * more uop_given_valid for idx * valid past index_child * fix bug preventing load getting an alt value * add track_match_stats back in in shapetracker and remove cache * get_valid_idx -> get_valid and get_idx * fix heuristics with new idx * split line * fix typo * fix signature * dont skip idx if stride is 0 the idx may still be invalid * lower const with new valid * delete to_indexed_uops * update shapetracker test * delete axis_is_masked * add cache back * move around comment * fix get_valid bug * move invalid fold to symbolic so its earlier * cleanup * update applying padto to new idx * add unit tests * cleanup * fold line * improve spec * dont try to render Invalid as a float * more consistent invalid index * update some tests * Fold index with true cond * skip test * vconst min max if Invalid in arg * fix signature of UOp.const * add test for min/max of Invalid CONST/VCONST * add InvalidType to as_const signature * is Invalid to isinstance * Add InvalidType to ConstLike * index gate is a where gate * make that a metaclass * fix heurisics for new idx * mypy happy
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
@@ -10,7 +10,8 @@ from tinygrad.codegen.late.devectorizer import sym
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
|
||||
valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)])
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is Ops.CONST and valid.op is Ops.CONST
|
||||
return idx.arg, valid.arg
|
||||
@@ -68,7 +69,7 @@ class CheckingShapeTracker:
|
||||
def contiguous(self): return self.st.contiguous
|
||||
|
||||
def assert_same(self):
|
||||
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
|
||||
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))]
|
||||
y = [self[i] for i in range(prod(self.shape))]
|
||||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
@@ -154,7 +155,7 @@ class TestRealStrides(unittest.TestCase):
|
||||
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
|
||||
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
|
||||
))
|
||||
self.assertEqual(st.real_strides(), (132, None, None, None, None))
|
||||
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -816,12 +817,14 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
class TestRender(unittest.TestCase):
|
||||
def test_render(self):
|
||||
st = ShapeTracker.from_shape((2, 3))
|
||||
idx, valid = st.to_indexed_uops()
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "True")
|
||||
|
||||
st = st.pad(((0, 1), (0, 0)))
|
||||
idx, valid = st.to_indexed_uops()
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
|
||||
@@ -269,6 +269,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
|
||||
self.check(load, None, "gidx0", "(gidx1+5)")
|
||||
|
||||
@unittest.skip("this should be constructed with an invalid gate")
|
||||
def test_valid_empty_set(self):
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType, DType
|
||||
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
|
||||
@@ -934,6 +934,46 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
assert c == uconst(2)
|
||||
"""
|
||||
|
||||
class TestInvalidIndex(unittest.TestCase):
|
||||
def test_invalid_times_0(self):
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())*0
|
||||
self.assertIs(idx.simplify(), (ridx<5).where(0, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
|
||||
|
||||
def test_invalid_comparison_drops_invalid(self):
|
||||
# comparisons return a bool, and bools can't be invalid
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())<3
|
||||
self.assertIs(idx.simplify(), (ridx<3), "comparison of index should drop the invalid")
|
||||
self.assertIs(idx.where(UOp.const(dtypes.int, 1), 0).simplify(), (ridx<3).where(UOp.const(dtypes.int, 1), 0),
|
||||
"comparison of index should drop the invalid")
|
||||
|
||||
def test_alu_moves_inside_invalid(self):
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())*10
|
||||
self.assertIs(idx.simplify(), (ridx<5).where(ridx*10, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
|
||||
|
||||
def test_merge_invalid_conditions(self):
|
||||
ridx0 = Variable("ridx0", 0, 10)
|
||||
ridx1 = Variable("ridx1", 0, 10)
|
||||
idx0 = (ridx0<5).where(ridx0, UOp.invalid())
|
||||
idx1 = (ridx1<5).where(idx0//2, UOp.invalid())
|
||||
self.assertIs(idx1.simplify(), ((ridx1<5)&(ridx0<5)).where(ridx0//2, UOp.invalid()),
|
||||
"valid inside a valid should make a single valid and & the conditions")
|
||||
|
||||
def test_alu_invalid(self):
|
||||
self.assertIs((UOp.invalid()*2).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()*0).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()+8).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()+Variable("a",0,10)).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()*Variable("a",0,10)).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()<Variable("a",0,10)).simplify().dtype, dtypes.bool)
|
||||
|
||||
def test_alu_invalid_vconst(self):
|
||||
c1 = UOp.const(dtypes.index.vec(4), (1, 1, Invalid, Invalid))
|
||||
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
|
||||
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
def test_resnet_half(self):
|
||||
gidx0 = Variable("gidx0", 0, 3)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, math
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
|
||||
class TestVminVmaxProperties(unittest.TestCase):
|
||||
def test_vmin_vmax_constant(self):
|
||||
@@ -122,6 +122,15 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
|
||||
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
|
||||
|
||||
def test_vmin_vmax_invalid(self):
|
||||
i = UOp.invalid()
|
||||
self.assertNotEqual(i.vmin, i.vmax)
|
||||
|
||||
def test_vmin_vmax_invalid_vconst(self):
|
||||
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
|
||||
self.assertLess(x.vmin, 0)
|
||||
self.assertGreater(x.vmax, 4)
|
||||
|
||||
class TestVminVmaxDivMod(unittest.TestCase):
|
||||
def test_vmin_vmax_division_positive(self):
|
||||
# vmin and vmax for division of a variable by a positive constant
|
||||
|
||||
Reference in New Issue
Block a user