From 20a232f1c5bea58d7e81ef3cad635416fb7f5d78 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 21 Oct 2025 19:21:02 +0800 Subject: [PATCH] bugfixes from multioutput + PCONTIG=3 for fa bw memory fix (#12837) * bugfixes from multioutput * PCONTIG=3 fixes fa memory usage * that's base --- test/test_rangeify.py | 6 +++--- tinygrad/codegen/late/control_flow.py | 2 +- tinygrad/helpers.py | 1 + tinygrad/schedule/indexing.py | 1 + tinygrad/schedule/rangeify.py | 19 ++++++++++++++----- tinygrad/uop/ops.py | 8 +++++--- 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 9bed5c1481..a7fe93990e 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, nn, Device -from tinygrad.helpers import Context, GlobalCounters, CI, getenv +from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer @@ -64,11 +64,11 @@ class TestPcontig(unittest.TestCase): Tensor.realize(*ret) return ret - with Context(PCONTIG=2, DEBUG=2): + with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2): grads = fa_bw() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") - with Context(DEBUG=2): + with Context(PCONTIG=0, DEBUG=2): cmp_grads = fa_bw() print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS") diff --git a/tinygrad/codegen/late/control_flow.py b/tinygrad/codegen/late/control_flow.py index 3eb9e56931..ce61556866 100644 --- a/tinygrad/codegen/late/control_flow.py +++ b/tinygrad/codegen/late/control_flow.py @@ -86,7 +86,7 @@ def do_merge_ends(s:UOp): replaces = {} for k,v in stacked.items(): if len(v) == 1: continue - rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=x[0].arg) + rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=v[0].arg) for x in v: replaces[x] = rep if not len(replaces) and not len(dangling_ifs): return None ret = s.substitute(replaces) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index aeb1fc5d8a..80d7a960c4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -170,6 +170,7 @@ SPEC = ContextVar("SPEC", 0) # TODO: disable by default due to speed IGNORE_OOB = ContextVar("IGNORE_OOB", 1) PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify +DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index a78ac20cdc..5b7ca601a2 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -141,6 +141,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO @profile_matches def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: + if debug: print("**************************") rctx = IndexingContext() # get ops to realize diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index eae57227ae..8f1add590d 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -4,7 +4,8 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate from tinygrad.uop.symbolic import symbolic_flat -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY +from tinygrad.helpers import PCONTIG from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op @@ -157,16 +158,19 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): accessed_buffers: list[UOp] = [] reduces: list[UOp] = [] def red_gate(x:UOp): - if x.op is Ops.INDEX: + if x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL: accessed_buffers.append(x) return False + if x.op is Ops.BUFFER: + accessed_buffers.append(x) if x.op is Ops.REDUCE: reduces.append(x) return True src.toposort(gate=red_gate) del red_gate + accessed_buffers = dedup(accessed_buffers) # if this is generated from multiple buffers, don't remove this buffer - if len(dedup([x.src[0] for x in accessed_buffers])) > 2: return None + if len(accessed_buffers) > 2 and not (PCONTIG > 2): return None # if any reduces access a buffer, don't remove this buffer buffer_in_reduce = False @@ -176,7 +180,12 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): return not buffer_in_reduce UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate) del buf_gate - if buffer_in_reduce: return None + if buffer_in_reduce: + if PCONTIG > 2: + out_in_ratio = (prod(buf.shape)+1) / (sum([x.size for x in accessed_buffers])+1) + if out_in_ratio < 10: return None + else: + return None # if it makes it here, the bufferize is removed # this is the ranges replaced @@ -477,7 +486,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: tsink = graph_rewrite(tsink, earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites") # convert movement ops to ranges - tsink, rctx = run_rangeify(tsink, getenv("DEBUG_RANGEIFY", 0)) + tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 39dc77c6b7..7e7d2d0d62 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -362,7 +362,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs) def end(self, *src:UOp, ends:Sequence[UOp]): - if len(ends) == 0: return self + if len(ends) == 0: + if len(src): return UOp(Ops.NOOP, src=(self, *src)) + return self return UOp(Ops.END, src=(*ends, self, *src), arg=len(ends)) def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) @@ -555,8 +557,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.BUFFER: return self if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src)) - assert self.op is Ops.AFTER, f"must be AFTER {self.op}" - return self.src[0].buf_uop.base + assert self.base.op is Ops.AFTER, f"must be AFTER {self.base.op}" + return self.base.src[0].buf_uop.base def as_buf(self) -> UOp: if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)