This commit is contained in:
George Hotz
2025-10-23 10:21:27 +08:00
parent 6e00dec95d
commit f4cea6a403
5 changed files with 26 additions and 36 deletions

View File

@@ -1,4 +1,4 @@
import heapq
import heapq, functools
from typing import cast
from collections import defaultdict
from tinygrad.dtype import dtypes
@@ -45,7 +45,7 @@ def linearize(u:UOp) -> list[UOp]:
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.BARRIER: priority.append(-1500)
# ranges are scheduled as late as possible so anything that can be outside is
#if u.op is Ops.RANGE: priority = [2000]
if u.op is Ops.RANGE: priority = [2000]
# move defines and consts to the top
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
priorities[u] = min(priority)
@@ -101,9 +101,6 @@ pm_add_control_flow = PatternMatcher([
pm_add_ends = PatternMatcher([
# put the end on the store
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[:2]).end(ends=s.src[2:]) if len(s.src) > 2 else None),
# END is only on RANGES
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
# for renderering and linearizing, all ends must end one loop
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
(UPat(Ops.STORE, name="s"), lambda s:
functools.reduce(lambda x,y: y.end(x), [x for x in s.src[2:] if x.op is Ops.RANGE], s.replace(src=s.src[:2]))),
])

View File

@@ -5,7 +5,7 @@ from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType, AddrSpace
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@@ -87,11 +87,11 @@ class Scheduler:
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]:
globalizible_rngs = self._globalizable_rngs()
output_rngs = flatten([s.src[2:] for s in self.ast.src])
ret = []
for x,r in zip(self.axis_types, self.rngs):
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK")
elif r not in output_rngs and x == AxisType.LOOP: ret.append("BLACK")
else: ret.append(axis_colors[x])
return ret
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])

View File

@@ -268,20 +268,24 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
@functools.cached_property
def ended_ranges(self):
# copy of range_start
match self.op:
case Ops.REDUCE | Ops.BUFFERIZE: return self.src[1:]
case Ops.STORE: return self.src[2:]
case Ops.WMMA: return self.src[3:]
case Ops.END: return self.src[:1]
case _: return ()
# determine what ranges this is in
@recursive_property
def _ranges(self) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
if self.op in range_start.keys():
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
for s in UOp.sink(*self.src[range_start[self.op]:]).ranges:
for s in self.src: ret.update(s.ranges)
if (er:=self.ended_ranges):
for s in UOp.sink(*er).ranges:
if s in ret: del ret[s]
elif self.op is Ops.END:
for s in self.src[self.arg:]: ret.update(s.ranges)
for s in UOp.sink(*self.src[:self.arg]).ranges:
if s in ret: del ret[s]
else:
for s in self.src: ret.update(s.ranges)
return ret
@property
@@ -289,15 +293,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.RANGE: return {self:None}
return self._ranges
@functools.cached_property
def ended_ranges(self):
# copy of range_start
match self.op:
case Ops.REDUCE: return self.src[1:]
case Ops.STORE: return self.src[2:]
case Ops.END: return self.src[:self.arg]
case _: raise RuntimeError(f"{self.op} doesn't end ranges")
# *** uop evaluation ***
def simplify(self, tracked=False, full_symbolic=True):
@@ -363,11 +358,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
def end(self, *src:UOp, ends:Sequence[UOp]):
if len(ends) == 0:
if len(src): return UOp(Ops.NOOP, src=(self, *src))
return self
return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends))
def end(self, *src:UOp):
assert self.op is Ops.RANGE, "end only ends ranges"
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
@@ -1196,6 +1189,8 @@ pm_lower_index_dtype = PatternMatcher([
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
# TODO: this is only triggering if they are all casts, correct?
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
# no CAST on END
(UPat(Ops.END, src=(UPat(Ops.CAST),), allow_any_len=True, name="e"), lambda e: e.replace(src=(e.src[0].src[0],)+e.src[1:])),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View File

@@ -120,7 +120,7 @@ program_spec = PatternMatcher([
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), allow_any_len=True, arg=1, dtype=dtypes.void), lambda: True),
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),

View File

@@ -382,8 +382,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# END is only on RANGES
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
])+gep_pushing
symbolic_flat = symbolic+PatternMatcher([