diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8e3d214216..1e6b5259ad 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -33,7 +33,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ for ss in s.src: if ss.op is Ops.MSELECT: ss = ss.src[0] if ss.op is not Ops.BUFFER: - assert ss.op is Ops.ASSIGN + assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}" children[ss.src[1]].append(k) in_degree[k] += 1 elif s.op is Ops.BUFFER: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index fc7dd02e9a..74fa551d08 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, PARTIAL_CONTIG +from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.kernelize import Kernel from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element, sint @@ -330,25 +331,14 @@ def bufferize_to_store(x:UOp): assert assign_target.op is Ops.INDEX return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype) buf = UOp.new_buffer(x.arg, prod(shape), x.dtype) - return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype) - -def add_load_on_buffer(idx:UOp, b:UOp): - if isinstance(idx.dtype, PtrDType): return None - return idx.replace(dtype=idx.dtype.ptr(b.size), arg=None).load() - -def add_load_on_store(x:UOp, st:UOp): - if isinstance(x.dtype, PtrDType): return None - assert isinstance(st.dtype, PtrDType), f"{st} has the wrong dtype" - rngs = x.src[1:] - shape = tuple([int(r.vmax+1) for r in rngs]) - #assert st.dtype.size == prod(shape) # if it doesn't, we need a BUFFER_VIEW - if st.dtype.size != prod(shape): st = st.shrink(((0, prod(shape)),)) - return st.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).load() + return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype) pm_add_buffers = pm_mops+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), - (UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat()), name="idx"), add_load_on_buffer), - (UPat(Ops.STORE, name="st").f(Ops.INDEX, allow_any_len=True, name="x"), add_load_on_store), + + # move RESHAPEs through MSELECT/MSTACK + (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), + lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)), ]) # 5. split into kernels @@ -372,13 +362,20 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): def handle_assign(ctx:LocalAddBufferContext, assign:UOp): buf = assign.as_buf() assert buf not in ctx.map + # HACK to put the buffer in the MAP instead of MSTACK/MSELECT + if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] ctx.map[buf] = assign return buf to_define_global = PatternMatcher([ (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.BIND, name="b"), unbind_kernel), - (UPat(Ops.ASSIGN, name="assign"), handle_assign), + (UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign), + + # add loads to non ptr indexes + # TODO: this can be moved into codegen? + (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: idx.replace(dtype=dg.dtype, arg=None).load() if not isinstance(idx.dtype, PtrDType) else None), # TODO: this can be moved into codegen (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), @@ -408,7 +405,7 @@ split_kernels = PatternMatcher([ @track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True) def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: - tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest") + tensor_map = graph_rewrite_map(sink, multi_pm+earliest_rewrites, name="earliest") realize_map: dict[UOp, UOp] = {} graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph") tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1cde66ae3e..6a16d82f6a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -378,7 +378,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.st == ret.st: return self # ignore NOOPs, also check ret.st return ret - def forced_reshape(self, arg:tuple[sint, ...]): return UOp(Ops.RESHAPE, self.dtype, src=(self,), arg=arg) + def forced_reshape(self, arg:tuple[sint, ...], **kwargs): return UOp(Ops.RESHAPE, kwargs.pop("dtype", self.dtype), src=(self,), arg=arg) def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg) def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg) def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg) @@ -417,7 +417,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return self.src[0].base def as_buf(self) -> UOp: - # TODO: this should be the only one of these + if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg) + if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src)) + # TODO: this should be the only one of these. this is the one RANGEIFY uses s = self while len(s.src) and s.op is not Ops.BUFFER: s = s.src[0] return s @@ -568,7 +570,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames}) def render(self, simplify=True, pm:PatternMatcher|None=None) -> str: - ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm) + with Context(TRACK_MATCH_STATS=0): + ret = graph_rewrite(self.simplify() if simplify else self, renderer if pm is None else pm) return ret.arg if ret.op is Ops.NOOP else str(ret) class AxisType(Enum): @@ -773,7 +776,7 @@ class PatternMatcher: def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "" else fxn) for x,fxn in self.patterns],) @functools.cache # pylint: disable=method-cache-max-size-none - def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) + def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp, ctx=None) -> UOp|None: ler = {u.op for u in uop.src}