mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Invalid idx (#12067)
* merge index_dtype_3 * new lowering with Invalid idx * remove that dtype from range * finish merge * annotate better * indentation * dont need that anymore * always process replay for openpilot * more uop_given_valid for idx * valid past index_child * fix bug preventing load getting an alt value * add track_match_stats back in in shapetracker and remove cache * get_valid_idx -> get_valid and get_idx * fix heuristics with new idx * split line * fix typo * fix signature * dont skip idx if stride is 0 the idx may still be invalid * lower const with new valid * delete to_indexed_uops * update shapetracker test * delete axis_is_masked * add cache back * move around comment * fix get_valid bug * move invalid fold to symbolic so its earlier * cleanup * update applying padto to new idx * add unit tests * cleanup * fold line * improve spec * dont try to render Invalid as a float * more consistent invalid index * update some tests * Fold index with true cond * skip test * vconst min max if Invalid in arg * fix signature of UOp.const * add test for min/max of Invalid CONST/VCONST * add InvalidType to as_const signature * is Invalid to isinstance * Add InvalidType to ConstLike * index gate is a where gate * make that a metaclass * fix heurisics for new idx * mypy happy
This commit is contained in:
4
test/external/external_uop_gc.py
vendored
4
test/external/external_uop_gc.py
vendored
@@ -1,6 +1,6 @@
|
||||
import gc
|
||||
from tinygrad import Tensor, UOp, Device
|
||||
from tinygrad.shape.shapetracker import views_to_indexed_uops
|
||||
from tinygrad.shape.shapetracker import views_to_valid_uop
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
|
||||
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
|
||||
@@ -60,7 +60,7 @@ if __name__ == "__main__":
|
||||
|
||||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
views_to_indexed_uops.cache_clear()
|
||||
views_to_valid_uop.cache_clear()
|
||||
|
||||
new_uops = uops_allocated()
|
||||
gc.collect()
|
||||
|
||||
@@ -560,7 +560,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(idx, UOp.const(dtypes.bool, False)),))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(UOp.invalid()),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
@@ -573,7 +573,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(UOp.invalid()), barrier))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
@@ -586,7 +586,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = glbl.index(idx0, UOp.const(dtypes.bool, False)).store(val)
|
||||
st0 = glbl.index(UOp.invalid()).store(val)
|
||||
st1 = glbl.index(idx0, UOp.const(dtypes.bool, True)).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
@@ -10,7 +10,8 @@ from tinygrad.codegen.late.devectorizer import sym
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
|
||||
valid_idx = st.reshape((st.size,)).to_valid_uop([UOp.const(dtypes.int, val)])
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is Ops.CONST and valid.op is Ops.CONST
|
||||
return idx.arg, valid.arg
|
||||
@@ -68,7 +69,7 @@ class CheckingShapeTracker:
|
||||
def contiguous(self): return self.st.contiguous
|
||||
|
||||
def assert_same(self):
|
||||
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
|
||||
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] and v[0] is not Invalid else -1) for i in range(prod(self.st.shape))]
|
||||
y = [self[i] for i in range(prod(self.shape))]
|
||||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
@@ -154,7 +155,7 @@ class TestRealStrides(unittest.TestCase):
|
||||
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
|
||||
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
|
||||
))
|
||||
self.assertEqual(st.real_strides(), (132, None, None, None, None))
|
||||
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -816,12 +817,14 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
class TestRender(unittest.TestCase):
|
||||
def test_render(self):
|
||||
st = ShapeTracker.from_shape((2, 3))
|
||||
idx, valid = st.to_indexed_uops()
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "True")
|
||||
|
||||
st = st.pad(((0, 1), (0, 0)))
|
||||
idx, valid = st.to_indexed_uops()
|
||||
valid_idx = st.to_valid_uop()
|
||||
idx, valid = valid_idx.get_idx(), valid_idx.get_valid()
|
||||
self.assertEqual(idx.render(), "((ridx0*3)+ridx1)")
|
||||
self.assertEqual(valid.render(), "(ridx0<2)")
|
||||
|
||||
|
||||
@@ -269,6 +269,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
|
||||
self.check(load, None, "gidx0", "(gidx1+5)")
|
||||
|
||||
@unittest.skip("this should be constructed with an invalid gate")
|
||||
def test_valid_empty_set(self):
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType, DType
|
||||
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer, track_rewrites
|
||||
@@ -934,6 +934,46 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
assert c == uconst(2)
|
||||
"""
|
||||
|
||||
class TestInvalidIndex(unittest.TestCase):
|
||||
def test_invalid_times_0(self):
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())*0
|
||||
self.assertIs(idx.simplify(), (ridx<5).where(0, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
|
||||
|
||||
def test_invalid_comparison_drops_invalid(self):
|
||||
# comparisons return a bool, and bools can't be invalid
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())<3
|
||||
self.assertIs(idx.simplify(), (ridx<3), "comparison of index should drop the invalid")
|
||||
self.assertIs(idx.where(UOp.const(dtypes.int, 1), 0).simplify(), (ridx<3).where(UOp.const(dtypes.int, 1), 0),
|
||||
"comparison of index should drop the invalid")
|
||||
|
||||
def test_alu_moves_inside_invalid(self):
|
||||
ridx = Variable("ridx", 0, 10)
|
||||
idx = (ridx<5).where(ridx, UOp.invalid())*10
|
||||
self.assertIs(idx.simplify(), (ridx<5).where(ridx*10, UOp.invalid()), "multiplying an index by 0 should preserve the invalid")
|
||||
|
||||
def test_merge_invalid_conditions(self):
|
||||
ridx0 = Variable("ridx0", 0, 10)
|
||||
ridx1 = Variable("ridx1", 0, 10)
|
||||
idx0 = (ridx0<5).where(ridx0, UOp.invalid())
|
||||
idx1 = (ridx1<5).where(idx0//2, UOp.invalid())
|
||||
self.assertIs(idx1.simplify(), ((ridx1<5)&(ridx0<5)).where(ridx0//2, UOp.invalid()),
|
||||
"valid inside a valid should make a single valid and & the conditions")
|
||||
|
||||
def test_alu_invalid(self):
|
||||
self.assertIs((UOp.invalid()*2).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()*0).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()+8).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()+Variable("a",0,10)).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()*Variable("a",0,10)).simplify(), UOp.invalid())
|
||||
self.assertIs((UOp.invalid()<Variable("a",0,10)).simplify().dtype, dtypes.bool)
|
||||
|
||||
def test_alu_invalid_vconst(self):
|
||||
c1 = UOp.const(dtypes.index.vec(4), (1, 1, Invalid, Invalid))
|
||||
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
|
||||
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
def test_resnet_half(self):
|
||||
gidx0 = Variable("gidx0", 0, 3)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, math
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
|
||||
class TestVminVmaxProperties(unittest.TestCase):
|
||||
def test_vmin_vmax_constant(self):
|
||||
@@ -122,6 +122,15 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(x_uint.vmin, dtypes.min(dtypes.uint))
|
||||
self.assertEqual(x_uint.vmax, dtypes.max(dtypes.uint))
|
||||
|
||||
def test_vmin_vmax_invalid(self):
|
||||
i = UOp.invalid()
|
||||
self.assertNotEqual(i.vmin, i.vmax)
|
||||
|
||||
def test_vmin_vmax_invalid_vconst(self):
|
||||
x = UOp.const(dtypes.index.vec(4), (0, 4, Invalid, Invalid))
|
||||
self.assertLess(x.vmin, 0)
|
||||
self.assertGreater(x.vmax, 4)
|
||||
|
||||
class TestVminVmaxDivMod(unittest.TestCase):
|
||||
def test_vmin_vmax_division_positive(self):
|
||||
# vmin and vmax for division of a variable by a positive constant
|
||||
|
||||
@@ -2,16 +2,16 @@ from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
@@ -53,8 +53,10 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
|
||||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# index True is just Index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
||||
# lower turn the invalid into a gate, must come before index dtype lowering
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
@@ -76,6 +78,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
idx: Any = midx.src[i].src[1]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
@@ -255,7 +258,7 @@ pm_render = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE) else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
|
||||
@@ -38,8 +38,8 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
idx, valid = x.st_arg.to_indexed_uops(new_idxs)
|
||||
used_idxs = [x for x in UOp.sink(idx, valid).toposort() if x in new_idxs]
|
||||
idx = x.st_arg.to_valid_uop(new_idxs)
|
||||
used_idxs = [x for x in idx.toposort() if x in new_idxs]
|
||||
real_new_idxs = []
|
||||
for i in range(len(x.src[0].shape)):
|
||||
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
|
||||
@@ -47,7 +47,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
stored = subblock(ctx, real_new_idxs, x.src[1])
|
||||
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
|
||||
return buf.index(idx, valid).store(stored, *used_ranges)
|
||||
return buf.index(idx).store(stored, *used_ranges)
|
||||
|
||||
def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
@@ -71,9 +71,9 @@ pm_lowerer = PatternMatcher([
|
||||
|
||||
# consts and loads
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
|
||||
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_indexed_uops(ctx.idxs)[1].where(c, c.const_like(0))),
|
||||
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
|
||||
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(*x.st_arg.to_indexed_uops(ctx.idxs)),)+x.src[1:])),
|
||||
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
|
||||
|
||||
# reduce/view_const
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
|
||||
@@ -53,7 +53,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
|
||||
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1]
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1].get_idx(), mulop.src[1].src[0].src[1].get_idx()
|
||||
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||||
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
@@ -77,7 +77,8 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
if isinstance(buf.src[0].dtype, ImageDType):
|
||||
# part of real_strides
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
||||
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
if len(unit_stride_axes_mul_4):
|
||||
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||||
@@ -94,8 +95,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in k.upcastable_dims:
|
||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||
is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \
|
||||
any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||||
is_masked = any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||||
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
@@ -111,11 +111,13 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
||||
rng = k.rngs[axis]
|
||||
if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||||
if any(rng not in b.src[1].get_idx().parents and all(r2 in b.src[1].get_idx().parents
|
||||
for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||||
num_strides, sum_strides = 0, 0
|
||||
for b in k.bufs:
|
||||
if rng in b.src[1].parents: num_strides += 1
|
||||
for c in b.src[1].split_uop(Ops.ADD):
|
||||
idx = b.src[1].get_idx()
|
||||
if rng in idx.parents: num_strides += 1
|
||||
for c in idx.split_uop(Ops.ADD):
|
||||
if c is rng: sum_strides += 1
|
||||
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
|
||||
@@ -157,7 +159,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \
|
||||
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].get_idx().parents for b in k.bufs), axis) \
|
||||
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
|
||||
to_local: list[tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
|
||||
@@ -180,11 +180,10 @@ class Scheduler:
|
||||
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
|
||||
replaced_rng = UOp.range(new_sz, *rng.arg)
|
||||
replaces = {rng:replaced_rng}
|
||||
valid = replaced_rng < rng.vmax+1
|
||||
for b in self.bufs:
|
||||
if rng in b.src[1].sparents:
|
||||
valid = replaced_rng < rng.vmax+1
|
||||
if len(b.src) > 2: valid = b.src[2] & valid
|
||||
replaces[b] = b.replace(src=b.src[0:2]+(valid,))
|
||||
if rng in (i:=b.src[1].get_idx()).sparents:
|
||||
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
|
||||
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
|
||||
elif opt.op is OptOps.SWAP:
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,21 @@ from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from enum import Enum, auto
|
||||
|
||||
class InvalidTypeMetaClass(type):
|
||||
instance:None|InvalidType = None
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if (ret:=InvalidTypeMetaClass.instance) is not None: return ret
|
||||
InvalidTypeMetaClass.instance = ret = super().__call__()
|
||||
return ret
|
||||
|
||||
class InvalidType(metaclass=InvalidTypeMetaClass):
|
||||
def __eq__(self, other): return self is other
|
||||
def __hash__(self): return id(self)
|
||||
def __repr__(self): return "Invalid"
|
||||
def __reduce__(self): return (InvalidType, ()) # Return the global Invalid instance
|
||||
|
||||
Invalid = InvalidType()
|
||||
|
||||
ConstType = float|int|bool
|
||||
|
||||
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
|
||||
@@ -104,10 +119,11 @@ class dtypes:
|
||||
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
||||
@staticmethod
|
||||
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
|
||||
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
if isinstance(val, InvalidType): return val
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any
|
||||
import functools, operator
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
@@ -136,9 +138,8 @@ def map_pad(idx:UOp, r:UOp):
|
||||
if resolve(e > 0): where = where & (ret[i] < (sh-e))
|
||||
if resolve(s > 0): where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
# TODO (S-Lykles): switch to mixed index/valid
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym)
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
|
||||
|
||||
@@ -235,17 +236,20 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
out_rngs = []
|
||||
end_ranges = []
|
||||
idx_ranges = []
|
||||
for i,r in enumerate(all_rngs):
|
||||
if all_same(r):
|
||||
out_rngs.append(r[0])
|
||||
for i,valid_rngs in enumerate(all_rngs):
|
||||
rngs, valids = zip(*[(r.get_idx(), r.get_valid()) for r in valid_rngs])
|
||||
# we compare the ranges without their valids
|
||||
if all_same(rngs):
|
||||
# the new valid is the OR of all the children valids
|
||||
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
|
||||
out_rngs.append(minimum_valid.where(rngs[0], UOp.invalid()).simplify())
|
||||
else:
|
||||
out_rngs.append(ctx.new_range(c.shape[i]))
|
||||
end_ranges.append(out_rngs[-1])
|
||||
idx_ranges.append(i)
|
||||
ctx.seen_child[c] = (idx_ranges, end_ranges)
|
||||
ctx.seen_child[c] = (out_rngs, idx_ranges, end_ranges)
|
||||
else:
|
||||
out_rngs = list(idx.src[1:])
|
||||
idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
out_rngs, idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
|
||||
# index based on the shared ranges
|
||||
ret = c.index(*out_rngs)
|
||||
|
||||
@@ -5,30 +5,24 @@ import functools
|
||||
from typing import Callable
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, unravel
|
||||
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
||||
|
||||
@functools.cache
|
||||
def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
|
||||
idx, valid = views[-1].to_indexed_uops(_idxs)
|
||||
def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> UOp:
|
||||
idx = views[-1].to_valid_uop(_idxs)
|
||||
for view in reversed(views[0:-1]):
|
||||
view = view.minify()
|
||||
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
||||
idx = view.to_valid_uop([sint_to_uop(i) for i in unravel(view.shape, idx)])
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
# symbolic
|
||||
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src
|
||||
# simplify
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
|
||||
# symbolic again
|
||||
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 2").src
|
||||
return graph_rewrite(idx, sym, name="indexing sym @ 1")
|
||||
|
||||
@functools.cache
|
||||
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
||||
ret: list[sint|None] = [None] * len(views[-1].shape)
|
||||
idx, valid = views_to_indexed_uops(views)
|
||||
idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid()
|
||||
for c in idx.split_uop(Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
||||
@@ -69,14 +63,14 @@ class ShapeTracker:
|
||||
|
||||
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
|
||||
def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
|
||||
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
def to_valid_uop(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> UOp:
|
||||
return views_to_valid_uop(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
|
||||
# upper bound on buffer size required to fit this shapetracker
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
|
||||
idx, _ = views_to_indexed_uops((view,))
|
||||
idx = views_to_valid_uop((view,)).get_idx()
|
||||
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
||||
return int(idx.vmax + 1)
|
||||
|
||||
|
||||
@@ -112,16 +112,17 @@ class View:
|
||||
mask:tuple[tuple[sint, sint], ...]|None
|
||||
contiguous:bool
|
||||
|
||||
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
||||
"""(idx, valid)"""
|
||||
def to_valid_uop(self, idxs:Sequence[UOp]|None=None) -> UOp:
|
||||
"""valid.where(idx, INVALID)"""
|
||||
if idxs is None: idxs = [UOp.range(s, i) for i,s in enumerate(self.shape)]
|
||||
iexpr = sint_to_uop(self.offset)
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
iexpr = iexpr + idx*sint_to_uop(st)
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
|
||||
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
|
||||
return iexpr, vexpr
|
||||
if resolve(m[0] != 0): where &= (idx >= sint_to_uop(m[0]))
|
||||
if resolve(m[1] != sh): where &= (idx < sint_to_uop(m[1]))
|
||||
return where.where(iexpr, UOp.invalid())
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def size(self) -> int:
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||
if TYPE_CHECKING:
|
||||
@@ -319,6 +319,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert len(axis) == len(new_axis)
|
||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||
@staticmethod
|
||||
def invalid(): return UOp(Ops.CONST, dtypes.index, src=(), arg=Invalid)
|
||||
def get_idx(self) -> UOp:
|
||||
assert self.dtype is dtypes.index, "Can only call get_idx on index dtype"
|
||||
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
|
||||
def get_valid(self) -> UOp:
|
||||
assert self.dtype is dtypes.index, "Can only call get_valid on index dtype"
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def realize(self, *args, **kwargs): return UOp(Ops.REALIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
@@ -570,8 +578,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
|
||||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
if self.op is Ops.CONST: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
|
||||
@@ -626,6 +634,7 @@ python_alu: dict[Ops, Callable] = {
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
if dtype.count > 1:
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid
|
||||
alu = python_alu[op](*operands)
|
||||
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
|
||||
|
||||
@@ -699,7 +708,7 @@ class UPat(MathTrait):
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
||||
@staticmethod
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# lil helper
|
||||
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
|
||||
@@ -1120,4 +1129,4 @@ def pyrender(ast:UOp) -> list[str]:
|
||||
sint = int|UOp
|
||||
Variable = UOp
|
||||
|
||||
ConstLike = ConstType|Variable|tuple[ConstType, ...]
|
||||
ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast, Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
try:
|
||||
@@ -178,6 +178,8 @@ spec = PatternMatcher([
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
(UPat(Ops.VCONST, name="x"), lambda x: all(v is not Invalid for v in x.src)),
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
|
||||
@@ -22,7 +22,28 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
||||
def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
|
||||
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
|
||||
|
||||
symbolic_simple = PatternMatcher([
|
||||
invalid_pat = UPat.const(dtypes.index, Invalid).named("i")
|
||||
invalid_gate = UPat.var("cond").where(UPat.var("x",dtype=dtypes.index), invalid_pat)
|
||||
|
||||
propagate_invalid = PatternMatcher([
|
||||
# this needs to be before symbolic so that 0*something_that_might_be_invalid doesnt become 0
|
||||
# propagate invalid, push it past children
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: cond.where(x.alu(alu.op,y), i))
|
||||
for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
*((invalid_gate.alu(op, UPat.var("y")).named("alu"), lambda cond,x,y,alu,i: x.alu(alu.op,y)) for op in GroupOp.Comparison),
|
||||
# invalid + y -> y same for other ops
|
||||
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)).named("alu"), lambda alu,i: i) for op in GroupOp.Binary-GroupOp.Comparison),
|
||||
# i < y -> a_bool_value_that_will_never_be_used: we choose a random bool const
|
||||
*((invalid_pat.alu(op, UPat(dtype=dtypes.index)), lambda i: UOp.const(dtypes.bool, True)) for op in GroupOp.Comparison),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
# order of gate&!cond matters!, and-clauses are only simplified left to right and we need to gate to be used to fold cond
|
||||
(UPat.var("gate").where(invalid_gate, UPat.var("y")), lambda gate,cond,x,y,i: ((gate&cond.logical_not()).logical_not()).where(gate.where(x,y), i)),
|
||||
# unswap the branches for the rule above
|
||||
(UPat.var("gate").where(UPat.var("y"), invalid_gate).named("where"), lambda gate,cond,x,y,i: gate.logical_not().where(cond.where(x,i), y))
|
||||
])
|
||||
|
||||
symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
@@ -295,7 +316,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)),
|
||||
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
|
||||
if f.arg is not Invalid else None),
|
||||
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
||||
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
||||
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
@@ -395,7 +417,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
except ValueError: continue # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
||||
@@ -466,6 +488,8 @@ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if
|
||||
(newx:=uop_given_valid(cond, x)) is not x else None),
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
@@ -489,21 +513,16 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# ** where **
|
||||
# push cast to branches
|
||||
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
# ** pow **
|
||||
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
|
||||
# index true is index without op
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
# ** load/store folding **
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
||||
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # NULL pointer store does nothing. NULL pointer load produces 0
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
|
||||
@@ -77,7 +77,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x in excluded:
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(u.dtype) else f"{x.arg}"
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
try:
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
|
||||
Reference in New Issue
Block a user