Invalid idx (#12067)

* merge index_dtype_3

* new lowering with Invalid idx

* remove that dtype from range

* finish merge

* annotate better

* indentation

* dont need that anymore

* always process replay for openpilot

* more uop_given_valid for idx

* valid past index_child

* fix bug preventing load getting an alt value

* add track_match_stats back in in shapetracker and remove cache

* get_valid_idx -> get_valid and get_idx

* fix heuristics with new idx

* split line

* fix typo

* fix signature

* dont skip idx if stride is 0

the idx may still be invalid

* lower const with new valid

* delete to_indexed_uops

* update shapetracker test

* delete axis_is_masked

* add cache back

* move around comment

* fix get_valid bug

* move invalid fold to symbolic so its earlier

* cleanup

* update applying padto to new idx

* add unit tests

* cleanup

* fold line

* improve spec

* dont try to render Invalid as a float

* more consistent invalid index

* update some tests

* Fold index with true cond

* skip test

* vconst min max if Invalid in arg

* fix signature of UOp.const

* add test for min/max of Invalid CONST/VCONST

* add InvalidType to as_const signature

* is Invalid to isinstance

* Add InvalidType to ConstLike

* index gate is a where gate

* make that a metaclass

* fix heurisics for new idx

* mypy happy
This commit is contained in:
Sieds Lykles
2025-09-12 01:42:02 +02:00
committed by GitHub
parent 544eb2c402
commit 1f3950a484
18 changed files with 188 additions and 86 deletions

View File

@@ -1,6 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device
from tinygrad.shape.shapetracker import views_to_indexed_uops
from tinygrad.shape.shapetracker import views_to_valid_uop
from tinygrad.engine.realize import method_cache, get_program
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
@@ -60,7 +60,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
views_to_indexed_uops.cache_clear()
views_to_valid_uop.cache_clear()
new_uops = uops_allocated()
gc.collect()

View File

@@ -560,7 +560,7 @@ class TestUOpGraph(unittest.TestCase):
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(idx, UOp.const(dtypes.bool, False)),))
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
@@ -573,7 +573,7 @@ class TestUOpGraph(unittest.TestCase):
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
@@ -586,7 +586,7 @@ class TestUOpGraph(unittest.TestCase):
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = glbl.index(idx0, UOp.const(dtypes.bool, False)).store(val)
st0 = glbl.index(UOp.invalid()).store(val)
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, Invalid
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
@@ -10,7 +10,8 @@ from tinygrad.codegen.late.devectorizer import sym
from itertools import product
def shapetracker_getitem(st:ShapeTracker, val:int):
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)])
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
assert idx.op is Ops.CONST and valid.op is Ops.CONST
return idx.arg, valid.arg
@@ -68,7 +69,7 @@ class CheckingShapeTracker:
def contiguous(self): return self.st.contiguous
def assert_same(self):
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))]
y = [self[i] for i in range(prod(self.shape))]
assert self.st.shape == self.shape
assert x == y, f"mismatch shapetracker:{x} real:{y}"
@@ -154,7 +155,7 @@ class TestRealStrides(unittest.TestCase):
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
))
self.assertEqual(st.real_strides(), (132, None, None, None, None))
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
class TestRealSimplifies(unittest.TestCase):
def tearDown(self):
@@ -816,12 +817,14 @@ class TestShapeTrackerSize(unittest.TestCase):
class TestRender(unittest.TestCase):
def test_render(self):
st = ShapeTracker.from_shape((2, 3))
idx, valid = st.to_indexed_uops()
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "True")
st = st.pad(((0, 1), (0, 0)))
idx, valid = st.to_indexed_uops()
valid_idx = st.to_valid_uop()
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
self.assertEqual(valid.render(), "(ridx0<2)")

View File

@@ -269,6 +269,7 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
@unittest.skip("this should be constructed with an invalid gate")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)

View File

