mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simpler
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import heapq
|
||||
import heapq, functools
|
||||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -45,7 +45,7 @@ def linearize(u:UOp) -> list[UOp]:
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
# ranges are scheduled as late as possible so anything that can be outside is
|
||||
#if u.op is Ops.RANGE: priority = [2000]
|
||||
if u.op is Ops.RANGE: priority = [2000]
|
||||
# move defines and consts to the top
|
||||
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
|
||||
priorities[u] = min(priority)
|
||||
@@ -101,9 +101,6 @@ pm_add_control_flow = PatternMatcher([
|
||||
|
||||
pm_add_ends = PatternMatcher([
|
||||
# put the end on the store
|
||||
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[:2]).end(ends=s.src[2:]) if len(s.src) > 2 else None),
|
||||
# END is only on RANGES
|
||||
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
|
||||
# for renderering and linearizing, all ends must end one loop
|
||||
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
|
||||
(UPat(Ops.STORE, name="s"), lambda s:
|
||||
functools.reduce(lambda x,y: y.end(x), [x for x in s.src[2:] if x.op is Ops.RANGE], s.replace(src=s.src[:2]))),
|
||||
])
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType, AddrSpace
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -87,11 +87,11 @@ class Scheduler:
|
||||
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
globalizible_rngs = self._globalizable_rngs()
|
||||
output_rngs = flatten([s.src[2:] for s in self.ast.src])
|
||||
ret = []
|
||||
for x,r in zip(self.axis_types, self.rngs):
|
||||
if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE")
|
||||
elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK")
|
||||
elif r not in output_rngs and x == AxisType.LOOP: ret.append("BLACK")
|
||||
else: ret.append(axis_colors[x])
|
||||
return ret
|
||||
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
|
||||
|
||||
@@ -268,20 +268,24 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
|
||||
|
||||
@functools.cached_property
|
||||
def ended_ranges(self):
|
||||
# copy of range_start
|
||||
match self.op:
|
||||
case Ops.REDUCE | Ops.BUFFERIZE: return self.src[1:]
|
||||
case Ops.STORE: return self.src[2:]
|
||||
case Ops.WMMA: return self.src[3:]
|
||||
case Ops.END: return self.src[:1]
|
||||
case _: return ()
|
||||
|
||||
# determine what ranges this is in
|
||||
@recursive_property
|
||||
def _ranges(self) -> dict[UOp, None]:
|
||||
ret: dict[UOp, None] = {}
|
||||
if self.op in range_start.keys():
|
||||
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
|
||||
for s in UOp.sink(*self.src[range_start[self.op]:]).ranges:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
if (er:=self.ended_ranges):
|
||||
for s in UOp.sink(*er).ranges:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op is Ops.END:
|
||||
for s in self.src[self.arg:]: ret.update(s.ranges)
|
||||
for s in UOp.sink(*self.src[:self.arg]).ranges:
|
||||
if s in ret: del ret[s]
|
||||
else:
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
return ret
|
||||
|
||||
@property
|
||||
@@ -289,15 +293,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
return self._ranges
|
||||
|
||||
@functools.cached_property
|
||||
def ended_ranges(self):
|
||||
# copy of range_start
|
||||
match self.op:
|
||||
case Ops.REDUCE: return self.src[1:]
|
||||
case Ops.STORE: return self.src[2:]
|
||||
case Ops.END: return self.src[:self.arg]
|
||||
case _: raise RuntimeError(f"{self.op} doesn't end ranges")
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self, tracked=False, full_symbolic=True):
|
||||
@@ -363,11 +358,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
|
||||
def end(self, *src:UOp, ends:Sequence[UOp]):
|
||||
if len(ends) == 0:
|
||||
if len(src): return UOp(Ops.NOOP, src=(self, *src))
|
||||
return self
|
||||
return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends))
|
||||
def end(self, *src:UOp):
|
||||
assert self.op is Ops.RANGE, "end only ends ranges"
|
||||
return UOp(Ops.END, src=(self,)+src)
|
||||
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
@@ -1196,6 +1189,8 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
|
||||
# TODO: this is only triggering if they are all casts, correct?
|
||||
(UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
|
||||
# no CAST on END
|
||||
(UPat(Ops.END, src=(UPat(Ops.CAST),), allow_any_len=True, name="e"), lambda e: e.replace(src=(e.src[0].src[0],)+e.src[1:])),
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ program_spec = PatternMatcher([
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), allow_any_len=True, arg=1, dtype=dtypes.void), lambda: True),
|
||||
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),
|
||||
|
||||
@@ -382,8 +382,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# END is only on RANGES
|
||||
(UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))),
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user