rangeify load cleanups + multi support (#11765)

* use the old buf_uop + cleanups

* simpler handling of load

* everything needed for multi too
This commit is contained in:
George Hotz
2025-08-20 20:55:49 -07:00
committed by GitHub
parent 9f94c25a25
commit a044648111
3 changed files with 23 additions and 23 deletions

View File

@@ -33,7 +33,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN
assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}"
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, PARTIAL_CONTIG
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint
@@ -330,25 +331,14 @@ def bufferize_to_store(x:UOp):
assert assign_target.op is Ops.INDEX
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype)
buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype)
def add_load_on_buffer(idx:UOp, b:UOp):
if isinstance(idx.dtype, PtrDType): return None
return idx.replace(dtype=idx.dtype.ptr(b.size), arg=None).load()
def add_load_on_store(x:UOp, st:UOp):
if isinstance(x.dtype, PtrDType): return None
assert isinstance(st.dtype, PtrDType), f"{st} has the wrong dtype"
rngs = x.src[1:]
shape = tuple([int(r.vmax+1) for r in rngs])
#assert st.dtype.size == prod(shape) # if it doesn't, we need a BUFFER_VIEW
if st.dtype.size != prod(shape): st = st.shrink(((0, prod(shape)),))
return st.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).load()
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat()), name="idx"), add_load_on_buffer),
(UPat(Ops.STORE, name="st").f(Ops.INDEX, allow_any_len=True, name="x"), add_load_on_store),
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
])
# 5. split into kernels
@@ -372,13 +362,20 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
buf = assign.as_buf()
assert buf not in ctx.map
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
ctx.map[buf] = assign
return buf
to_define_global = PatternMatcher([
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat(Ops.ASSIGN, name="assign"), handle_assign),
(UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign),
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
(UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None),
# TODO: this can be moved into codegen
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
@@ -408,7 +405,7 @@ split_kernels = PatternMatcher([
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
tensor_map = graph_rewrite_map(sink, multi_pm+earliest_rewrites, name="earliest")
realize_map: dict[UOp, UOp] = {}
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous")

View File

@@ -378,7 +378,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
return ret
def forced_reshape(self, arg:tuple[sint, ...]): return UOp(Ops.RESHAPE, self.dtype, src=(self,), arg=arg)
def forced_reshape(self, arg:tuple[sint, ...], **kwargs): return UOp(Ops.RESHAPE, kwargs.pop("dtype", self.dtype), src=(self,), arg=arg)
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
@@ -417,7 +417,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return self.src[0].base
def as_buf(self) -> UOp:
# TODO: this should be the only one of these
if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
# TODO: this should be the only one of these. this is the one RANGEIFY uses
s = self
while len(s.src) and s.op is not Ops.BUFFER: s = s.src[0]
return s
@@ -568,7 +570,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
with Context(TRACK_MATCH_STATS=0):
ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm)
return ret.arg if ret.op is Ops.NOOP else str(ret)
class AxisType(Enum):
@@ -773,7 +776,7 @@ class PatternMatcher:
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
@functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
ler = {u.op for u in uop.src}