mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
rangeify load cleanups + multi support (#11765)
* use the old buf_uop + cleanups * simpler handling of load * everything needed for multi too
This commit is contained in:
@@ -33,7 +33,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.ASSIGN
|
||||
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
|
||||
children[ss.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
|
||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, PARTIAL_CONTIG
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint
|
||||
@@ -330,25 +331,14 @@ def bufferize_to_store(x:UOp):
|
||||
assert assign_target.op is Ops.INDEX
|
||||
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype)
|
||||
buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype)
|
||||
|
||||
def add_load_on_buffer(idx:UOp, b:UOp):
|
||||
if isinstance(idx.dtype, PtrDType): return None
|
||||
return idx.replace(dtype=idx.dtype.ptr(b.size), arg=None).load()
|
||||
|
||||
def add_load_on_store(x:UOp, st:UOp):
|
||||
if isinstance(x.dtype, PtrDType): return None
|
||||
assert isinstance(st.dtype, PtrDType), f"{st} has the wrong dtype"
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
#assert st.dtype.size == prod(shape) # if it doesn't, we need a BUFFER_VIEW
|
||||
if st.dtype.size != prod(shape): st = st.shrink(((0, prod(shape)),))
|
||||
return st.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).load()
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat()), name="idx"), add_load_on_buffer),
|
||||
(UPat(Ops.STORE, name="st").f(Ops.INDEX, allow_any_len=True, name="x"), add_load_on_store),
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
|
||||
])
|
||||
|
||||
# 5. split into kernels
|
||||
@@ -372,13 +362,20 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
|
||||
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
|
||||
buf = assign.as_buf()
|
||||
assert buf not in ctx.map
|
||||
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
|
||||
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
|
||||
ctx.map[buf] = assign
|
||||
return buf
|
||||
|
||||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat(Ops.ASSIGN, name="assign"), handle_assign),
|
||||
(UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign),
|
||||
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None),
|
||||
|
||||
# TODO: this can be moved into codegen
|
||||
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
|
||||
@@ -408,7 +405,7 @@ split_kernels = PatternMatcher([
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+earliest_rewrites, name="earliest")
|
||||
realize_map: dict[UOp, UOp] = {}
|
||||
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous")
|
||||
|
||||
@@ -378,7 +378,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
||||
return ret
|
||||
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return UOp(Ops.RESHAPE, self.dtype, src=(self,), arg=arg)
|
||||
def forced_reshape(self, arg:tuple[sint, ...], **kwargs): return UOp(Ops.RESHAPE, kwargs.pop("dtype", self.dtype), src=(self,), arg=arg)
|
||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
||||
@@ -417,7 +417,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return self.src[0].base
|
||||
|
||||
def as_buf(self) -> UOp:
|
||||
# TODO: this should be the only one of these
|
||||
if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)
|
||||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
|
||||
# TODO: this should be the only one of these. this is the one RANGEIFY uses
|
||||
s = self
|
||||
while len(s.src) and s.op is not Ops.BUFFER: s = s.src[0]
|
||||
return s
|
||||
@@ -568,7 +570,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
|
||||
|
||||
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
|
||||
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
|
||||
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
||||
|
||||
class AxisType(Enum):
|
||||
@@ -773,7 +776,7 @@ class PatternMatcher:
|
||||
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
ler = {u.op for u in uop.src}
|
||||
|
||||
Reference in New Issue
Block a user