simplify spec

This commit is contained in:
George Hotz
2025-10-28 09:36:09 +08:00
parent 62e62d8760
commit 7d26342ab6
2 changed files with 2 additions and 60 deletions

View File

@@ -4,7 +4,6 @@ from tinygrad.dtype import AddrSpace
from tinygrad.helpers import DEBUG, Context
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp, AxisType
from tinygrad.uop.symbolic import sym
from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.codegen.late.expander import expander
from test.test_uops import to_uops_list
@@ -811,54 +810,6 @@ class TestExpander(unittest.TestCase):
sink = expander_rewrite(sink)
print(sink)
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, addrspace=AddrSpace.LOCAL), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0")
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
st = UOp(Ops.STORE, dtypes.void, (sbuf.index(idx), UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier))
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
sink = UOp(Ops.SINK, dtypes.void, (store,))
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
def test_expand_ifs_one_gate(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, addrspace=AddrSpace.LOCAL), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0")
gate = valid&(lidx.ne(2))
st = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
def test_expand_ifs_dumb(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0")
gate = valid&(lidx.ne(2))
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_rewrite_to_sink(sink)
if_uops = [u for u in sink.toposort() if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
class TestUOpTags(unittest.TestCase):
def test_inc_by_one(self):
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)

View File

@@ -162,7 +162,7 @@ kernel_spec = PatternMatcher([
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), allow_any_len=True, dtype=dtypes.void), lambda: True),
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
# bufferize (must be on ranges)
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.op in {Ops.RANGE, Ops.CONST} for y in x.src[1:])),
@@ -175,9 +175,6 @@ kernel_spec = PatternMatcher([
# *** this spec should match all UOps ever created ***
full_spec = PatternMatcher([
# any END
(UPat(Ops.END), lambda: True),
# NOOP in the full spec
(UPat(Ops.NOOP), lambda: True),
@@ -191,8 +188,6 @@ full_spec = PatternMatcher([
# rangeify: buffer view with index or load is okay
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
# copy on index
(UPat(Ops.COPY, src=(UPat(Ops.INDEX), UPat())), lambda: True),
# assign on index. the third op is the shape
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
@@ -215,19 +210,15 @@ full_spec = PatternMatcher([
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.VECTORIZE), dtype=dtypes.index), lambda: True),
# while BIND is being casted
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(), UPat()), arg=None), lambda: True),
(UPat(Ops.BIND, (dtypes.int, dtypes.index), (UPat(), UPat()), arg=None), lambda: True),
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
# all ifs
(UPat(Ops.IF), lambda: True),
# all DEFINE_VAR to deal with the floats used in reduce collapse
(UPat(Ops.DEFINE_VAR), lambda: True),
# reshape on STORE
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
])+tensor_spec+kernel_spec+program_spec+shared_spec