@@ -2,7 +2,7 @@
import unittest, pickle, functools, math
import z3
from tinygrad.dtype import dtypes, ConstType, DType
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.codegen import full_rewrite
from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
@@ -934,6 +934,46 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert c == uconst(2)
"""
class TestInvalidIndex(unittest.TestCase):
def test_invalid_times_0(self):
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())*0
self.assertIs(idx.simplify(), (ridx<5).where(0, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
def test_invalid_comparison_drops_invalid(self):
# comparisons return a bool, and bools can't be invalid
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())<3
self.assertIs(idx.simplify(), (ridx<3), "comparison of index should drop the invalid")
self.assertIs(idx.where(UOp.const(dtypes.int, 1), 0).simplify(), (ridx<3).where(UOp.const(dtypes.int, 1), 0),
"comparison of index should drop the invalid")
def test_alu_moves_inside_invalid(self):
ridx = Variable("ridx", 0, 10)
idx = (ridx<5).where(ridx, UOp.invalid())*10
self.assertIs(idx.simplify(), (ridx<5).where(ridx*10, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
def test_merge_invalid_conditions(self):
ridx0 = Variable("ridx0", 0, 10)
ridx1 = Variable("ridx1", 0, 10)
idx0 = (ridx0<5).where(ridx0, UOp.invalid())
idx1 = (ridx1<5).where(idx0//2, UOp.invalid())
self.assertIs(idx1.simplify(), ((ridx1<5)&(ridx0<5)).where(ridx0//2, UOp.invalid()),
"valid inside a valid should make a single valid and & the conditions")
def test_alu_invalid(self):
self.assertIs((UOp.invalid()*2).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()*0).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()+8).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()+Variable("a",0,10)).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()*Variable("a",0,10)).simplify(), UOp.invalid())
self.assertIs((UOp.invalid()<Variable("a",0,10)).simplify().dtype, dtypes.bool)
def test_alu_invalid_vconst(self):
c1 = UOp.const(dtypes.index.vec(4), (1, 1, Invalid, Invalid))
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
class TestSymbolicRealWorld(unittest.TestCase):
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)

View File

@@ -1,6 +1,6 @@
import unittest, math
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, Invalid
class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_constant(self):
@@ -122,6 +122,15 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
def test_vmin_vmax_invalid(self):
i = UOp.invalid()
self.assertNotEqual(i.vmin, i.vmax)
def test_vmin_vmax_invalid_vconst(self):
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
self.assertLess(x.vmin, 0)
self.assertGreater(x.vmax, 4)
class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_division_positive(self):
# vmin and vmax for division of a variable by a positive constant

View File

@@ -2,16 +2,16 @@ from typing import Any, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, prod
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
# wait for it to be image indexed before running simplification
@@ -53,8 +53,10 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
load_store_indexing = PatternMatcher([
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# index True is just Index
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
# lower turn the invalid into a gate, must come before index dtype lowering
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
@@ -76,6 +78,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
idx: Any = midx.src[i].src[1]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
@@ -255,7 +258,7 @@ pm_render = PatternMatcher([
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE) else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \

View File

@@ -38,8 +38,8 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
idx = x.st_arg.to_valid_uop(new_idxs)
used_idxs = [x for x in idx.toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
@@ -47,7 +47,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
stored = subblock(ctx, real_new_idxs, x.src[1])
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
return buf.index(idx, valid).store(stored, *used_ranges)
return buf.index(idx).store(stored, *used_ranges)
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None
@@ -71,9 +71,9 @@ pm_lowerer = PatternMatcher([
# consts and loads
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))),
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])),
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
# reduce/view_const
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),

View File

@@ -53,7 +53,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1]
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
for global_idx in k.axes_of(AxisType.GLOBAL):
@@ -77,7 +77,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
for buf_index,buf in enumerate(k.bufs):
if isinstance(buf.src[0].dtype, ImageDType):
# part of real_strides
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
if len(unit_stride_axes_mul_4):
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
@@ -94,8 +95,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in k.upcastable_dims:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \
any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
@@ -111,11 +111,13 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
rng = k.rngs[axis]
if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents
for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
num_strides, sum_strides = 0, 0
for b in k.bufs:
if rng in b.src[1].parents: num_strides += 1
for c in b.src[1].split_uop(Ops.ADD):
idx = b.src[1].get_idx()
if rng in idx.parents: num_strides += 1
for c in idx.split_uop(Ops.ADD):
if c is rng: sum_strides += 1
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
@@ -157,7 +159,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):

View File

@@ -180,11 +180,10 @@ class Scheduler:
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)
replaces = {rng:replaced_rng}
valid = replaced_rng < rng.vmax+1
for b in self.bufs:
if rng in b.src[1].sparents:
valid = replaced_rng < rng.vmax+1
if len(b.src) > 2: valid = b.src[2] & valid
replaces[b] = b.replace(src=b.src[0:2]+(valid,))
if rng in (i:=b.src[1].get_idx()).sparents:
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
elif opt.op is OptOps.SWAP:
try:

View File

@@ -5,6 +5,21 @@ from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod
from enum import Enum, auto
class InvalidTypeMetaClass(type):
instance:None|InvalidType = None
def __call__(cls, *args, **kwargs):
if (ret:=InvalidTypeMetaClass.instance) is not None: return ret
InvalidTypeMetaClass.instance = ret = super().__call__()
return ret
class InvalidType(metaclass=InvalidTypeMetaClass):
def __eq__(self, other): return self is other
def __hash__(self): return id(self)
def __repr__(self): return "Invalid"
def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance
Invalid = InvalidType()
ConstType = float|int|bool
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
@@ -104,10 +119,11 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
if isinstance(val, InvalidType): return val
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.cache

View File

@@ -1,8 +1,10 @@
from typing import Any
import functools, operator
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY
from tinygrad.uop.symbolic import sym
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.kernelize import Kernel
@@ -136,9 +138,8 @@ def map_pad(idx:UOp, r:UOp):
if resolve(e > 0): where = where & (ret[i] < (sh-e))
if resolve(s > 0): where = where & (ret[i] >= s)
bigwhere = bigwhere & where
# this is safe but dumb
# TODO (S-Lykles): switch to mixed index/valid
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
with Context(TRACK_MATCH_STATS=0):
ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym)
# PAD is with 0
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
@@ -235,17 +236,20 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
out_rngs = []
end_ranges = []
idx_ranges = []
for i,r in enumerate(all_rngs):
if all_same(r):
out_rngs.append(r[0])
for i,valid_rngs in enumerate(all_rngs):
rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
# we compare the ranges without their valids
if all_same(rngs):
# the new valid is the OR of all the children valids
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
out_rngs.append(minimum_valid.where(rngs[0], UOp.invalid()).simplify())
else:
out_rngs.append(ctx.new_range(c.shape[i]))
end_ranges.append(out_rngs[-1])
idx_ranges.append(i)
ctx.seen_child[c] = (idx_ranges, end_ranges)
ctx.seen_child[c] = (out_rngs, idx_ranges, end_ranges)
else:
out_rngs = list(idx.src[1:])
idx_ranges, end_ranges = ctx.seen_child[c]
out_rngs, idx_ranges, end_ranges = ctx.seen_child[c]
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
# index based on the shared ranges
ret = c.index(*out_rngs)

View File

@@ -5,30 +5,24 @@ import functools
from typing import Callable
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, unravel
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
from tinygrad.uop.symbolic import sym
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
@functools.cache
def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> UOp:
idx = views[-1].to_valid_uop(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
idx = view.to_valid_uop([sint_to_uop(i) for i in unravel(view.shape, idx)])
with Context(TRACK_MATCH_STATS=0):
# symbolic
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src
# simplify
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
# symbolic again
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 2").src
return graph_rewrite(idx, sym, name="indexing sym @ 1")
@functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: list[sint|None] = [None] * len(views[-1].shape)
idx, valid = views_to_indexed_uops(views)
idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid()
for c in idx.split_uop(Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
@@ -69,14 +63,14 @@ class ShapeTracker:
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
def to_valid_uop(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> UOp:
return views_to_valid_uop(self.views, tuple(_idxs) if _idxs is not None else None)
# upper bound on buffer size required to fit this shapetracker
def real_size(self) -> int:
if 0 in self.shape: return 0
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
idx, _ = views_to_indexed_uops((view,))
idx = views_to_valid_uop((view,)).get_idx()
assert idx.vmax < 1e12, f"real_size broken for {self}"
return int(idx.vmax + 1)

View File

@@ -112,16 +112,17 @@ class View:
mask:tuple[tuple[sint, sint], ...]|None
contiguous:bool
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
"""(idx, valid)"""
def to_valid_uop(self, idxs:Sequence[UOp]|None=None) -> UOp:
"""valid.where(idx, INVALID)"""
if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)]
iexpr = sint_to_uop(self.offset)
where = UOp.const(dtypes.bool, True)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
iexpr = iexpr + idx*sint_to_uop(st)
if m is not None:
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
return iexpr, vexpr
if resolve(m[0] != 0): where &= (idx >= sint_to_uop(m[0]))
if resolve(m[1] != sh): where &= (idx < sint_to_uop(m[1]))
return where.where(iexpr, UOp.invalid())
@functools.cache # pylint: disable=method-cache-max-size-none
def size(self) -> int:

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
if TYPE_CHECKING:
@@ -319,6 +319,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert len(axis) == len(new_axis)
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
@staticmethod
def invalid(): return UOp(Ops.CONST, dtypes.index, src=(), arg=Invalid)
def get_idx(self) -> UOp:
assert self.dtype is dtypes.index, "Can only call get_idx on index dtype"
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
def get_valid(self) -> UOp:
assert self.dtype is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def realize(self, *args, **kwargs): return UOp(Ops.REALIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
@@ -570,8 +578,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
if self.op is Ops.CONST: return self.arg, self.arg
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
@@ -626,6 +634,7 @@ python_alu: dict[Ops, Callable] = {
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
@@ -699,7 +708,7 @@ class UPat(MathTrait):
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True):
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
@staticmethod
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# lil helper
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
@@ -1120,4 +1129,4 @@ def pyrender(ast:UOp) -> list[str]:
sint = int|UOp
Variable = UOp
ConstLike = ConstType|Variable|tuple[ConstType, ...]
ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...]

View File

@@ -1,6 +1,6 @@
from typing import cast, Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
from tinygrad.shape.shapetracker import ShapeTracker
try:
@@ -178,6 +178,8 @@ spec = PatternMatcher([
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
(UPat(Ops.CONST, arg=Invalid), lambda: False),
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
# INDEX is used in new style load/store
# INDEX takes a <buf, alu, gate?>

View File

@@ -3,7 +3,7 @@ from typing import cast
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
from tinygrad.uop.decompositions import xpow
@@ -22,7 +22,28 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
symbolic_simple = PatternMatcher([
invalid_pat = UPat.const(dtypes.index, Invalid).named("i")
invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat)
propagate_invalid = PatternMatcher([
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
# propagate invalid, push it past children
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
for op in GroupOp.Binary-GroupOp.Comparison),
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
# invalid + y -> y same for other ops
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
# i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# order of gate&!cond matters!, and-clauses are only simplified left to right and we need to gate to be used to fold cond
(UPat.var("gate").where(invalid_gate, UPat.var("y")), lambda gate,cond,x,y,i: ((gate&cond.logical_not()).logical_not()).where(gate.where(x,y), i)),
# unswap the branches for the rule above
(UPat.var("gate").where(UPat.var("y"), invalid_gate).named("where"), lambda gate,cond,x,y,i: gate.logical_not().where(cond.where(x,i), y))
])
symbolic_simple = propagate_invalid + PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
@@ -295,7 +316,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)),
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
if f.arg is not Invalid else None),
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
@@ -395,7 +417,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
for stmt in valid.split_uop(Ops.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
except ValueError: continue # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
@@ -466,6 +488,8 @@ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if
(newx:=uop_given_valid(cond, x)) is not x else None),
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
@@ -489,21 +513,16 @@ sym = symbolic_flat+PatternMatcher([
# ** where **
# push cast to branches
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# ** pow **
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
# index true is index without op
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
# ** load/store folding **
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
# fold gated LOAD/STORE
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
(UPat(Ops.BARRIER, name="root"),
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)

View File

@@ -77,7 +77,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype) else f"{x.arg}"
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: