diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 6466f7d609..8cc4dfbb69 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1750,7 +1750,7 @@ class TestHandCodedOpts(unittest.TestCase): # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous assert k.upcasted == 1 and k.full_shape[-1] == 7 - @unittest.skipIf(Device.DEFAULT == "METAL", "METAL can only run kernels with up to 32 buffers") + @unittest.skipIf(Device.DEFAULT in {"METAL", "WEBGPU"}, "METAL/WEBGPU split this kernel since it has 37 buffers") def test_masked_upcast_wino(self): monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)]) diff --git a/test/test_schedule.py b/test/test_schedule.py index 517c96a6c2..02c03466cc 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -98,6 +98,17 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(xt, 2)) np.testing.assert_equal(xt.numpy(), X.numpy()[1][0]) + @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") + def test_add_chain_buffers(self): + N = 31 + with Context(TRACK_MATCH_STATS=0, DEBUG=0): + bufs = [Tensor(i).reshape((1,)).contiguous().realize() for i in range(N)] + for X in range(1,N): + root = bufs[0] + for i in range(1,N,X): + root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X]) + self.assertEqual(root.item(), sum(range(N))) + @unittest.expectedFailure # TODO: failing because of can_chase def test_indexing_scalars_multiple_dims(self): X = Tensor.randn(2, 3).realize() diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 0db8f18ea0..f6feb8db04 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -502,6 +502,29 @@ def get_name(becomes_map:dict[UOp, UOp]) -> str: add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"), lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)]) + +# TODO: get this from the device through GrouperOpts +DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8} + +def limit_bufs(root:UOp): + # check if backend has a buffer limit + device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0] + if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None + # count number of unique buffers flowing into this op + bufs: set[UOp] = set() + def gate_input(u:UOp): + if (is_buffer:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u) + return not is_buffer + root.toposort(gate=gate_input) + # NOTE: this -1 is for the output buffer + if len(bufs)>=MAX_BUFS-1: + return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src)) + +split_kernels = PatternMatcher([ + (UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs), + (UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]), +]) + remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) @track_rewrites(name_fxn=get_name) @@ -515,6 +538,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]: # insert gbarriers in places determined by the realize map realize_map = group_realizes(tensor_map[big_sink]) tensor_map = graph_rewrite_map(tensor_map[big_sink], add_gbarrier, realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier") + tensor_map = graph_rewrite_map(tensor_map[big_sink], split_kernels, input_map=tensor_map, name="split_kernels") tensor_map = graph_rewrite_map(tensor_map[big_sink], remove_tags, input_map=tensor_map, name="remove_tags") # TODO: move view_left/view_right here