diff --git a/test/test_schedule.py b/test/test_schedule.py index 4f8fa2542d..a0909a105f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -146,7 +146,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_equal(xt.numpy(), X.numpy()[1][0]) @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") - @unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting") def test_add_chain_buffers(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): @@ -1959,7 +1958,6 @@ class TestSchedule(unittest.TestCase): self.assertEqual(swizzle_cnt(new_uop), 0) @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") - @unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting") def test_limit_bufs_with_var(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index b0405f96de..17122485be 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -1,5 +1,5 @@ -from typing import Any, cast -import functools, operator +from typing import Any, cast, Iterator +import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify @@ -134,11 +134,8 @@ class RangeifyContext: progress: int = 0 # create ranges - range_idx: int = 0 - def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): - ret = UOp.range(s, self.range_idx, axistype) - self.range_idx += 1 - return ret + range_idx: Iterator[int] = field(default_factory=itertools.count) + def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): return UOp.range(s, next(self.range_idx), axistype) def map_reshape(idx:UOp, r:UOp): acc = 1 @@ -467,6 +464,30 @@ to_bufferview = PatternMatcher([ (UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)), ]) +DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device? +def limit_bufs(ctx:RangeifyContext, root:UOp): + if (device:=root._device) is None: return None # no device, index related calculations + device = device if isinstance(device, str) else device[0].split(":")[0] + if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None + + bufs: set[UOp] = set() + def gate_input(u:UOp): + # TODO: add cache to fix n^2 + if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.DEFINE_VAR}): bufs.add(u) + return not is_load + root.toposort(gate=gate_input) + + if len(bufs) > MAX_BUFS - 1: # NOTE: this -1 is for the output buffer + srcs = [] + for s in root.src: + if s.op in GroupOp.Elementwise: + # Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP + orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges] + s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=device)).index(*orig_ranges) + srcs.append(s) + return root.replace(src=tuple(srcs)) +pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs)]) + # ***************** # 4. put in buffers for bufferize # TODO: should BUFFERIZE look a lot more like STORE @@ -662,10 +683,11 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children") # rangeify - tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify") + tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify") # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") + tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers") # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph # MSTACK stacks multiple BUFFERIZEs in one tagged tensor