From 357dac8425b57fa8db108f4b609b4e7692168cc7 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 19 Oct 2025 19:11:05 -0700 Subject: [PATCH 01/30] feat: allow tuple indexing on uops (#12797) --- tinygrad/uop/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 55ee4a69c1..488fc0774c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -324,7 +324,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, *srcs:UOp|None, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) - def __getitem__(self, idx): return self.index(idx) + def __getitem__(self, *idx): return self.index(*idx) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape) From 339e6edb7d2a574d91b3e10f148fc7b03a3fa316 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:15:15 +0800 Subject: [PATCH 02/30] viz: ui prereqs for hierarchical rewrites (#12799) --- tinygrad/viz/index.html | 10 +++++++--- tinygrad/viz/js/index.js | 6 ++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 34f68ce448..8731a457b6 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -41,11 +41,13 @@ } ul { padding: 0; - opacity: 0.6; white-space: nowrap; cursor: pointer; } - ul.active { + ul > p { + opacity: 0.6; + } + ul.active > p { opacity: 1; } ul > ul { @@ -54,8 +56,10 @@ ul.expanded > ul { display: block; } - ul.disabled { + ul.disabled > p { opacity: 0.4; + } + ul.disabled { pointer-events: none; } label { diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 90ad10b6be..a0f5a3af25 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -608,7 +608,8 @@ async function main() { for (const [j,u] of steps.entries()) { const inner = ul.appendChild(document.createElement("ul")); inner.id = `step-${i}-${j}`; - inner.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : ''); + const p = inner.appendChild(document.createElement("p")); + p.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : ''); inner.style.marginLeft = `${8*u.depth}px`; inner.onclick = (e) => { e.stopPropagation(); @@ -706,8 +707,9 @@ async function main() { rewriteList.className = "rewrite-list"; for (let s=0; s<=step.match_count; s++) { const ul = rewriteList.appendChild(document.createElement("ul")); - ul.innerText = s; ul.id = `rewrite-${s}`; + const p = ul.appendChild(document.createElement("p")); + p.innerText = s; ul.onclick = () => setState({ currentRewrite:s }); ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : ""; if (s > 0 && s === currentRewrite) { From 2e9082e0bcc325467c464b8e54d8e1c409882621 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:27:56 +0800 Subject: [PATCH 03/30] after op (#12801) * after op * fix tests --- test/unit/test_kernelize.py | 6 +++--- tinygrad/engine/schedule.py | 8 ++++---- tinygrad/schedule/indexing.py | 4 ++-- tinygrad/schedule/rangeify.py | 31 ++++++++++++++++--------------- tinygrad/tensor.py | 6 +++--- tinygrad/uop/__init__.py | 3 +++ tinygrad/uop/ops.py | 9 ++++++--- tinygrad/uop/spec.py | 7 +++++-- tinygrad/viz/serve.py | 2 +- 9 files changed, 43 insertions(+), 33 deletions(-) diff --git a/test/unit/test_kernelize.py b/test/unit/test_kernelize.py index e571c1d297..3cc0b0c0cc 100644 --- a/test/unit/test_kernelize.py +++ b/test/unit/test_kernelize.py @@ -20,8 +20,8 @@ class TestKernelize(unittest.TestCase): self.assertEqual(len([s for s in a0.uop.toposort() if s.op is Ops.KERNEL]), 2) self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS) # input Tensor and user contiguous kernelize - self.assertIs(a0.uop.base.op, Ops.ASSIGN) - self.assertIs(a.uop.base.op, Ops.ASSIGN) + self.assertIs(a0.uop.base.op, Ops.AFTER) + self.assertIs(a.uop.base.op, Ops.AFTER) def test_two_reduce_w_add(self): a = Tensor.ones(16,16).contiguous() @@ -31,7 +31,7 @@ class TestKernelize(unittest.TestCase): # NOTE: the +1 is fused with a1, so a1 is not kernelized self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS) # the input to the REDUCE_AXIS is an ASSIGN though - self.assertIs(a1.uop.base.src[0].base.op, Ops.ASSIGN) + self.assertIs(a1.uop.base.src[0].base.op, Ops.AFTER) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e56c908309..655cd0d242 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -22,18 +22,18 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ in_degree: dict[UOp, int] = {} var_vals: dict[str, int] = {} for u in sched_sink.toposort(): - if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip + if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] in_degree.setdefault(k, 0) for s in k.src: - if s.op is Ops.ASSIGN: + if s.op is Ops.AFTER: children[s.src[1]].append(k) in_degree[k] += 1 elif s.op in {Ops.MSELECT, Ops.MSTACK}: for ss in s.src: if ss.op is Ops.MSELECT: ss = ss.src[0] if ss.op is not Ops.BUFFER: - assert ss.op is Ops.ASSIGN, f"ss.op is not ASSIGN, it's {ss.op}" + assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}" children[ss.src[1]].append(k) in_degree[k] += 1 elif s.op is Ops.BUFFER: @@ -43,7 +43,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" var_vals[var.expr] = val else: - raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}") + raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}") # linearize KERNEL UOps into ScheduleItems in BFS order diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index d4b048a823..a78ac20cdc 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -52,11 +52,11 @@ class IndexingContext: def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): if x.op in {Ops.BUFFERIZE, Ops.INDEX, Ops.KERNEL}: return None - if x.op is Ops.ASSIGN and x.src[1].op is Ops.KERNEL: return None + if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None new_srcs = [] for s in x.src: new_src = s - if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): + if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: realized_ranges = ctx.realize_map[s] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c058a255bb..9e1eaf002c 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -242,7 +242,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp): bufs: set[UOp] = set() def gate_input(u:UOp): # TODO: add cache to fix n^2 - if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) + if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) return not is_load root.toposort(gate=gate_input) @@ -277,7 +277,8 @@ def bufferize_to_store(x:UOp): assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" # in assign, this is the buffer size, not the bufferize size # TODO: assign_mops here - ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag) + do_store = assign_target.replace(dtype=sdtype).store(assign_src, *rngs).replace(tag=x.tag) + ret = assign_target.src[0].after(do_store) mops = [] walk = assign_mops while walk is not assign_mops.base: @@ -289,8 +290,8 @@ def bufferize_to_store(x:UOp): # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg.device, size, x.dtype) - ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag) - ret = ret.forced_reshape(shape) + do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs).replace(tag=x.tag) + ret = buf.after(do_store).forced_reshape(shape) # TODO: is this right? what if it's offset if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs]) @@ -302,7 +303,7 @@ def bufferize_to_store(x:UOp): if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) # store has the other dtype here - # TODO: how is this unified? + # TODO: use after here? return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ @@ -335,12 +336,12 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): ctx.vars[b] = None return b.src[0] -def handle_assign(ctx:LocalAddBufferContext, assign:UOp): - buf = assign.as_buf() +def handle_after(ctx:LocalAddBufferContext, after:UOp): + buf = after.as_buf() # HACK to put the buffer in the MAP instead of MSTACK/MSELECT if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] assert buf not in ctx.map - ctx.map[buf] = assign + ctx.map[buf] = after return buf def renumber_range(ctx:LocalAddBufferContext, r:UOp): @@ -350,7 +351,7 @@ def renumber_range(ctx:LocalAddBufferContext, r:UOp): return ret def find_bufs(x:UOp): - idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.ASSIGN) if s.op is Ops.INDEX] + idxs = [s for s in x.toposort(gate=lambda x: x.op is not Ops.AFTER) if s.op is Ops.INDEX] read_from: dict[UOp, Ops] = {} if any((buf:=idx.as_buf()).op is Ops.BUFFER and read_from.setdefault(buf, op:=idx.src[0].op) is not op for idx in idxs): raise RuntimeError(f"cycle detected while indexing {buf}") @@ -359,7 +360,7 @@ to_define_global = PatternMatcher([ (UPat(Ops.STORE, name="x"), find_bufs), (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.BIND, name="b"), unbind_kernel), - (UPat((Ops.ASSIGN, Ops.MSTACK, Ops.MSELECT), name="assign"), handle_assign), + (UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after), # HACK in case any CONSTs were replaced # this is only needed if you are using symbolic @@ -418,7 +419,7 @@ class Kernel: ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op) return f"" -def split_store(ctx:list[UOp], x:UOp): +def split_store(ctx:list[UOp], x:UOp) -> UOp|None: if len(x.ranges): return None if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None @@ -436,7 +437,7 @@ def split_store(ctx:list[UOp], x:UOp): kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") - return x.as_buf().assign(kernel) + return kernel split_kernels = PatternMatcher([ (UPat(Ops.STORE, name="x"), split_store), @@ -523,13 +524,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} for u in tsink.toposort(): - if u.op is not Ops.ASSIGN: continue + if u.op is not Ops.AFTER: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: # TODO: this is probably broken for MSELECT/MSTACK if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue - if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()): - raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") + if any(x.op is Ops.AFTER and x.buf_uop is s for x in u.toposort()): + raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1028519341..74edfc145f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -249,9 +249,9 @@ class Tensor(MathTrait): self.kernelize(*lst) sink = UOp.sink(*[x.uop for x in (self,)+lst]) - # remove all ASSIGNs, after scheduling, the tensors are just buffers - remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.ASSIGN} - _apply_map_to_tensors(remove_assign_map, name="Remove Assigns") + # remove all AFTERs, after scheduling, the tensors are just buffers + remove_assign_map = {u:u.buf_uop for u in sink.toposort() if u.op is Ops.AFTER} + _apply_map_to_tensors(remove_assign_map, name="Remove After") # create the schedule schedule, var_vals = create_schedule_with_vars(sink) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 2922fd4471..4879a6daa6 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -12,6 +12,9 @@ class Ops(FastEnum): NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702 SENTINEL = auto() + # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] + AFTER = auto() + # buffer ops COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702 diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 488fc0774c..131ea70538 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -190,7 +190,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE: return self.src[0]._shape + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER: + return self.src[0]._shape # ops with custom handling case Ops.KERNEL: return self.arg.ast._shape @@ -349,6 +350,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs) + def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def alu(self, op, *src:UOp, **kwargs): @@ -525,6 +527,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _device(self) -> str|tuple[str, ...]|None: if self.op is Ops.DEVICE: return self.arg if self.op is Ops.BUFFERIZE: return self.arg.device + if self.op is Ops.AFTER: return self.src[0].device if self.op is Ops.MSELECT: assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device" return self.src[0].device[self.arg] @@ -538,8 +541,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.BUFFER: return self if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src)) - assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" - return self.src[0].base + assert self.op is Ops.AFTER, f"must be AFTER {self.op}" + return self.src[0].buf_uop.base def as_buf(self) -> UOp: if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 78be792353..f35784ed83 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -66,8 +66,8 @@ buffer_spec = PatternMatcher([ ]) assign_spec = PatternMatcher([ - # KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), + # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True), # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), @@ -111,6 +111,9 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([ # REDUCE with an outerworld range (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), + + # AFTER if things were kernelized + (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True) ]) # ***** uop type spec ***** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index baa3656850..ad597a53f3 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -20,7 +20,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", - Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00"} + Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866"} # VIZ API From 734c99f722a216612ea6a77677485c9d75a3c504 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:37:03 +0800 Subject: [PATCH 04/30] viz: show indexing rewrites during run_rangeify (#12802) * viz: show indexing rewrites during run_rangeify * sinking index --- tinygrad/viz/serve.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index ad597a53f3..88b67adf81 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -98,9 +98,10 @@ def _reconstruct(a:int): return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest) def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: - ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"}) - yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, - "diff":None, "upat":None} + next_sink = _reconstruct(ctx.sink) + ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"} or + any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,))) + yield {"graph":uop_to_json(next_sink, ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) From 12fd2c9c7bc74848b0513a125d48e7737530ab42 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 20 Oct 2025 13:11:57 +0800 Subject: [PATCH 05/30] explicitly set ignore_indexing for schedule only (#12803) --- tinygrad/viz/serve.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 88b67adf81..b92246224e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -99,8 +99,9 @@ def _reconstruct(a:int): def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]: next_sink = _reconstruct(ctx.sink) - ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"} or - any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,))) + # in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink) + ignore_indexing = trace.keys[i].display_name.startswith("Schedule") and not (ctx.name in {"kernel split"} or \ + any(s.dtype is dtypes.index for s in next_sink.src+(next_sink,))) yield {"graph":uop_to_json(next_sink, ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): From b8a9cce7832e764c92b9d20d72c6635d9a8473d9 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:34:32 +0800 Subject: [PATCH 06/30] replace NOOP with AFTER in reg init (#12804) * after op * fix tests * replace NOOP with AFTER in reg init * closer * or_after there * fix device * fix all renderers * better spec for after --- tinygrad/codegen/late/devectorizer.py | 16 ++++++++++------ tinygrad/renderer/cstyle.py | 3 +++ tinygrad/renderer/llvmir.py | 3 +++ tinygrad/renderer/nir.py | 7 +++++-- tinygrad/renderer/ptx.py | 3 +++ tinygrad/runtime/ops_python.py | 3 ++- tinygrad/uop/ops.py | 7 ++++++- tinygrad/uop/spec.py | 9 +++++++-- 8 files changed, 39 insertions(+), 12 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index c0012b73ff..7eeb9e68ac 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -123,7 +123,7 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp): return gep.src[0].store(st.gep(new_arg), *sto.src[2:]) load_store_folding = PatternMatcher([ - (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index), + (UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines).or_after(name="buf")), UPat.var("vec"))), expand_index), # GEP after LOAD (UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True), lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)), @@ -242,11 +242,13 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt)))) devectorize = PatternMatcher([ + # CAST after AFTER + (UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)), # no ALU on vectorized dtypes (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), - (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), ]) pm_render = PatternMatcher([ @@ -296,12 +298,14 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE]) input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges]) identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar())) - acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0)) - do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity) - lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element + acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)) + acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \ + acc.index(UOp.const(dtypes.int, 0)).store(identity) + lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0)).load()] + lst # put acc as the first element ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) - return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret + if len(reduce_range) == 0: return ret + return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret, *reduce_range)).index(UOp.const(dtypes.int, 0)).load() pm_reduce = PatternMatcher([ # REDUCE -> DEFINE_ACC+ASSIGN diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index c3a8e1508d..2140abe6e7 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -144,6 +144,9 @@ class CStyleLanguage(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + r[u] = r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 8be73d536f..032532e75c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -167,6 +167,9 @@ class LLVMRenderer(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + r[u] = r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 26cf519c7d..efaeddbecd 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -1,4 +1,4 @@ -from typing import Callable, cast +from typing import Callable, cast, Any from tinygrad.dtype import AddrSpace, DType, PtrDType, dtypes from tinygrad.helpers import DEBUG, OSX, unwrap from tinygrad.renderer import Renderer @@ -169,10 +169,13 @@ class NIRRenderer(Renderer): def render(self, uops:list[UOp]): self.prerender(uops) for u in [u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]: self.b.shader.contents.info.workgroup_size[int(u.arg[-1])] = u.src[0].arg - self.r, self.param_idx, ranges = {}, 0, [] + self.r: dict[UOp, Any] = {} + self.param_idx, ranges = 0, [] for u in uops: if u.op == Ops.NOOP or u.op == Ops.INDEX: pass + elif u.op is Ops.AFTER: + self.r[u] = self.r[u.src[0]] elif u.op == Ops.SINK: if u.arg is not None: self.b.shader.contents.info.name = mesa.char_pointer_cast(u.arg.function_name) elif u.op == Ops.DEFINE_LOCAL: diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 4695c880c3..a57ee6a838 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -179,6 +179,9 @@ class PTXRenderer(Renderer): name = "test" for u in uops: if u.op is Ops.NOOP: continue + if u.op is Ops.AFTER: + self.r[u] = self.r[u.src[0]] + continue if u.op is Ops.SINK: if u.arg is not None: name = u.arg.function_name continue diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 9dd145d299..afb1bb87f7 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -72,7 +72,8 @@ class PythonProgram: if g: _store(m, o+j, v, dtp[1].scalar()) i += 1 continue - if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: + if uop is Ops.AFTER: ul[i] = inp[0] + elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: assert isinstance(dtype, PtrDType), dtype storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 131ea70538..72bb277b1e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -527,7 +527,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _device(self) -> str|tuple[str, ...]|None: if self.op is Ops.DEVICE: return self.arg if self.op is Ops.BUFFERIZE: return self.arg.device - if self.op is Ops.AFTER: return self.src[0].device + if self.op is Ops.AFTER: return self.src[0]._device if self.op is Ops.MSELECT: assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device" return self.src[0].device[self.arg] @@ -813,6 +813,8 @@ class UPat(MathTrait): @staticmethod def any(*src): return UPatAny(src=src) def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,))) + def or_after(self, name:str|None=None): + return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True)) @staticmethod @functools.cache @@ -1174,7 +1176,10 @@ pm_lower_index_dtype = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)), (UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"), lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))), + # TODO: this is only triggering if they are all casts, correct? (UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))), + # TODO: this should be more general + (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=tuple(y.src[0] if y.op is Ops.CAST and y.dtype.scalar()==dtypes.index else y for y in x.src))), ]) def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index f35784ed83..61ca7aeda1 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -165,6 +165,9 @@ spec = PatternMatcher([ (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), + # allow AFTER on buffers + (UPat(Ops.AFTER, src=(UPat(GroupOp.Defines),), allow_any_len=True), lambda: True), + # **** new style load/store **** # make sure all index dtypes have been lowered @@ -174,8 +177,8 @@ spec = PatternMatcher([ # INDEX is used in new style load/store # INDEX takes a - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True), - (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), + (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), # LOAD on STORE (UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True), @@ -286,6 +289,8 @@ full_spec = PatternMatcher([ (UPat(Ops.DEFINE_VAR), lambda: True), # reshape on STORE (UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True), + # allow any AFTER + (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), ])+tensor_uop_spec+spec # ***** uop helpers ***** From b5e36e3c6c0c94a89e7b50cdf83bbf0eada8316d Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:13:16 +0800 Subject: [PATCH 07/30] nv: check if jitlink is avail (#12808) * nv: check if jitlink is avail * why * fix * fix --- tinygrad/runtime/support/compiler_cuda.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index 8f83c34657..7e8ff6150a 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -69,7 +69,9 @@ class PTXCompiler(Compiler): def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) class NVPTXCompiler(PTXCompiler): - def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx") + def __init__(self, arch:str): + nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint()))) + super().__init__(arch, cache_key="nv_ptx") def compile(self, src:str) -> bytes: jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle) jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "".encode()), handle) From 1e93d19ee3a6cbe2c06dcf8aec52830836faa464 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:41:06 +0200 Subject: [PATCH 08/30] stable diffusion --fakeweights (#12810) --- examples/sdv2.py | 20 ++++++++++++-------- examples/stable_diffusion.py | 6 ++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/sdv2.py b/examples/sdv2.py index 29b1abb8fd..856cf239ad 100644 --- a/examples/sdv2.py +++ b/examples/sdv2.py @@ -99,6 +99,7 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--noshow', action='store_true', help="Don't show the image") parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() N = 1 @@ -112,19 +113,22 @@ if __name__ == "__main__": model = StableDiffusionV2(**params) - default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' - weights_fn = args.weights_fn - if not weights_fn: - weights_url = args.weights_url if args.weights_url else default_weights_url - weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) - with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - load_state_dict(model, safe_load(weights_fn), strict=False) + if not args.fakeweights: + default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' + weights_fn = args.weights_fn + if not weights_fn: + weights_url = args.weights_url if args.weights_url else default_weights_url + weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) + + load_state_dict(model, safe_load(weights_fn), strict=False) if args.fp16: for k,v in get_state_dict(model).items(): if k.startswith("model"): - v.replace(v.cast(dtypes.float16).realize()) + v.replace(v.cast(dtypes.float16)) + + Tensor.realize(*get_state_dict(model).values()) c = { "crossattn": model.cond_stage_model(args.prompt) } uc = { "crossattn": model.cond_stage_model("") } diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 644c524476..4650b7e1d9 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -263,14 +263,16 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--seed', type=int, help="Set the random latent seed") parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() model = StableDiffusion() # load in weights with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') - load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) + if not args.fakeweights: + model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') + load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) if args.fp16: for k,v in get_state_dict(model).items(): From a8e461443638212212e3fa6d232534642e70fe7c Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Mon, 20 Oct 2025 12:44:20 +0200 Subject: [PATCH 09/30] remove REAL_SUBSTITUTE=0 and make it fast (#12809) * fast REAL_substitute * remove REAL_SUBSTITUTE=0 --- test/test_rangeify.py | 2 +- tinygrad/helpers.py | 1 - tinygrad/schedule/rangeify.py | 34 ++++++---------------------------- tinygrad/uop/ops.py | 2 +- 4 files changed, 8 insertions(+), 31 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index d0a4eea1c1..ab8f8b8cfb 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -62,7 +62,7 @@ class TestPcontig(unittest.TestCase): Tensor.realize(*ret) return ret - with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2): + with Context(PCONTIG=2, DEBUG=2): grads = fa_bw() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 1e39715692..aeb1fc5d8a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -170,7 +170,6 @@ SPEC = ContextVar("SPEC", 0) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify -REAL_SUBSTITUTE = ContextVar("REAL_SUBSTITUTE", 0) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 9e1eaf002c..439bf4e613 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,9 +2,9 @@ from typing import cast from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo -from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType +from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate from tinygrad.uop.symbolic import symbolic_flat -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op @@ -136,6 +136,9 @@ def cleanup_dead_axes(b:UOp): # move the tag to the expand. NOTE: this expand tag might not survive return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag) +def gate_substitute(ctx, b:UOp) -> None: + if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate() +pm_gate_substitute = PatternMatcher([(UPat(GroupOp.All, name="b"), gate_substitute)], compiled=False) # if a buffer is being stored just for permutes or something, remove it # we want to reexpress the indexes of idx2 in terms of the implied b1 def remove_bufferize(src:UOp, buf:UOp, idx:UOp): @@ -178,11 +181,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): # if it makes it here, the bufferize is removed # this is the ranges replaced # NOTE: if buf src is a const, we don't replace it - if REAL_SUBSTITUTE: - return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}) - else: - replaces = flatten([(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]) - return UOp(Ops.SUBSTITUTE, dtype=src.dtype, src=(src, UOp(Ops.NOOP, src=tuple(replaces[0::2])), UOp(Ops.NOOP, src=tuple(replaces[1::2])))) + return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute) def pre_bufferize(b:UOp, x:UOp, copy:UOp): nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) @@ -471,25 +470,6 @@ replace_contiguous = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -def do_sub_recurse(s:UOp): - x,keys,values = s.src[0], s.src[1].src, s.src[2].src - # SUBSTITUTE applied to SUBSTITUTE runs the child SUB on the parents. though this is probably wrong in the generic case - if x.op is Ops.SUBSTITUTE: - sub_k = UOp(Ops.SUBSTITUTE, src=(x.src[1],)+s.src[1:]) - sub_v = UOp(Ops.SUBSTITUTE, src=(x.src[2],)+s.src[1:]) - return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v)) - # here we actually do the SUBSTITUTE - if x in keys: return values[keys.index(x)] - # we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size) - x_ranges = x.ranges - new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)} - # if there's no SUBSTITUTEs left, we can just return x - if len(new_kv) == 0: return x - # then we add SUBSTITUTE to all parents - uop_keys, uop_values = UOp(Ops.NOOP, src=tuple(new_kv.keys())), UOp(Ops.NOOP, src=tuple(new_kv.values())) - return x.replace(src=tuple([UOp(Ops.SUBSTITUTE, dtype=y.dtype, src=(y,uop_keys,uop_values)) for y in x.src])) -pm_substitute_recurse = PatternMatcher([(UPat(Ops.SUBSTITUTE, src=(UPat(), UPat(Ops.NOOP), UPat(Ops.NOOP)), name="s"), do_sub_recurse)]) - @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph") @@ -504,8 +484,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") - # TODO: can you substitute and remove costly buffers at the same time? - tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 72bb277b1e..951bc4ce8d 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1096,7 +1096,7 @@ class RewriteContext: new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) except BottomUpGate: # if the bpm matching raised a gate, we are done with this node and dont continue down the srcs - self.replace[n] = new_n + self.replace[n] = unwrap(test_n) continue stack.append((n, 1, new_n)) for x in reversed(new_n.src): From d1e2c393f8d59be8acde7f651313e995e894bc70 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:54:37 +0800 Subject: [PATCH 10/30] after in sym, axis_letters in range (#12811) * after in sym, axis_letters in range * this is better * this work? --- tinygrad/renderer/cstyle.py | 4 ++-- tinygrad/uop/ops.py | 4 +--- tinygrad/uop/symbolic.py | 5 +++++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 2140abe6e7..e6d01bfc97 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,7 +1,7 @@ from typing import Literal, Callable, cast import os, math, sys from collections import defaultdict, Counter -from tinygrad.codegen.opt import tc +from tinygrad.codegen.opt import tc, axis_letters from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate @@ -163,7 +163,7 @@ class CStyleLanguage(Renderer): # naming prefix = None if u.op is Ops.SPECIAL: r[u] = u.arg - elif u.op is Ops.RANGE: r[u] = "ridx"+range_str(u) + elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u) else: prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast", diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 951bc4ce8d..322d713852 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -169,7 +169,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def ptrdtype(self) -> PtrDType: - if not isinstance(self.dtype, PtrDType): raise RuntimeError("ptrdtype called on UOp without PtrDType") + if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}") return self.dtype # *** uop shape stuff *** @@ -1178,8 +1178,6 @@ pm_lower_index_dtype = PatternMatcher([ lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))), # TODO: this is only triggering if they are all casts, correct? (UPat((Ops.SINK, Ops.NOOP), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))), - # TODO: this should be more general - (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=tuple(y.src[0] if y.op is Ops.CAST and y.dtype.scalar()==dtypes.index else y for y in x.src))), ]) def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index e9fec2ae9e..91cf8390e4 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -377,6 +377,11 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat(GroupOp.Binary, src=(UPat.var("x", dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y: x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None), ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), + # only RANGE/IF/STORE/KERNEL have side effects + (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER} else y.src for y in x.src[1:]])))), + # after with 1 src is just src[0] + (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ From 5d0d3d7aac798e65f490e9c1a8dae957fff09316 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:24:24 +0800 Subject: [PATCH 11/30] after clean up of locals (#12813) --- test/test_uops_stats.py | 10 ---------- tinygrad/codegen/gpudims.py | 10 +++++++++- tinygrad/codegen/late/expander.py | 3 +-- tinygrad/codegen/opt/postrange.py | 2 +- tinygrad/schedule/rangeify.py | 18 +++++------------- 5 files changed, 16 insertions(+), 27 deletions(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 845ab8b325..39c631206b 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -208,16 +208,6 @@ class TestStatsOptimized(unittest.TestCase): self.check_gemm(p) self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) - def test_gemm_group(self): - try: - p = get_program(self.ast_gemm, opts=[Opt(OptOps.GROUP, 0, 4)]) - except KernelOptError: - raise unittest.SkipTest("no locals") - SZ = N*N*4 - # NOTE: these are sort of wrong. they aren't honoring the IF statement - self.check_gemm(p, extra_flops=SZ*4) - self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) - def test_reduce(self): p = get_program(self.ast_reduce, opts=[]) print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 5f406f78b0..5169450883 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,4 +1,4 @@ -import math +import math, functools, operator from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop from tinygrad.helpers import all_int, dedup, get_contraction from tinygrad.dtype import dtypes @@ -87,7 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp): except ValueError: continue return s.substitute(subs) +def add_barrier_and_if(buf:UOp, s:UOp): + # TODO: this is not generic + local_ranges = [x for x in s.src[1:] if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] + if len(local_ranges) == 0: return None + return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), s.barrier()))) + pm_add_gpudims = PatternMatcher([ # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), + # add barrier and if + (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.STORE, name="s"))), add_barrier_and_if), ]) diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 9a42d414ce..c594d6315d 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -145,8 +145,7 @@ def fix_group_for_reduce(x:UOp): reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) - # gate with an if on the store + do the final reduce - buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf)) + # do the final reduce (if/barrier are added in gpudims step) return buf.reduce(*reduce_loop, arg=x.arg) pm_pre_expander = PatternMatcher([ diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 55a443dfdf..4263ba67ff 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -348,7 +348,7 @@ def apply_opts(ctx:Renderer, ast:UOp): elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet - if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD): + if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 439bf4e613..1f037406c1 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -301,9 +301,7 @@ def bufferize_to_store(x:UOp): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - # store has the other dtype here - # TODO: use after here? - return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape) + return buf.after(buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs)).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), @@ -336,6 +334,7 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): return b.src[0] def handle_after(ctx:LocalAddBufferContext, after:UOp): + if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None buf = after.as_buf() # HACK to put the buffer in the MAP instead of MSTACK/MSELECT if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] @@ -388,16 +387,9 @@ rangeify_codegen = PatternMatcher([ # add loads to non ptr indexes # TODO: this can be moved into codegen? - (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), - - # TODO: this can be moved into codegen - (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), - lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), - - # TODO: hack for group for reduce - (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), - lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), + (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) + .f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), ]) def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): From 203a93363cd99fe2913c8814de12c3ebe1758266 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:33:35 +0800 Subject: [PATCH 12/30] Revert "after clean up of locals (#12813)" (#12814) This reverts commit 5d0d3d7aac798e65f490e9c1a8dae957fff09316. --- test/test_uops_stats.py | 10 ++++++++++ tinygrad/codegen/gpudims.py | 10 +--------- tinygrad/codegen/late/expander.py | 3 ++- tinygrad/codegen/opt/postrange.py | 2 +- tinygrad/schedule/rangeify.py | 18 +++++++++++++----- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 39c631206b..845ab8b325 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -208,6 +208,16 @@ class TestStatsOptimized(unittest.TestCase): self.check_gemm(p) self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) + def test_gemm_group(self): + try: + p = get_program(self.ast_gemm, opts=[Opt(OptOps.GROUP, 0, 4)]) + except KernelOptError: + raise unittest.SkipTest("no locals") + SZ = N*N*4 + # NOTE: these are sort of wrong. they aren't honoring the IF statement + self.check_gemm(p, extra_flops=SZ*4) + self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) + def test_reduce(self): p = get_program(self.ast_reduce, opts=[]) print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 5169450883..5f406f78b0 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,4 +1,4 @@ -import math, functools, operator +import math from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop from tinygrad.helpers import all_int, dedup, get_contraction from tinygrad.dtype import dtypes @@ -87,15 +87,7 @@ def add_gpudims(ctx:Renderer, s:UOp): except ValueError: continue return s.substitute(subs) -def add_barrier_and_if(buf:UOp, s:UOp): - # TODO: this is not generic - local_ranges = [x for x in s.src[1:] if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] - if len(local_ranges) == 0: return None - return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), s.barrier()))) - pm_add_gpudims = PatternMatcher([ # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), - # add barrier and if - (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.STORE, name="s"))), add_barrier_and_if), ]) diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index c594d6315d..9a42d414ce 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -145,7 +145,8 @@ def fix_group_for_reduce(x:UOp): reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) - # do the final reduce (if/barrier are added in gpudims step) + # gate with an if on the store + do the final reduce + buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf)) return buf.reduce(*reduce_loop, arg=x.arg) pm_pre_expander = PatternMatcher([ diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 4263ba67ff..55a443dfdf 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -348,7 +348,7 @@ def apply_opts(ctx:Renderer, ast:UOp): elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet - if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice): + if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 1f037406c1..439bf4e613 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -301,7 +301,9 @@ def bufferize_to_store(x:UOp): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - return buf.after(buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs)).reshape(shape) + # store has the other dtype here + # TODO: use after here? + return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), @@ -334,7 +336,6 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): return b.src[0] def handle_after(ctx:LocalAddBufferContext, after:UOp): - if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None buf = after.as_buf() # HACK to put the buffer in the MAP instead of MSTACK/MSELECT if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] @@ -387,9 +388,16 @@ rangeify_codegen = PatternMatcher([ # add loads to non ptr indexes # TODO: this can be moved into codegen? - (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) - .f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), + (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), + + # TODO: this can be moved into codegen + (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), + lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), + + # TODO: hack for group for reduce + (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), + lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), ]) def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): From e284f6325a787145cc544fd6e470222723c075ec Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:46:48 +0800 Subject: [PATCH 13/30] llvm: fix compile key for different processors (#12812) --- tinygrad/runtime/support/compiler_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index f9ec8d1062..04c2987180 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -58,7 +58,7 @@ class LLVMCompiler(Compiler): self.diag_msgs.append(msg) self.handle_diag = handle_diag llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None) - super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") + super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo) From c7c59e6dd71158f50bbb9a87298b4ed1d65a6fb6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 20 Oct 2025 12:24:58 -0400 Subject: [PATCH 14/30] unused UPat.or_broadcasted and GroupOp.Block [pr] (#12819) --- tinygrad/uop/__init__.py | 1 - tinygrad/uop/ops.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 4879a6daa6..cfdfa39a8a 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -94,7 +94,6 @@ class GroupOp: Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP} Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR} - Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART} # BinaryOps that can be flipped Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR} diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 322d713852..b08c193f9c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -841,7 +841,6 @@ class UPat(MathTrait): def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) def fuse(self): return self.alu(Ops.FUSE) def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs) - def or_broadcasted(self, **kwargs): return UPat.any(self, self.broadcast(**kwargs)) def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) From 25beea576956b26b802e17379bc3a0cabc6003c9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 21 Oct 2025 09:04:36 +0800 Subject: [PATCH 15/30] hotfix: suppress_finalizing on device __del__ --- tinygrad/device.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 7db5310bf8..2e4b1f2520 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM from tinygrad.helpers import Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup -from tinygrad.helpers import unwrap_class_type +from tinygrad.helpers import unwrap_class_type, suppress_finalizing from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -163,6 +163,7 @@ class Buffer: return self._trace_num @property def nbytes(self): return self.size*self.dtype.itemsize + @suppress_finalizing def __del__(self): (not hasattr(self, '_buf')) or self.deallocate() def __repr__(self): return f" Date: Tue, 21 Oct 2025 09:22:39 +0800 Subject: [PATCH 16/30] num_batches_tracked has shape () (#12820) --- tinygrad/nn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index b27ab036c0..c8884146d3 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -36,7 +36,7 @@ class BatchNorm: self.weight: Tensor|None = Tensor.ones(sz) if affine else None self.bias: Tensor|None = Tensor.zeros(sz) if affine else None - self.num_batches_tracked = Tensor.zeros(1, dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False) + self.num_batches_tracked = Tensor.zeros(dtype='long' if is_dtype_supported(dtypes.long) else 'int', requires_grad=False) if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) def calc_stats(self, x:Tensor) -> tuple[Tensor, Tensor]: From 990e8b97eea57599d2f93c5689141e5f5bdc6690 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Mon, 20 Oct 2025 18:30:34 -0700 Subject: [PATCH 17/30] feat: log openpilot 0.10.1 times (#12816) --- .github/workflows/benchmark.yml | 6 +++--- examples/openpilot/compile3.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c563f79c62..39893f4388 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -626,11 +626,11 @@ jobs: - name: benchmark openpilot 0.9.9 dmonitoring run: BENCHMARK_LOG=openpilot_0_9_9_dmonitoring PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.9/selfdrive/modeld/models/dmonitoring_model.onnx - name: openpilot compile3 0.10.1 driving_vision - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=25 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: openpilot compile3 0.10.1 driving_policy - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=7 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.10.1 dmonitoring - run: PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=12 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/refs/heads/master/selfdrive/modeld/models/dmonitoring_model.onnx - name: benchmark MobileNetV2 on DSP run: | # generate quantized weights diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 02b8496b26..1c831aa48d 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -121,6 +121,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file, tol): print("test vs onnx passed") return timings +def bench(run, inputs): + from extra.bench_log import WallTimeEvent, BenchEvent + for _ in range(10): + with WallTimeEvent(BenchEvent.STEP): + run(**inputs).numpy() + if __name__ == "__main__": onnx_file = fetch(OPENPILOT_MODEL) inputs, outputs = compile(onnx_file) @@ -131,3 +137,5 @@ if __name__ == "__main__": if not getenv("FLOAT16"): test_vs_onnx(inputs, outputs, onnx_file, 1e-4) + if getenv("BENCHMARK_LOG", ""): + bench(pickle_loaded, inputs) From 68c045bf0ad9014259e5cd38b676ae964a53af8b Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Mon, 20 Oct 2025 21:38:43 -0400 Subject: [PATCH 18/30] NIR: Check for brew packages tinymesa and tinymesa_cpu (#12739) * brew install tinymesa_cpu * brew --prefix tinygrad_cpu too * fix brew paths * check both brew paths * better errors * handle failure --- .github/actions/setup-tinygrad/action.yml | 2 +- autogen_stubs.sh | 12 ++++++------ tinygrad/runtime/autogen/mesa.py | 17 +++++++++-------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.github/actions/setup-tinygrad/action.yml b/.github/actions/setup-tinygrad/action.yml index 76323bc415..0b2dbc05a5 100644 --- a/.github/actions/setup-tinygrad/action.yml +++ b/.github/actions/setup-tinygrad/action.yml @@ -302,4 +302,4 @@ runs: - name: Install mesa (macOS) if: inputs.mesa == 'true' && runner.os == 'macOS' shell: bash - run: brew install sirhcm/tinymesa/tinymesa + run: brew install sirhcm/tinymesa/tinymesa_cpu diff --git a/autogen_stubs.sh b/autogen_stubs.sh index 5d02cd37f4..58d919d597 100755 --- a/autogen_stubs.sh +++ b/autogen_stubs.sh @@ -520,17 +520,17 @@ generate_mesa() { LVP_NIR_OPTIONS=$(./extra/mesa/lvp_nir_options.sh $MESA_SRC) fixup $BASE/mesa.py - patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "f'{brew_prefix()}/lib/libtinymesa_cpu.dylib'" + patch_dlopen $BASE/mesa.py tinymesa_cpu "(BASE:=os.getenv('MESA_PATH', f\"/usr{'/local/' if helpers.OSX else '/'}lib\"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so')" "f'{BASE}/libtinymesa{EXT}'" "brew_path('tinymesa_cpu')" "brew_path('tinymesa')" echo "lvp_nir_options = gzip.decompress(base64.b64decode('$LVP_NIR_OPTIONS'))" >> $BASE/mesa.py cat <> $BASE/mesa.py + echo "def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa ($TINYMESA_TAG, $MESA_TAG)')" >> $BASE/mesa.py sed -i "s/ctypes.glsl_base_type/glsl_base_type/" $BASE/mesa.py # bitfield bug in clang2py sed -i "s/('fp_fast_math', ctypes.c_bool, 9)/('fp_fast_math', ctypes.c_uint32, 9)/" $BASE/mesa.py diff --git a/tinygrad/runtime/autogen/mesa.py b/tinygrad/runtime/autogen/mesa.py index 78a0efc2e6..66cd9e5342 100644 --- a/tinygrad/runtime/autogen/mesa.py +++ b/tinygrad/runtime/autogen/mesa.py @@ -7,13 +7,14 @@ # LONGDOUBLE_SIZE is: 16 # import ctypes, ctypes.util, os, gzip, base64, subprocess, tinygrad.helpers as helpers -def brew_prefix(): - try: return subprocess.check_output(['brew', '--prefix', 'tinymesa']).decode().strip() - except Exception: return '' +def brew_path(nm): + try: return f"{subprocess.check_output(['brew', '--prefix', nm]).decode().strip()}/lib/lib{nm}.dylib" + except Exception: return 'failed' PATHS_TO_TRY = [ (BASE:=os.getenv('MESA_PATH', f"/usr{'/local/' if helpers.OSX else '/'}lib"))+'/libtinymesa_cpu'+(EXT:='.dylib' if helpers.OSX else '.so'), f'{BASE}/libtinymesa{EXT}', - f'{brew_prefix()}/lib/libtinymesa_cpu.dylib', + brew_path('tinymesa_cpu'), + brew_path('tinymesa'), ] def _try_dlopen_tinymesa_cpu(): library = ctypes.util.find_library("tinymesa_cpu") @@ -6087,7 +6088,7 @@ struct_nir_op_info._fields_ = [ nir_op_info = struct_nir_op_info try: nir_op_infos = (struct_nir_op_info * 489).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_op_infos') -except AttributeError: pass +except (AttributeError, ValueError): pass try: nir_op_is_selection = _libraries['FIXME_STUB'].nir_op_is_selection nir_op_is_selection.restype = ctypes.c_bool @@ -8118,7 +8119,7 @@ c__EA_nir_intrinsic_index_flag = ctypes.c_uint32 # enum nir_intrinsic_index_flag = c__EA_nir_intrinsic_index_flag nir_intrinsic_index_flag__enumvalues = c__EA_nir_intrinsic_index_flag__enumvalues try: nir_intrinsic_index_names = (ctypes.POINTER(ctypes.c_char) * 75).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_index_names') -except AttributeError: pass +except (AttributeError, ValueError): pass class struct_nir_intrinsic_instr(Structure): pass @@ -8242,7 +8243,7 @@ struct_nir_intrinsic_info._fields_ = [ nir_intrinsic_info = struct_nir_intrinsic_info try: nir_intrinsic_infos = (struct_nir_intrinsic_info * 732).in_dll(_libraries['libtinymesa_cpu.so'], 'nir_intrinsic_infos') -except AttributeError: pass +except (AttributeError, ValueError): pass try: nir_intrinsic_src_components = _libraries['libtinymesa_cpu.so'].nir_intrinsic_src_components nir_intrinsic_src_components.restype = ctypes.c_uint32 @@ -19877,4 +19878,4 @@ __all__ = \ 'union_util_format_description_0', 'util_format_colorspace', 'util_format_layout', 'va_list'] lvp_nir_options = gzip.decompress(base64.b64decode('H4sIAAAAAAAAA2NgZGRkYGAAkYxgCsQFsxigwgwQBoxmhCqFq2WEKwIrAEGIkQxoAEMALwCqVsCiGUwLMHA0QPn29nBJkswHANb8YpH4AAAA')) -def __getattr__(nm): raise AttributeError() if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)') +def __getattr__(nm): raise AttributeError('LLVMpipe requires tinymesa_cpu' if 'tinymesa_cpu' not in dll._name else f'attribute {nm} not found') if dll else FileNotFoundError(f'libtinymesa not found (MESA_PATH={BASE}). See https://github.com/sirhcm/tinymesa (tinymesa-32dc66c, mesa-25.2.4)') From df2f8b9295fb52784f804e599f8cc9580d5ba850 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 10:29:12 +0800 Subject: [PATCH 19/30] use after on locals (#12815) * use after on locals * fix estimates * too much compute * correct for both ptx and normal * err, that * tighter spec * keep that --- tinygrad/codegen/gpudims.py | 10 +++++++++- tinygrad/codegen/late/expander.py | 3 +-- tinygrad/codegen/opt/postrange.py | 2 +- tinygrad/codegen/opt/search.py | 4 +++- tinygrad/renderer/__init__.py | 4 +++- tinygrad/schedule/rangeify.py | 18 +++++------------- tinygrad/uop/spec.py | 6 +----- 7 files changed, 23 insertions(+), 24 deletions(-) diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 5f406f78b0..5169450883 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,4 +1,4 @@ -import math +import math, functools, operator from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop from tinygrad.helpers import all_int, dedup, get_contraction from tinygrad.dtype import dtypes @@ -87,7 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp): except ValueError: continue return s.substitute(subs) +def add_barrier_and_if(buf:UOp, s:UOp): + # TODO: this is not generic + local_ranges = [x for x in s.src[1:] if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] + if len(local_ranges) == 0: return None + return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), s.barrier()))) + pm_add_gpudims = PatternMatcher([ # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), + # add barrier and if + (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.STORE, name="s"))), add_barrier_and_if), ]) diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 9a42d414ce..c594d6315d 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -145,8 +145,7 @@ def fix_group_for_reduce(x:UOp): reduce_loop = [x.replace(arg=(x.arg[0]+100, AxisType.REDUCE)) for x in reduce_gfr] buf = ret.bufferize(*upstream_locals, *reduce_gfr, arg=BufferizeOpts(reduce_gfr[0].arg[0], AddrSpace.LOCAL)).index(*upstream_locals, *reduce_loop) - # gate with an if on the store + do the final reduce - buf = UOp(Ops.IF, dtype=buf.dtype, src=(functools.reduce(operator.and_, [x.eq(0) for x in reduce_gfr]), buf)) + # do the final reduce (if/barrier are added in gpudims step) return buf.reduce(*reduce_loop, arg=x.arg) pm_pre_expander = PatternMatcher([ diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 55a443dfdf..4263ba67ff 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -348,7 +348,7 @@ def apply_opts(ctx:Renderer, ast:UOp): elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet - if all(len(u.src) == 1 for u in ast.backward_slice if u.op is Ops.LOAD): + if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 21cce836f3..bb87c103b9 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -156,7 +156,9 @@ def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=Tr if lib in seen_libs: continue # filter out kernels that use 1000x more compute than the smallest least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) - if least_compute_ops*1000 < this_compute_ops: continue + if least_compute_ops*1000 < this_compute_ops: + if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}") + continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches')) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 87ddce695a..849ec9d48e 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -30,7 +30,9 @@ class Estimates: if ignore_indexing: for u in uops: if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): - dont_count = dont_count.union(u.src[0].toposort()) + # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER + dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort()) + # TODO: is this correct? this all needs to be cleaned up if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) elif u.op is Ops.IF: dont_count = dont_count.union(u.src[0].toposort()) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 439bf4e613..1f037406c1 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -301,9 +301,7 @@ def bufferize_to_store(x:UOp): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - # store has the other dtype here - # TODO: use after here? - return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).reshape(shape) + return buf.after(buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs)).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), @@ -336,6 +334,7 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp): return b.src[0] def handle_after(ctx:LocalAddBufferContext, after:UOp): + if isinstance(after.dtype, PtrDType) and after.ptrdtype.addrspace == AddrSpace.LOCAL: return None buf = after.as_buf() # HACK to put the buffer in the MAP instead of MSTACK/MSELECT if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0] @@ -388,16 +387,9 @@ rangeify_codegen = PatternMatcher([ # add loads to non ptr indexes # TODO: this can be moved into codegen? - (UPat((Ops.DEFINE_GLOBAL, Ops.STORE), name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), - lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), - - # TODO: this can be moved into codegen - (UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD), - lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store if idx.dtype.addrspace != AddrSpace.LOCAL else store.barrier())), - - # TODO: hack for group for reduce - (UPat(Ops.IF, src=(UPat.var("gate"), UPat(Ops.LOAD, src=(UPat.var("src"), UPat.var("barrier"))),)), - lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))), + (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) + .f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), ]) def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 61ca7aeda1..ca9b32c7ec 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -180,15 +180,11 @@ spec = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat())), lambda: True), (UPat(Ops.INDEX, src=(UPat(GroupOp.Defines).or_after(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - # LOAD on STORE - (UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True), - # LOAD takes a (UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])), (UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index), - # STORE takes a - (UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store), + # STORE takes a (UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE From 8521fd526367793b866e0abf86595f0deaba7c71 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 10:55:41 +0800 Subject: [PATCH 20/30] viz: hierarchical rewrites (#12805) * viz: hierarchical rewrites * count of subrewrites * arrows * better keyboard things * add select and deselect utils * works * diff * event stopPropagation * work * don't change the rewrite * walk tree back --- tinygrad/viz/index.html | 10 ++++++++- tinygrad/viz/js/index.js | 46 ++++++++++++++++++++++++++++++++-------- tinygrad/viz/serve.py | 4 ++-- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 8731a457b6..2f10e7d89b 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -2,6 +2,7 @@ tinygrad viz + @@ -52,6 +53,13 @@ } ul > ul { display: none; + margin-left: 6px; + } + ul.has-children > p::before { + content:"â–¸ "; + } + ul.has-children.expanded > p::before { + content:"â–¾ "; } ul.expanded > ul { display: block; @@ -141,7 +149,7 @@ .metadata > * + *, .rewrite-container > * + *, .ctx-list > * + * { margin-top: 12px; } - .ctx-list > ul > * + * { + ul > * + * { margin-top: 4px; } .graph { diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index a0f5a3af25..4c5d775b1c 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -573,12 +573,20 @@ function setState(ns) { } if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) { document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active"); + // walk the tree back until all parents expanded so that the child is visible + let e = document.getElementById(`step-${state.currentCtx}-${state.currentStep}`); + while (e?.parentElement?.id.startsWith("step")) { + e.parentElement.classList.add("expanded"); + e = e.parentElement; + } setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`)); } // re-render main(); } +const getSubrewrites = (ul) => ul.querySelectorAll(":scope > ul"); + // set a new context and keep the old one in browser history function setCtxWithHistory(newCtx, step=0) { // NOTE: browser does a structured clone, passing a mutable object is safe. @@ -605,16 +613,25 @@ async function main() { p.onclick = () => { setState(i === state.currentCtx ? { expandSteps:!state.expandSteps } : { expandSteps:true, currentCtx:i, currentStep:0, currentRewrite:0 }); } + const stack = []; let list = ul; for (const [j,u] of steps.entries()) { - const inner = ul.appendChild(document.createElement("ul")); - inner.id = `step-${i}-${j}`; - const p = inner.appendChild(document.createElement("p")); + while (stack.length && stack.at(-1).depth >= u.depth) stack.pop(); + const list = stack.length > 0 ? stack.at(-1).li : ul; + u.li = list.appendChild(document.createElement("ul")); + u.li.id = `step-${i}-${j}`; + const p = u.li.appendChild(document.createElement("p")); p.innerText = `${u.name}`+(u.match_count ? ` - ${u.match_count}` : ''); - inner.style.marginLeft = `${8*u.depth}px`; - inner.onclick = (e) => { + p.onclick = (e) => { e.stopPropagation(); - setState({ currentStep:j, currentCtx:i, currentRewrite:0 }); + const subrewrites = getSubrewrites(e.currentTarget.parentElement); + if (subrewrites.length) { e.currentTarget.parentElement.classList.toggle("expanded"); } + setState({ currentStep:j, currentCtx:i }); } + stack.push(u); + } + for (const l of ul.querySelectorAll("ul > ul > p")) { + const subrewrites = getSubrewrites(l.parentElement); + if (subrewrites.length > 0) { l.innerText += ` (${subrewrites.length})`; l.parentElement.classList.add("has-children"); } } } return setState({ currentCtx:-1 }); @@ -764,22 +781,32 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid // **** keyboard shortcuts +const select = (ctx, step) => ({ ctx:document.getElementById(`ctx-${ctx}`), step:document.getElementById(`step-${ctx}-${step}`) }); +const deselect = (element) => { + const parts = element?.id.split("-").map(Number); + return element?.id.startsWith("ctx") ? { ctx:parts[1], step:null } : element?.id.startsWith("step") ? {ctx:parts[1], step:parts[2]} : {}; +} +const isExpanded = (el) => el?.classList.contains("expanded"); + document.addEventListener("keydown", (event) => { const { currentCtx, currentStep, currentRewrite, expandSteps } = state; // up and down change the step or context from the list const changeStep = expandSteps && ctxs[currentCtx].steps?.length; + const { step, ctx } = select(currentCtx, currentStep); if (event.key == "ArrowUp") { event.preventDefault(); if (changeStep) { - return setState({ currentRewrite:0, currentStep:Math.max(0, currentStep-1) }); + let prev = deselect(step.previousElementSibling); + if (prev.step == null && isExpanded(step.parentElement)) prev = deselect(step.parentElement); + return prev.step != null && !isExpanded(step) && setState({ currentRewrite:0, currentStep:prev.step }); } return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.max(0, currentCtx-1), expandSteps:false }); } if (event.key == "ArrowDown") { event.preventDefault(); if (changeStep) { - const totalUOps = ctxs[currentCtx].steps.length-1; - return setState({ currentRewrite:0, currentStep:Math.min(totalUOps, currentStep+1) }); + const next = deselect(isExpanded(step) ? step.children[1] : step.nextElementSibling); + return next.step != null && setState({ currentRewrite:0, currentStep:next.step }); } return setState({ currentStep:0, currentRewrite:0, currentCtx:Math.min(ctxs.length-1, currentCtx+1), expandSteps:false }); } @@ -789,6 +816,7 @@ document.addEventListener("keydown", (event) => { if (currentCtx === -1) { return setState({ currentCtx:0, expandSteps:true }); } + if (expandSteps && getSubrewrites(step).length) return step.children[0].click(); return setState({ expandSteps:!expandSteps }); } // left and right go through rewrites in a single UOp diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index b92246224e..da864578ce 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -33,8 +33,8 @@ def get_rewrites(t:RewriteTrace) -> list[dict]: steps = [{"name":s.name, "loc":s.loc, "match_count":len(s.matches), "code_line":printable(s.loc), "query":f"/ctxs?ctx={i}&idx={j}", "depth":s.depth} for j,s in enumerate(v)] if isinstance(k.ret, ProgramSpec): - steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src"}) - steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm"}) + steps.append({"name":"View Program", "query":f"/render?ctx={i}&fmt=src", "depth":0}) + steps.append({"name":"View Disassembly", "query":f"/render?ctx={i}&fmt=asm", "depth":0}) for key in k.keys: ref_map[key] = i ret.append({"name":k.display_name, "steps":steps}) return ret From a71a41f6d1bd6de701ee8d0da1df8cf8f04a962e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 11:32:18 +0800 Subject: [PATCH 21/30] rename Ops.ENDRANGE -> Ops.END (#12824) --- test/test_linearizer.py | 4 ++-- test/test_uop_graph.py | 2 +- tinygrad/codegen/late/linearize.py | 4 ++-- tinygrad/renderer/__init__.py | 2 +- tinygrad/renderer/cstyle.py | 4 ++-- tinygrad/renderer/llvmir.py | 2 +- tinygrad/renderer/nir.py | 2 +- tinygrad/renderer/ptx.py | 4 ++-- tinygrad/runtime/ops_python.py | 4 ++-- tinygrad/uop/__init__.py | 2 +- tinygrad/uop/spec.py | 2 +- tinygrad/viz/serve.py | 3 ++- 12 files changed, 18 insertions(+), 17 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7af6294c83..f63a9e9a71 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase): def _test_no_nested_ranges(self, lins, skip=None): for l in lins: range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG]) - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)] for i,u in enumerate(ranges): if skip and i in skip: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase): # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE uops = get_program(ast, opts=opt).uops begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] - end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] + end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0] for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) for u in uops: if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG: diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d5d75462f3..ca93f3a0cf 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -674,7 +674,7 @@ class TestUOpGraph(unittest.TestCase): store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) ranges = [x for x in uops if x.op is Ops.RANGE] - endranges = [x for x in uops if x.op is Ops.ENDRANGE] + endranges = [x for x in uops if x.op is Ops.END] # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) diff --git a/tinygrad/codegen/late/linearize.py b/tinygrad/codegen/late/linearize.py index d860125adf..af6727819e 100644 --- a/tinygrad/codegen/late/linearize.py +++ b/tinygrad/codegen/late/linearize.py @@ -105,7 +105,7 @@ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp while len(ends_to_add): r:UOp = ends_to_add.pop(-1) new_ctx = tuple([z for z in new_ctx if z is not r]) - end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)) + end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.END, src=(r,)) base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) return base_block @@ -215,7 +215,7 @@ def remove_blockend(x:UOp): # NOTE: DEFINE_ACC doesn't have to be handled in any special way late_ops = list(x.arg.lst) # NOTE: we have to add a barrier at the start if barrier is used in the range - if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE: + if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.END: late_ops = [UOp(Ops.BARRIER)] + late_ops # peephole opt, remove any BARRIERs next to each other for i in range(len(late_ops)-1): diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 849ec9d48e..a1d8f89d5f 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -47,7 +47,7 @@ class Estimates: mults *= cast(sint, u.src[0].ssimplify()) # SPECIAL are already counted in mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults - elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) + elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): lds += u.dtype.itemsize * mults diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e6d01bfc97..5afcd0711a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), - (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), + (UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"), (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), # r method accesses (UPat(Ops.RANGE, name="x"), @@ -173,7 +173,7 @@ class CStyleLanguage(Renderer): l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" - if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 + if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 032532e75c..b67bd9cb32 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([ f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n" f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n" f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"), - (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: + (UPat(Ops.END, name="x"), lambda ctx,x: f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n" f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n" f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n" diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index efaeddbecd..eec9cade89 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -186,7 +186,7 @@ class NIRRenderer(Renderer): nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) mesa.nir_push_loop(self.b) self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype) - elif u.op == Ops.ENDRANGE: + elif u.op == Ops.END: nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]), functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break)) mesa.nir_pop_loop(self.b, None) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index a57ee6a838..cc95e357a3 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([ if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]), - (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + (UPat(Ops.END, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), @@ -219,7 +219,7 @@ class PTXRenderer(Renderer): [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)], [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] - prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), + prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]), Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index afb1bb87f7..9a8ade8e18 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -52,11 +52,11 @@ class PythonProgram: loop_ends: dict[int, int] = {} while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] - void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} + void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) - if uop is Ops.ENDRANGE: + if uop is Ops.END: loop_ends[idp[0]] = i i = idp[0] continue diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index cfdfa39a8a..fcb62cd1f3 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -70,7 +70,7 @@ class Ops(FastEnum): WHERE = auto(); MULACC = auto() # noqa: E702 # control flow ops - BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702 + BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702 # consts. VCONST is a vectorized const VCONST = auto(); CONST = auto() # noqa: E702 diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index ca9b32c7ec..667936cb49 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -195,7 +195,7 @@ spec = PatternMatcher([ (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), - (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), + (UPat(Ops.END, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index da864578ce..827d672509 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -20,7 +20,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", - Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866"} + Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866", + Ops.END: "#524C46"} # VIZ API From 154cdfe46d5022651a46708169bf421798ea591e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 11:44:51 +0800 Subject: [PATCH 22/30] viz state cleanups (#12821) * viz state cleanups * more generic --- tinygrad/viz/js/index.js | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 4c5d775b1c..319f1f4ba6 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -530,10 +530,10 @@ function codeBlock(st, language, { loc, wrap }={}) { return ret; } -function setActive(e) { - if (e == null) return; - e.classList.add("active"); - requestAnimationFrame(() => e.scrollIntoView({ behavior: "auto", block: "nearest" })); +function toggleCls(prev, next, cls, value) { + prev?.classList.remove(cls); + next?.classList.toggle(cls, value ?? true); + requestAnimationFrame(() => next?.scrollIntoView({ behavior: "auto", block: "nearest" })); } // ** hljs extra definitions for UOps and float4 @@ -563,23 +563,20 @@ const evtSources = []; // context: collection of steps const state = {currentCtx:-1, currentStep:0, currentRewrite:0, expandSteps:false}; function setState(ns) { - const { currentCtx:prevCtx, currentStep:prevStep } = state; + const { ctx:prevCtx, step:prevStep } = select(state.currentCtx, state.currentStep); Object.assign(state, ns); // update element styles if needed - document.getElementById(`ctx-${state.currentCtx}`)?.classList.toggle("expanded", state.expandSteps); - if (state.currentCtx !== prevCtx) { - document.getElementById(`ctx-${prevCtx}`)?.classList.remove("active", "expanded"); - setActive(document.getElementById(`ctx-${state.currentCtx}`)); - } - if (state.currentCtx !== prevCtx || state.currentStep !== prevStep) { - document.getElementById(`step-${prevCtx}-${prevStep}`)?.classList.remove("active"); + const { ctx, step } = select(state.currentCtx, state.currentStep); + toggleCls(prevCtx, ctx, "expanded", state.expandSteps); + if (ctx?.id !== prevCtx?.id) toggleCls(prevCtx, ctx, "active"); + if (ctx?.id !== prevCtx?.id || step?.id !== prevStep?.id) { + toggleCls(prevStep, step, "active"); // walk the tree back until all parents expanded so that the child is visible - let e = document.getElementById(`step-${state.currentCtx}-${state.currentStep}`); + let e = step; while (e?.parentElement?.id.startsWith("step")) { e.parentElement.classList.add("expanded"); e = e.parentElement; } - setActive(document.getElementById(`step-${state.currentCtx}-${state.currentStep}`)); } // re-render main(); From 57f6b6f229f9c2ed73c0f00b787c74e495c872da Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 12:15:13 +0800 Subject: [PATCH 23/30] style view codegen like a link in profiler (#12825) --- tinygrad/viz/index.html | 2 ++ tinygrad/viz/js/index.js | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 2f10e7d89b..9a11acf95e 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -39,6 +39,8 @@ ::-webkit-scrollbar-thumb { background: #686977; } a { color: #4a90e2; + text-decoration: underline; + cursor: pointer; } ul { padding: 0; diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 319f1f4ba6..4c4888fc48 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -245,9 +245,9 @@ async function renderProfiler() { html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node()); if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info; if (shapeRef != null) { - const p = html.appendChild(document.createElement("p")); - p.innerText = "\nView Codegen Rewrite"; p.style.cursor = "pointer"; - p.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step); + const a = html.appendChild(document.createElement("a")); + a.innerText = "\nView codegen rewrite"; + a.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step); } // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; From 367fbabc3063757eefc4b02a591fa785c3d0ce11 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Tue, 21 Oct 2025 08:19:42 +0200 Subject: [PATCH 24/30] remove Ops.SUBSTITUTE (#12827) * remove Ops.SUBSTITUTE * remove from viz --- tinygrad/uop/__init__.py | 1 - tinygrad/uop/ops.py | 2 +- tinygrad/uop/spec.py | 3 --- tinygrad/viz/serve.py | 3 +-- 4 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index fcb62cd1f3..f6130bc6e4 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -20,7 +20,6 @@ class Ops(FastEnum): # create buffer BUFFERIZE = auto() - SUBSTITUTE = auto() # ops that adjust the behavior of the scheduler CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702 diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b08c193f9c..e8be4f0fe6 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -179,7 +179,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.SUBSTITUTE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST: return None # some ops init the shape diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 667936cb49..86e0e398be 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -236,9 +236,6 @@ full_spec = PatternMatcher([ # SENTINEL should never be in the graph (UPat(Ops.SENTINEL), lambda: False), - # allow any SUBSTITUTE - (UPat(Ops.SUBSTITUTE), lambda: True), - # Invalid must have type Index (UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index), # where on index in rhs position is fine diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 827d672509..592304116c 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -20,8 +20,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", - Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866", - Ops.END: "#524C46"} + Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} # VIZ API From 32af1ff84b578226cf76e76735b58094074889b0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:51:32 +0800 Subject: [PATCH 25/30] viz graph drawing small cleanups (#12830) * viz graph drawing small cleanups * str literal --- test/unit/test_viz.py | 8 ++++++++ tinygrad/viz/js/index.js | 9 ++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 9810de38d4..4edee1323b 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -148,6 +148,14 @@ class TestViz(BaseTestViz): a2 = uop_to_json(a)[id(a)] self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')") + def test_colored_label_multiline(self): + arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta") + src = [Tensor.empty(1).uop for _ in range(10)] + a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg) + exec_rewrite(a, [PatternMatcher([])]) + a2 = next(get_viz_details(0, 0))["graph"][id(a)] + self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw") + def test_inf_loop(self): a = UOp.variable('a', 0, 10, dtype=dtypes.int) b = a.replace(op=Ops.CONST) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 4c4888fc48..dd803bda5b 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) { if (parents == null && children == null) return; const src = [...parents, ...children, d.id]; nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id)); - const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; + const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath"); d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port"); e.stopPropagation(); @@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) { }).selectAll("text").data(d => { const ret = [[]]; for (const { st, color } of parseColors(d.label, defaultColor="initial")) { - for (const [i, l] of st.split("\n").entries()) { - if (i > 0) ret.push([]); - ret.at(-1).push({ st:l, color }); - } + const lines = st.split("\n"); + ret.at(-1).push({ st:lines[0], color }); + for (let i=1; i d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan") From d59d4cdbe40726fffec57f3b4dc1b591dc0faef3 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 21 Oct 2025 17:09:44 +0800 Subject: [PATCH 26/30] lil less is okay --- test/external/speed_v_theoretical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py index 8b04d4af2b..ec669781ba 100644 --- a/test/external/speed_v_theoretical.py +++ b/test/external/speed_v_theoretical.py @@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase): # theoretical is nv_tflops=165, amd_tflops=123 def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=65) - def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=60) + def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=115, amd_tflops=60) # theoretical is nv_gbs=1008, amd_gbs=960 def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750) - def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750) + def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=820, amd_gbs=750) if __name__ == '__main__': unittest.main() From c780cd9abb9d0a3cb4f8d369aefc13684a5c0b74 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:37:48 +0800 Subject: [PATCH 27/30] new linearizer with early endrange (#12823) * new linearizer with early endrange * cleanups * second stage removal * not store * do that later * end cleanup * fix globals * end * multi end * fix ends earlier * work * do_merge_ends * mini change * range_gate * fix cpu * test fixups * ranges on index * not for ptx --- test/external/external_benchmark_schedule.py | 7 +- test/test_linearizer.py | 6 +- test/test_ops.py | 1 + test/test_rangeify.py | 8 +- test/test_tensor_uop.py | 8 +- test/test_uop_graph.py | 13 --- test/unit/test_simplify_valid_idx.py | 2 +- tinygrad/codegen/__init__.py | 12 ++- tinygrad/codegen/control_flow.py | 100 +++++++++++++++++++ tinygrad/codegen/gpudims.py | 8 +- tinygrad/codegen/late/devectorizer.py | 12 ++- tinygrad/codegen/late/expander.py | 2 +- tinygrad/codegen/opt/postrange.py | 25 ++--- tinygrad/codegen/simplify.py | 12 ++- tinygrad/renderer/__init__.py | 3 +- tinygrad/renderer/ptx.py | 2 +- tinygrad/schedule/rangeify.py | 20 ++-- tinygrad/uop/ops.py | 16 ++- tinygrad/uop/spec.py | 12 +-- tinygrad/uop/symbolic.py | 4 +- tinygrad/viz/serve.py | 4 +- 21 files changed, 193 insertions(+), 84 deletions(-) create mode 100644 tinygrad/codegen/control_flow.py diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 3dce947828..92feedca84 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -3,6 +3,7 @@ from tinygrad import Tensor, nn, Device from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer +from tinygrad.codegen.control_flow import linearize from tinygrad.uop.spec import type_verify if __name__ == "__main__": @@ -39,7 +40,7 @@ if __name__ == "__main__": with Timing("***** model linearize in "): uops_line = [] for u in rewritten_uops: - uops_line.append(apply_rewrites(u, rewrites_for_linearizer)) + uops_line.append(linearize(apply_rewrites(u, rewrites_for_linearizer))) with Timing("***** model verify in "): - for u in uops_line: type_verify(u.arg.lst) - print(sum(len(u.arg.lst) for u in uops_line)) + for u in uops_line: type_verify(u) + print(sum(len(u) for u in uops_line)) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f63a9e9a71..9a505a3921 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -214,8 +214,8 @@ class TestLinearizer(unittest.TestCase): else: assert u.src[1].op in GroupOp.ALU assert begin_range < uops.index(u) < end_range - # children of STORE are placed after ENDRANGE - if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src): + # children of END are placed after ENDRANGE + if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src): assert end_range < uops.index(u) def test_grouped_dims(self): @@ -400,7 +400,7 @@ class TestLinearizer(unittest.TestCase): # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) - assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1 + assert len([u for u in uops if u.op is Ops.IF and u.src[1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") diff --git a/test/test_ops.py b/test/test_ops.py index fb3869a295..022131c50d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2602,6 +2602,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)), lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) + @unittest.skipIf(Device.DEFAULT == "AMD" and CI, "remu failure?") def test_avg_pool3d_failure(self): with Context(NOOPT=0): helper_test_op([(1,1,16,16,16)], diff --git a/test/test_rangeify.py b/test/test_rangeify.py index ab8f8b8cfb..9bed5c1481 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,7 +1,9 @@ import unittest -from tinygrad import Tensor, nn -from tinygrad.helpers import Context, GlobalCounters, CI, CPU_LVP, getenv +from tinygrad import Tensor, nn, Device +from tinygrad.helpers import Context, GlobalCounters, CI, getenv from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops +from tinygrad.renderer.ptx import PTXRenderer +from tinygrad.renderer.nir import NIRRenderer class TestRangeifyAssign(unittest.TestCase): def test_assign_permuted(self): @@ -40,7 +42,7 @@ elif getenv("BIG") > 0: else: BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 -@unittest.skipIf(CPU_LVP, "broken in LVP") +@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX") class TestPcontig(unittest.TestCase): def test_flash_attention_bw(self): def fa_bw(): diff --git a/test/test_tensor_uop.py b/test/test_tensor_uop.py index 0a526ef5a1..21dfe41b57 100644 --- a/test/test_tensor_uop.py +++ b/test/test_tensor_uop.py @@ -3,7 +3,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes from tinygrad.engine.realize import run_schedule -from tinygrad.uop.ops import Ops, UOp, UPat +from tinygrad.uop.ops import UOp from tinygrad.helpers import SPLIT_REDUCEOP class TestTensorUOp(unittest.TestCase): @@ -93,7 +93,6 @@ class TestTensorUOp(unittest.TestCase): out.realize() self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist()) -reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, allow_any_len=True, src=(UPat(), UPat((Ops.REDUCE_AXIS, Ops.REDUCE)))))) @unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP") class TestReduceOp(unittest.TestCase): def test_no_split_reduce_kernel(self): @@ -101,23 +100,18 @@ class TestReduceOp(unittest.TestCase): a = a.sum() sched = a.schedule() assert len(sched) == 1 - assert reduce_kernel.match(sched[0].ast, {}) def test_split_reduce_kernel_dim0(self): a = Tensor.rand(256, 255).realize() a = a.sum() sched = a.schedule() assert len(sched) == 2 - for s in sched: - assert reduce_kernel.match(s.ast, {}) def test_split_reduce_kernel_dim1(self): a = Tensor.rand(255, 256).realize() a = a.sum() sched = a.schedule() assert len(sched) == 2 - for s in sched: - assert reduce_kernel.match(s.ast, {}) if __name__ == "__main__": unittest.main() diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index ca93f3a0cf..56bb56f67d 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -665,19 +665,6 @@ class TestUOpGraph(unittest.TestCase): bad_gate = UOp.const(dtypes.int, 1) with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) - def test_switched_range_order(self): - glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - cf = UOp.const(dtypes.float, 0.0) - r1 = UOp.range(2, 0) - r2 = UOp.range(2, 1) - alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) - store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) - uops = to_uops_list([store]) - ranges = [x for x in uops if x.op is Ops.RANGE] - endranges = [x for x in uops if x.op is Ops.END] - # ranges are closed in the right order - self.assertEqual(endranges[-1].src[0], ranges[0]) - @track_rewrites() def expander_rewrite(sink): return graph_rewrite(sink, sym + expander) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 7f3790c217..619d10e5ca 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.symbolic import simplify_valid from tinygrad.helpers import Context -from .test_uop_symbolic import check_uop_against_string +from test.unit.test_uop_symbolic import check_uop_against_string def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(Ops.LOAD, dtypes.float, ( diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 155ed805c2..ae5ee00d0a 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -14,10 +14,11 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render -from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.opt.postrange import pm_postrange_opt from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen +#from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext +from tinygrad.codegen.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize @dataclass class RewriteStep: @@ -30,11 +31,18 @@ class RewriteStep: def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) +""" rewrites_for_linearizer = [ RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True), RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"), RewriteStep(block_merge, name="Linearizer: Merge Blocks"), RewriteStep(pm_finalize, name="Linearizer: Finalize")] +""" + +rewrites_for_linearizer = [ + RewriteStep(pm_merge_ends, CFGContext, name="merge ends", bottom_up=True), + RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True), +] def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]: # cache with the values of the context vars @@ -119,6 +127,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: Linear program in UOps. """ - lst = list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).arg.lst) + lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True)) if __debug__: type_verify(lst) return lst diff --git a/tinygrad/codegen/control_flow.py b/tinygrad/codegen/control_flow.py new file mode 100644 index 0000000000..891015c5a1 --- /dev/null +++ b/tinygrad/codegen/control_flow.py @@ -0,0 +1,100 @@ +import heapq +from collections import defaultdict +from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat + +def linearize(u:UOp) -> list[UOp]: + lst = list(u.toposort()) + in_this_block = set(lst) + local_children: defaultdict[UOp, list[UOp]] = defaultdict(list) + in_degree:dict[UOp, int] = {} + priorities:dict[UOp, int] = {} + + # get local children and assign priorities + # NOTE: this requires the lst be locally toposorted + for u in reversed(lst): + in_degree[u] = 0 + for s in u.src: + if s in in_this_block: + local_children[s].append(u) + in_degree[u] += 1 + # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too + priority = [0] + [priorities[x] for x in local_children[u]] + if u.op is Ops.LOAD: priority.append(-1000) + if u.op is Ops.BARRIER: priority.append(-1500) + # ranges are scheduled as late as possible so anything that can be outside is + #if u.op is Ops.RANGE: priority = [2000] + # move defines and consts to the top + if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000) + priorities[u] = min(priority) + + # number the uops in "ideal" order + nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))} + + # then force then to be toposorted in as close to the ideal order as possible + heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0]) + newlst = [] + while heap: + newlst.append(u:=heapq.heappop(heap)[1]) + for v in local_children[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v)) + + assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}" + return newlst + +class CFGContext: + def __init__(self, sink:UOp): + # there are 3 relationships between ranges: + # nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y + # dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y + # independent, endrange y is not a dependency of endrange x + # everything is nested inside the sink + deps: dict[UOp, set[UOp]] = {} + nesting: dict[UOp, UOp] = {} + for u in sink.toposort(): + deps[u] = set().union(*(deps[s] for s in u.src)) + if u.op in (Ops.END, Ops.ENDIF, Ops.SINK): + nesting |= {x:u for x in deps[u] if x.op in (Ops.END, Ops.ENDIF) and (u.op is Ops.SINK or u.src[0] in deps[x]) and x not in nesting} + if u.op in (Ops.RANGE, Ops.END, Ops.IF, Ops.ENDIF): deps[u] |= {u} + + self.edges: dict[UOp, UOp] = {} + siblings: dict[UOp, list[UOp]] = {} + for k,vv in nesting.items(): siblings.setdefault(vv, []).append(k) + for k,v in siblings.items(): + # range/if that have dependencies on other siblings need to run after them + order = sorted(v, key=lambda x: len(deps[x].intersection(v))) + zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[0]] + order, order) + for x,y in zipped: + # TODO: is this check correct? + if y.src[0] not in x.backward_slice_with_self: + self.edges[y.src[0]] = x + +pm_add_control_flow = PatternMatcher([ + (UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None), +]) + +def do_merge_ends(s:UOp): + # NOTE: this can fail + stacked: dict[UOp, list[UOp]] = {} + dangling_ifs = [] + for x in s.toposort(): + if x.op in {Ops.END, Ops.ENDIF}: + assert x.op is not Ops.END or x.arg == 1, "ends must be single ends for linearizer" + stacked.setdefault(x.src[0], []).append(x) + if x.op is Ops.IF: dangling_ifs.append(x) + dangling_ifs = [x for x in dangling_ifs if x not in stacked] + replaces = {} + for k,v in stacked.items(): + if len(v) == 1: continue + rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=x[0].arg) + for x in v: replaces[x] = rep + if not len(replaces) and not len(dangling_ifs): return None + ret = s.substitute(replaces) + if len(dangling_ifs): + assert len(dangling_ifs) == 1, "we only support 1 dangling if" + ret = ret.replace(src=(UOp(Ops.ENDIF, src=(dangling_ifs[0], *ret.src)),)) + return ret + +pm_merge_ends = PatternMatcher([ + (UPat(Ops.SINK, name="s"), do_merge_ends), +]) \ No newline at end of file diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 5169450883..15a82d2df9 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -87,15 +87,15 @@ def add_gpudims(ctx:Renderer, s:UOp): except ValueError: continue return s.substitute(subs) -def add_barrier_and_if(buf:UOp, s:UOp): +def add_barrier_and_if(buf:UOp, e:UOp): # TODO: this is not generic - local_ranges = [x for x in s.src[1:] if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] + local_ranges = [x for x in e.ended_ranges if x.op is Ops.RANGE and x.arg[-1] == AxisType.GROUP_REDUCE] if len(local_ranges) == 0: return None - return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), s.barrier()))) + return buf.after(UOp(Ops.IF, dtype=dtypes.void, src=(functools.reduce(operator.and_, [x.eq(0) for x in local_ranges]), e.barrier()))) pm_add_gpudims = PatternMatcher([ # add gpudims must be last (UPat(Ops.SINK, name="s"), add_gpudims), # add barrier and if - (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.STORE, name="s"))), add_barrier_and_if), + (UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL, name="buf"), UPat(Ops.END, name="e"))), add_barrier_and_if), ]) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 7eeb9e68ac..5c928a09eb 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -268,10 +268,12 @@ pm_render = PatternMatcher([ UPat.var("a")), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)), (UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat.var("idx"), UPat.var("c").logical_not()).or_casted(),), allow_any_len=True, name="l").or_casted()), lambda c,idx,l,a: l.replace(src=(l.src[0], a.cast(l.dtype))+l.src[2:]).cast(a.dtype)), - # gate any stores that aren't gated with ifs + # gate any stores that aren't gated with if/endif pairs (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), - lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \ + lambda store,idx: UOp(Ops.ENDIF, src=(uif:=UOp(Ops.IF, src=(idx.src[2],)), UOp(Ops.STORE, src=store.src[:2]+(uif,)+store.src[2:]))) if \ len(store.src) <= 2 or store.src[2].op != Ops.IF else None), + # for renderering and linearizing, all ends must end one loop + (UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None), ]) # *** Ops.REDUCE -> Ops.DEFINE_ACC *** @@ -295,8 +297,8 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): # if we have a range if len(reduce_range) != 0: topo = inp.toposort() - stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE]) - input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges]) + ended_ranges = flatten([x.src[:x.arg] for x in topo if x.op is Ops.END]) + input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges]) identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar())) acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)) acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \ @@ -305,7 +307,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret - return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret, *reduce_range)).index(UOp.const(dtypes.int, 0)).load() + return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(ends=reduce_range[::-1])).index(UOp.const(dtypes.int, 0)).load() pm_reduce = PatternMatcher([ # REDUCE -> DEFINE_ACC+ASSIGN diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index c594d6315d..1f270394e6 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -87,7 +87,7 @@ expander = PatternMatcher([ lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE, - Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), + Ops.VECTORIZE, Ops.IF, Ops.REDUCE, Ops.END), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand), (UPat(Ops.CONTRACT, name="con"), do_contract), # BARRIERs aren't actually expanded (UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)), diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 4263ba67ff..720237b3b0 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -4,8 +4,8 @@ from collections import defaultdict from typing import cast, Final from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp from tinygrad.device import Buffer -from tinygrad.dtype import AddrSpace, dtypes, ImageDType -from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten +from tinygrad.dtype import dtypes, ImageDType +from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -64,21 +64,8 @@ class Scheduler: return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1) def _globalizable_rngs(self) -> list[UOp]: - store_rngs = self.ast.src[0].src[2:] - - # filter any not in local stores - local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \ - or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)] - for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) - - # filter any not in reduces - # TODO: enable this - """ - reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE] - for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) - """ - - return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[-1] == AxisType.LOOP] if store_rngs else [] + # all ranges that end before any STOREs + return [x for x in self.ast.toposort(lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in self.ast.ranges] def convert_loop_to_global(self): if not self.opts.has_local: return None @@ -89,11 +76,11 @@ class Scheduler: self.ast = self.ast.substitute(dict(zip(self.rngs, rng))) def colors(self) -> list[str]: - store_rngs = flatten([x.src[2:] for x in self.ast.src]) + globalizible_rngs = self._globalizable_rngs() ret = [] for x,r in zip(self.axis_types, self.rngs): if self.dont_use_locals and x == AxisType.GLOBAL: ret.append("BLUE") - elif r not in store_rngs and x == AxisType.LOOP: ret.append("BLACK") + elif r not in globalizible_rngs and x == AxisType.LOOP: ret.append("BLACK") else: ret.append(axis_colors[x]) return ret def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())]) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 9053df5eaa..eeb84071eb 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -13,14 +13,16 @@ def flatten_range(r:UOp): pm_flatten_range = PatternMatcher([ # real ranges only (UPat((Ops.REDUCE, Ops.STORE), name="r"), flatten_range), + # END is only on RANGES. TODO: this is copied from symbolic + (UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))), ]) def count_divmod(x:UOp): return len([u for u in x.toposort() if u.op in {Ops.IDIV, Ops.MOD}]) def simplify_merge_adjacent(u:UOp) -> UOp|None: reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE] - i = range_start[u.op] - while i < len(u.src)-1: - r0, r1 = u.src[i], u.src[i+1] + i = 0 + while i < len(u.ended_ranges)-1: + r0, r1 = u.ended_ranges[i], u.ended_ranges[i+1] # check same type if r0.arg[-1] == r1.arg[-1]: # check if the ranges to merge are in the same reduces @@ -39,7 +41,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: return u pm_simplify_ranges = PatternMatcher([ - (UPat((Ops.STORE, Ops.REDUCE), name="u"), simplify_merge_adjacent), + (UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent), ]) def mark_range_mod(ctx, r:UOp, c:UOp): @@ -57,7 +59,7 @@ def do_substitute(ctx, x: UOp): def dont_sub_ranges_for_image(ctx, x:UOp): if isinstance(x.src[0].dtype, ImageDType): - for s in x.src[1:]: ctx[s] = None + for s in x.src[0].ranges: ctx[s] = None pm_split_ranges = PatternMatcher([ (UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod), diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index a1d8f89d5f..b70b51c012 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -28,10 +28,11 @@ class Estimates: mult_stack: list[sint] = [] dont_count: set[UOp] = set() if ignore_indexing: + def range_gate(x): return x.op is not Ops.RANGE for u in uops: if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER - dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort()) + dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) # TODO: is this correct? this all needs to be cleaned up if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) elif u.op is Ops.IF: diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index cc95e357a3..565faf52b4 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([ if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]), - (UPat(Ops.END, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + (UPat(Ops.END, name="x", src=(UPat.var("src0"),), allow_any_len=True), lambda ctx, x, src0: [ ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 1f037406c1..eae57227ae 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -276,7 +276,7 @@ def bufferize_to_store(x:UOp): assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" # in assign, this is the buffer size, not the bufferize size # TODO: assign_mops here - do_store = assign_target.replace(dtype=sdtype).store(assign_src, *rngs).replace(tag=x.tag) + do_store = assign_target.replace(dtype=sdtype).store(assign_src).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE]) ret = assign_target.src[0].after(do_store) mops = [] walk = assign_mops @@ -289,7 +289,7 @@ def bufferize_to_store(x:UOp): # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg.device, size, x.dtype) - do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs).replace(tag=x.tag) + do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).replace(tag=x.tag).end(ends=[x for x in rngs if x.op is Ops.RANGE]) ret = buf.after(do_store).forced_reshape(shape) # TODO: is this right? what if it's offset if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): @@ -301,7 +301,8 @@ def bufferize_to_store(x:UOp): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - return buf.after(buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs)).reshape(shape) + do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).end(ends=[x for x in rngs if x.op is Ops.RANGE]) + return buf.after(do_store).reshape(shape) pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), @@ -412,7 +413,6 @@ class Kernel: def split_store(ctx:list[UOp], x:UOp) -> UOp|None: if len(x.ranges): return None - if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None # local kernel rewrite lctx = LocalAddBufferContext() @@ -422,8 +422,14 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None: metadatas = [ctx[y].metadata for y in lctx.parent_tags] # NOTE: the hack for COPY is here - ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) \ - if ret.src[1].op not in {Ops.COPY, Ops.BUFFER_VIEW} else ret.src[1] + for u in ret.toposort(): + # TODO: this can be wrong if there's multiple of these + if u.op in {Ops.COPY, Ops.BUFFER_VIEW}: + ret = u + break + else: + ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None) + kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): @@ -431,7 +437,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None: return kernel split_kernels = PatternMatcher([ - (UPat(Ops.STORE, name="x"), split_store), + (UPat((Ops.STORE, Ops.END), name="x"), split_store), ]) def tag_uop(ctx:list[UOp], x:UOp): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e8be4f0fe6..39dc77c6b7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -190,7 +190,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER: + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.FUSE | Ops.AFTER | Ops.END: return self.src[0]._shape # ops with custom handling @@ -276,6 +276,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): for s in self.src[:range_start[self.op]]: ret.update(s.ranges) for s in UOp.sink(*self.src[range_start[self.op]:]).ranges: if s in ret: del ret[s] + elif self.op is Ops.END: + for s in self.src[self.arg:]: ret.update(s.ranges) + for s in UOp.sink(*self.src[:self.arg]).ranges: + if s in ret: del ret[s] else: for s in self.src: ret.update(s.ranges) return ret @@ -285,6 +289,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.RANGE: return {self:None} return self._ranges + @functools.cached_property + def ended_ranges(self): + match self.op: + case Ops.REDUCE: return self.src[1:] + case Ops.END: return self.src[:self.arg] + case _: raise RuntimeError(f"{self.op} doesn't end ranges") + # *** uop evaluation *** def simplify(self, tracked=False, full_symbolic=True): @@ -350,6 +361,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs) + def end(self, *src:UOp, ends:Sequence[UOp]): + if len(ends) == 0: return self + return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends)) def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 86e0e398be..d53785edfe 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -159,8 +159,9 @@ spec = PatternMatcher([ (UPat(Ops.DEFINE_REG, src=()), lambda: True), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ - all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), + (UPat(Ops.RANGE, src=(UPat.var("x"),), allow_any_len=True, name="rng"), lambda rng,x: + rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \ + all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)), (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), @@ -195,7 +196,7 @@ spec = PatternMatcher([ (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), - (UPat(Ops.END, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), + (UPat(Ops.END, dtype=dtypes.void), lambda: True), # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), @@ -203,9 +204,8 @@ spec = PatternMatcher([ (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), - (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), - (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),), allow_any_len=True), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),), allow_any_len=True), lambda: True), (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), (UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 91cf8390e4..85fa1a804b 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -379,9 +379,11 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)), # only RANGE/IF/STORE/KERNEL have side effects (UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+ - tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER} else y.src for y in x.src[1:]])))), + tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.IF, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END} else y.src for y in x.src[1:]])))), # after with 1 src is just src[0] (UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s), + # END is only on RANGES + (UPat(Ops.END, name="e"), lambda e: UOp.end(*e.src[e.arg:], ends=sorted(UOp.sink(*e.src[:e.arg]).ranges, key=lambda x: x.arg))), ])+gep_pushing symbolic_flat = symbolic+PatternMatcher([ diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 592304116c..70f8285511 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -71,7 +71,7 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]: if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg) label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}" if u.dtype != dtypes.void: label += f"\n{u.dtype}" - for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src): + for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])): if x in excluded: arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}" label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "") @@ -82,6 +82,8 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]: label += f"\n{shape_to_str(u.shape)}" if u.op in {Ops.INDEX, Ops.BUFFERIZE}: label += f"\n{u.render()}" + if u.op is Ops.END: + label += "\n"+' '.join([f"{colored(u.src[i].arg[0], axis_colors[u.src[i].arg[-1]])}({u.src[i].vmax+1})" for i in range(u.arg)]) except Exception: label += "\n" if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}" From 40633ab34db6a469ea808a1513ffd699cd4ff45b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:51:36 +0800 Subject: [PATCH 28/30] list buffer args to kernel in profiler (#12826) * list buffer args to kernel in profiler * stable order * back button works * deselect also works --- tinygrad/viz/js/index.js | 51 +++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index dd803bda5b..42fc837349 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -173,10 +173,16 @@ function tabulate(rows) { return root; } -var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity; +var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity, shapeMetadata = new Map(); +function focusShape(shape) { + saveToHistory({ shape:focusedShape }); + focusedShape = shape?.key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel); + return document.querySelector(".metadata").replaceChildren(shapeMetadata.get(focusedShape) ?? ""); +} + async function renderProfiler() { displayGraph("profiler"); - d3.select(".metadata").node().replaceChildren(focusedShape?.html ?? ""); + d3.select(".metadata").node().replaceChildren(shapeMetadata.get(focusedShape) ?? ""); // layout once! if (data != null) return updateProgress({ start:false }); const profiler = d3.select(".profiler").html(""); @@ -242,6 +248,7 @@ async function renderProfiler() { } const html = document.createElement("div"); html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node()); + const argsDiv = document.createElement("div"); argsDiv.id = "args"; html.appendChild(document.createElement("br")); html.appendChild(argsDiv); if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info; if (shapeRef != null) { const a = html.appendChild(document.createElement("a")); @@ -250,7 +257,8 @@ async function renderProfiler() { } // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; - const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), html, key, ...shapeRef }; + if (key != null) shapeMetadata.set(key, html); + const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, ...shapeRef }; if (e.key != null) shapeMap.set(e.key, arg); // offset y by depth shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor }); @@ -294,17 +302,30 @@ async function renderProfiler() { const rows = [["DType", dtype], ["Len", formatUnit(sz)], ["Size", formatUnit(nbytes, "B")], ["Lifetime", formatTime(dur)]]; if (users != null) rows.push(["Users", users.length]); const info = html.appendChild(tabulate(rows).node()); + const arg = {tooltipText:info.outerHTML, key:`${k}-${num}`}; for (let u=0; u focusShape(shape); + const args = shapeMetadata.get(shape.key).querySelector("#args"); + const bufArg = d3.create("p").text(`${bufInfo} ${rows[2][1]}`).style("cursor", "pointer").style("margin-top", "4px").on("click", () => { + const device = document.getElementById(k); + if (!isExpanded(device)) device.click(); + focusShape(arg); + }).node(); + bufArg.dataset.num = num; + let before = null; + for (const c of args.children) { if (+c.dataset.num > num) { before = c; break; } } + args.insertBefore(bufArg, before); } } - const arg = {tooltipText:info.outerHTML, html, key:`${k}-${num}`}; + shapeMetadata.set(arg.key, html) shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) }); } // generic polygon merger @@ -337,6 +358,7 @@ async function renderProfiler() { else if (tid === focusedDevice) { track.shapes = track.views[0]; offset += rescaleTrack(track, tid, 1/track.scaleFactor); } } data.axes.y = newFocus != null ? { domain:[0, (t=data.tracks.get(newFocus)).peak], range:[t.offsetY+t.height, t.offsetY], fmt:"B" } : null; + toggleCls(document.getElementById(focusedDevice), document.getElementById(newFocus), "expanded"); focusedDevice = newFocus; return resize(); }); @@ -395,7 +417,7 @@ async function renderProfiler() { lw += e.label[li].width; } } - if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push([p, pcolor]); } + if (focusedShape != null && e.arg?.key === focusedShape) { paths.push([p, pcolor]); } } } // draw axes @@ -462,15 +484,11 @@ async function renderProfiler() { } } - function focusShape(shape) { - focusedShape = shape; render(zoomLevel); - return document.querySelector(".metadata").replaceChildren(shape?.html ?? ""); - } canvas.addEventListener("click", e => { e.preventDefault(); const foundRect = findRectAtPosition(e.clientX, e.clientY); if (foundRect?.step != null && foundRect?.key == null) { return setCtxWithHistory(foundRect.ctx, foundRect.step); } - if (foundRect?.key != focusedShape?.key) { focusShape(foundRect); } + if (foundRect?.key != focusedShape) { focusShape(foundRect); } }); canvas.addEventListener("mousemove", e => { @@ -583,15 +601,20 @@ function setState(ns) { const getSubrewrites = (ul) => ul.querySelectorAll(":scope > ul"); +function saveToHistory(ns) { + // NOTE: browser does a structured clone, passing a mutable object is safe. + history.replaceState(ns, ""); + history.pushState(ns, ""); +} + // set a new context and keep the old one in browser history function setCtxWithHistory(newCtx, step=0) { - // NOTE: browser does a structured clone, passing a mutable object is safe. - history.replaceState(state, ""); - history.pushState(state, ""); + saveToHistory(state); setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step, currentRewrite:0 }); } window.addEventListener("popstate", (e) => { + if (e.state?.shape != null) return focusShape({ key:e.state?.shape }); if (e.state != null) setState(e.state); }); From d711a4b9339f711f015cd9ccc0146a7e2165e042 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:52:18 +0800 Subject: [PATCH 29/30] delete old linearizer (#12834) * new linearizer with early endrange * cleanups * second stage removal * not store * do that later * end cleanup * fix globals * end * multi end * fix ends earlier * work * do_merge_ends * mini change * range_gate * fix cpu * test fixups * ranges on index * not for ptx * delete linearizer * remove more junk * delete that test * we insert endif * all ends --- test/external/external_benchmark_schedule.py | 4 +- test/test_uop_graph.py | 6 - test/unit/test_block_reorder.py | 76 ------ tinygrad/codegen/__init__.py | 28 +-- tinygrad/codegen/control_flow.py | 2 + tinygrad/codegen/late/devectorizer.py | 2 - tinygrad/codegen/late/linearize.py | 243 ------------------- tinygrad/uop/__init__.py | 3 - tinygrad/uop/spec.py | 2 +- tinygrad/viz/serve.py | 4 +- 10 files changed, 16 insertions(+), 354 deletions(-) delete mode 100644 test/unit/test_block_reorder.py delete mode 100644 tinygrad/codegen/late/linearize.py diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 92feedca84..0e91175bd8 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -2,7 +2,7 @@ from extra.models.resnet import ResNet50 from tinygrad import Tensor, nn, Device from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops -from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites, rewrites_for_linearizer +from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites from tinygrad.codegen.control_flow import linearize from tinygrad.uop.spec import type_verify @@ -40,7 +40,7 @@ if __name__ == "__main__": with Timing("***** model linearize in "): uops_line = [] for u in rewritten_uops: - uops_line.append(linearize(apply_rewrites(u, rewrites_for_linearizer))) + uops_line.append(linearize(u)) with Timing("***** model verify in "): for u in uops_line: type_verify(u) print(sum(len(u) for u in uops_line)) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 56bb56f67d..a38ef37af9 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -832,8 +832,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) def test_expand_ifs_one_gate(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) @@ -850,8 +848,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) # this will be fixed with the merge gated stores bounty @unittest.expectedFailure @@ -866,8 +862,6 @@ class TestIFUOps(unittest.TestCase): if_uops = [u for u in sink.toposort() if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) - for st in sink.src: - self.assertEqual(len(st.src), 2) class TestUOpTags(unittest.TestCase): def test_inc_by_one(self): diff --git a/test/unit/test_block_reorder.py b/test/unit/test_block_reorder.py deleted file mode 100644 index e81b1ff6c5..0000000000 --- a/test/unit/test_block_reorder.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest, random -from tinygrad.dtype import dtypes -from tinygrad.uop.ops import print_uops, UOp, Ops -from tinygrad.codegen.late.linearize import block_reorder -from tinygrad.renderer.cstyle import OpenCLRenderer - -def is_toposorted(lst:list[UOp]): - seen = set() - for u in lst: - if any(p not in seen for p in u.src): return False - seen.add(u) - return True - -class TestBlockReorder(unittest.TestCase): - def _test_randomize(self, golden:list[UOp]): - # test random order is always same - for _ in range(50): - # shuffle and form a valid toposort - lst = golden[:] - random.shuffle(lst) - topolst = [] - for u in lst: - for p in u.toposort(): - if p not in topolst: topolst.append(p) - assert is_toposorted(topolst) - - for x,y in zip(golden, this_order:=block_reorder(topolst)): - if x is not y: - print_uops(golden) - print_uops(this_order) - self.assertIs(x, y) - - def _test_render(self, golden:list[UOp]): - return OpenCLRenderer().render(golden) - - def test_loads(self): - a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0) - b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1) - c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2) - v1 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx0") - v2 = UOp(Ops.SPECIAL, dtype=dtypes.int, src=(UOp.const(dtypes.int, 4),), arg="gidx1") - v1 = v1*27 - v2 = v2*4 - loads = [ - a.index(v1).load(dtype=dtypes.float), - a.index(v1+1).load(dtype=dtypes.float), - a.index(v1+2).load(dtype=dtypes.float), - a.index(v1+3).load(dtype=dtypes.float), - b.index(v2).load(dtype=dtypes.float), - b.index(v2+1).load(dtype=dtypes.float), - b.index(v2+2).load(dtype=dtypes.float), - b.index(v2+3).load(dtype=dtypes.float)] - #random.shuffle(loads) - sink = c.store(sum(loads)).sink() - - # determine golden order - golden = block_reorder(list(sink.toposort())) - - # render for test - print(self._test_render(golden)) - #print_uops(golden) - - # assert the loads are in this order - self.assertListEqual([g.src[0].src[1].render() for g in golden if g.op is Ops.LOAD], - ['(gidx1*4)', '((gidx1*4)+1)', '((gidx1*4)+2)', '((gidx1*4)+3)', - '(gidx0*27)', '((gidx0*27)+1)', '((gidx0*27)+2)', '((gidx0*27)+3)']) - - # assert math is after loads - first_math = [i for i,g in enumerate(golden) if g.op is Ops.ADD and g.dtype == dtypes.float][0] - assert not any(x.op is Ops.LOAD for x in golden[first_math:]) - - # confirm the sort is stable - self._test_randomize(golden) - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index ae5ee00d0a..b0f8108053 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -17,7 +17,6 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in from tinygrad.codegen.opt.postrange import pm_postrange_opt from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen -#from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize @dataclass @@ -31,19 +30,6 @@ class RewriteStep: def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) -""" -rewrites_for_linearizer = [ - RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True), - RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"), - RewriteStep(block_merge, name="Linearizer: Merge Blocks"), - RewriteStep(pm_finalize, name="Linearizer: Finalize")] -""" - -rewrites_for_linearizer = [ - RewriteStep(pm_merge_ends, CFGContext, name="merge ends", bottom_up=True), - RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True), -] - def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool=True) -> list[RewriteStep]: # cache with the values of the context vars return _get_rewrites_for_renderer(opts, optimize, linearizer, QUANTIZE.value, DEVECTORIZE.value, TRANSCENDENTAL.value) @@ -109,11 +95,15 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q pm_final_rewrite = pm_decomp+pm_render+extra_matcher ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite")) - # return the list (with optional linearizer) - return ret + (rewrites_for_linearizer if linearizer else []) + # this was the linearizer + ret.append(RewriteStep(pm_merge_ends, name="merge ends")) + ret.append(RewriteStep(pm_add_control_flow, CFGContext, name="add control flow starts", bottom_up=True)) -def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp: - return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer)) + # return the list + return ret + +def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp: + return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize)) def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: """ @@ -127,6 +117,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]: Linear program in UOps. """ - lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True)) + lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None)) if __debug__: type_verify(lst) return lst diff --git a/tinygrad/codegen/control_flow.py b/tinygrad/codegen/control_flow.py index 891015c5a1..3eb9e56931 100644 --- a/tinygrad/codegen/control_flow.py +++ b/tinygrad/codegen/control_flow.py @@ -96,5 +96,7 @@ def do_merge_ends(s:UOp): return ret pm_merge_ends = PatternMatcher([ + # for renderering and linearizing, all ends must end one loop + (UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None), (UPat(Ops.SINK, name="s"), do_merge_ends), ]) \ No newline at end of file diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 5c928a09eb..7ada724c99 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -272,8 +272,6 @@ pm_render = PatternMatcher([ (UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True), lambda store,idx: UOp(Ops.ENDIF, src=(uif:=UOp(Ops.IF, src=(idx.src[2],)), UOp(Ops.STORE, src=store.src[:2]+(uif,)+store.src[2:]))) if \ len(store.src) <= 2 or store.src[2].op != Ops.IF else None), - # for renderering and linearizing, all ends must end one loop - (UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None), ]) # *** Ops.REDUCE -> Ops.DEFINE_ACC *** diff --git a/tinygrad/codegen/late/linearize.py b/tinygrad/codegen/late/linearize.py deleted file mode 100644 index af6727819e..0000000000 --- a/tinygrad/codegen/late/linearize.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import annotations -import heapq -from collections import defaultdict -from dataclasses import dataclass, replace -from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate -from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER - -# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed -def block_reorder(lst:list[UOp]) -> list[UOp]: - in_this_block = set(lst) - local_children: defaultdict[UOp, list[UOp]] = defaultdict(list) - in_degree:dict[UOp, int] = {} - priorities:dict[UOp, int] = {} - - # get local children and assign priorities - # NOTE: this requires the lst be locally toposorted - for u in reversed(lst): - in_degree[u] = 0 - for s in u.src: - if s in in_this_block: - local_children[s].append(u) - in_degree[u] += 1 - # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too - priority = [0] + [priorities[x] for x in local_children[u]] - if u.op is Ops.LOAD: priority.append(-1000) - if u.op is Ops.BARRIER: priority.append(-1500) - priorities[u] = min(priority) - - # number the uops in "ideal" order - nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))} - - # then force then to be toposorted in as close to the ideal order as possible - heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0]) - newlst = [] - while heap: - newlst.append(u:=heapq.heappop(heap)[1]) - for v in local_children[u]: - in_degree[v] -= 1 - if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v)) - - assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}" - return newlst - -# ***** basic block ***** - -def disp(y:UOp) -> str: - if y.op is Ops.IF: return f'IF{id(y)}' - if y.op is Ops.RANGE: return str(y.arg) - return "" - -@dataclass(frozen=True, eq=False) -class BasicBlock: - lst: tuple[UOp, ...] - ctx: tuple[UOp, ...] = () - end: UOp|None = None - cnt: int = 0 - child_ctx: tuple[UOp, ...]|None = None - def __lt__(self, _:BasicBlock): raise RuntimeError("no comparing basic blocks") - def __repr__(self): - return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\ - f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\ - f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst]) - def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx - -def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize)) - -# ***** block context ***** - -@dataclass -class BlockContext: - child_count: dict[UOp, int] - block_ctxs: dict[UOp, tuple[UOp, ...]] - child_ctxs: dict[UOp, tuple[UOp, ...]] - def last_ctx(self, u): return self.child_ctxs.get(u, self.block_ctxs[u]) - @staticmethod - def from_sink(sink:UOp) -> BlockContext: - # get children and all block contexts - ctx = BlockContext({}, {}, {}) - for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL): - this_block_ctx: list[UOp] = [] - ctx.child_count[u] = 0 - - # get children and accumulate the last_ctx - for s in u.src: - if s.op is Ops.SPECIAL: continue - # NOTE: if a parent appears multiple times in the src, it counts multiple times as a child - ctx.child_count[s] += 1 - this_block_ctx += ctx.last_ctx(s) - - # save the block ctx. SINK never has anything - ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else () - - # RANGE/IF add to the next ctx - # STORE/ASSIGN subtract from the next ctx - if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,)) - elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src]) - return ctx - -# ***** make blocks ***** - -DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST} - -def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp: - ends_to_add = [z for z in new_ctx if z not in current_ctx] - while len(ends_to_add): - r:UOp = ends_to_add.pop(-1) - new_ctx = tuple([z for z in new_ctx if z is not r]) - end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.END, src=(r,)) - base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) - return base_block - -def make_block_bottom_up(ctx:BlockContext, x:UOp): - if x.op is Ops.BLOCKSTART: - current_ctx, child_ctx = x.arg - lst = list(x.src) - child_count = 1 - else: - current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None) - lst = [x] - - # count of times we've seen this block, or a seed for a new block if we can't merge it - unmergable: defaultdict[UOp, int] = defaultdict(int) - blockseeds = defaultdict(list) - - # add the srcs of this to the frontier - # NOTE: things may be in here multiple times, that's okay - frontier_nodes = list(flatten(y.src[::-1] for y in lst)) - while len(frontier_nodes): - u = frontier_nodes.pop(0) - if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1: - # count is correct - if (newctx:=ctx.block_ctxs[u]) == current_ctx: - # block has same context, merge it, and put the srcs on the frontier - lst.append(u) - frontier_nodes.extend(u.src[::-1]) - else: - # block has different context, add it to blockseeds - blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u) - del unmergable[u] - else: - # count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable - unmergable[u] += 1 - - # add unmergables to sources - srcs = [] - for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt - - # add blockseeds, with blockends as needed - for (new_ctx, new_child_ctx), v in blockseeds.items(): - base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx)) - srcs.append(add_blockends(base_block, new_ctx, current_ctx)) - - lst = lst[::-1] - if BLOCK_REORDER: lst = block_reorder(lst) - bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) - return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) - -# we prevent the source of the SPECIAL from being linearized since its not part of the kernel -def raise_bottom_up_gate(): raise BottomUpGate() - -block_create = PatternMatcher([ - (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up), - (UPat(Ops.SPECIAL), raise_bottom_up_gate) -]) - -# ***** blockend merging **** - -def merge_blockends(sink:UOp) -> UOp|None: - # only run on the final BLOCK with the SINK in it - if sink.arg.lst[-1].op is not Ops.SINK: return None - # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs - blockends_to_arg: dict[UOp, list[UOp]] = {} - for be in sink.toposort(): - if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) - new_forks = {} - for k,v in blockends_to_arg.items(): - # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails - if len(v) > 1: - bb = BasicBlock(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v)) - out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb) - # NOTE: bb.ctx != u.arg.ctx can cause problems here - for u in v: new_forks[u] = out - if len(new_forks) == 0: return None - return sink.substitute(new_forks) - -pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)]) - -# ***** block merging **** - -def merge_block(x:UOp): - unmergable_blocks, mergable_blocks = [], [] - mergable_dict: defaultdict[UOp, int] = defaultdict(int) - for y in x.src: - if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1 - elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1 - else: unmergable_blocks.append(y) - for k,v in mergable_dict.items(): - if v == k.arg.cnt: mergable_blocks.append(k) - else: unmergable_blocks.extend([k]*v) - if len(mergable_blocks) == 0: return None - del mergable_dict - - # create the block - arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst) - return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg) - -def remove_blockend(x:UOp): - # if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it - if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None - - if (parent_blocks := [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]): - assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})" - parent_block = parent_blocks[0] - assert len(parent_blocks) == parent_block.arg.cnt - # NOTE: DEFINE_ACC doesn't have to be handled in any special way - late_ops = list(x.arg.lst) - # NOTE: we have to add a barrier at the start if barrier is used in the range - if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.END: - late_ops = [UOp(Ops.BARRIER)] + late_ops - # peephole opt, remove any BARRIERs next to each other - for i in range(len(late_ops)-1): - if late_ops[i].op is Ops.BARRIER and late_ops[i+1].op is Ops.BARRIER: late_ops[i+1] = UOp(Ops.NOOP) - arg = BasicBlock(parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt) - return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg) - # else the whole context ended by the blockend is already in this block and we can safely turn it into a block - return UOp(Ops.BLOCK, src=x.src, arg=BasicBlock(x.arg.lst, tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)) - -block_merge = PatternMatcher([ - (UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block), - (UPat(Ops.BLOCKEND, name="x"), remove_blockend), -]) - -# ****** finalize ****** - -def finalize(sink:UOp) -> UOp: - if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): - raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}") - - # place the early things - lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst) - return UOp(Ops.BLOCKFINAL, arg=BasicBlock(tuple(lst))) - -pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)]) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index f6130bc6e4..1c20da1185 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -24,9 +24,6 @@ class Ops(FastEnum): # ops that adjust the behavior of the scheduler CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702 - # blocks in linearizer (only used there) - BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702 - # movement ops! these only exist in the tensor graph RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 MULTI = auto() # MULTI is really a movement op diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index d53785edfe..9d1e9794d3 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -268,7 +268,7 @@ full_spec = PatternMatcher([ (UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True), # linearizer: outputs + intermediate KERNELs - (UPat((Ops.BLOCKSTART, Ops.BLOCK, Ops.BLOCKFINAL, Ops.BLOCKEND, Ops.KERNEL), dtype=dtypes.void), lambda: True), + (UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True), # allow index dtype on a restricted set of UOps (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, Ops.WHERE, diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 70f8285511..8ac090359e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -17,8 +17,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", - **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", - Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", + **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", + Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"} From 7d9551ce2e92353da12e6e282a3662f77c98f074 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:15:06 +0800 Subject: [PATCH 30/30] move to late/control_flow.py (#12835) --- test/external/external_benchmark_schedule.py | 2 +- tinygrad/codegen/__init__.py | 2 +- tinygrad/codegen/{ => late}/control_flow.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename tinygrad/codegen/{ => late}/control_flow.py (100%) diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 0e91175bd8..1d0b223506 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -3,7 +3,7 @@ from tinygrad import Tensor, nn, Device from tinygrad.helpers import Profiling, Timing, getenv from tinygrad.uop.ops import Ops from tinygrad.codegen import get_rewrites_for_renderer, apply_rewrites -from tinygrad.codegen.control_flow import linearize +from tinygrad.codegen.late.control_flow import linearize from tinygrad.uop.spec import type_verify if __name__ == "__main__": diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index b0f8108053..bd3d771e2a 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -17,7 +17,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in from tinygrad.codegen.opt.postrange import pm_postrange_opt from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen -from tinygrad.codegen.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize +from tinygrad.codegen.late.control_flow import CFGContext, pm_merge_ends, pm_add_control_flow, linearize @dataclass class RewriteStep: diff --git a/tinygrad/codegen/control_flow.py b/tinygrad/codegen/late/control_flow.py similarity index 100% rename from tinygrad/codegen/control_flow.py rename to tinygrad/codegen/late/control_flow.py