do not create kernels with more inputs than the backend allows (#10510)

* work

* no itertools + top down pass

* clean viz

* python can do that

* webgpu

* gbarrier of gbarrier is gbarrier

* device can be tuple

* bug in toposort

* failing test for gated toposort

* contiguous of gbarrier is gbarrier

* check for binops

* Revert "check for binops"

This reverts commit 53e3cdf720.

* viz + match on gbarrier, self exists by default

* alt

* green now

* cleanup
This commit is contained in:
qazal
2025-05-26 18:02:03 +03:00
committed by GitHub
parent deb369417c
commit 9169dcfb49
3 changed files with 36 additions and 1 deletions

View File

@@ -1750,7 +1750,7 @@ class TestHandCodedOpts(unittest.TestCase):
# float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
assert k.upcasted == 1 and k.full_shape[-1] == 7
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL can only run kernels with up to 32 buffers")
@unittest.skipIf(Device.DEFAULT in {"METAL", "WEBGPU"}, "METAL/WEBGPU split this kernel since it has 37 buffers")
def test_masked_upcast_wino(self):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)])

View File

@@ -98,6 +98,17 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(xt, 2))
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_add_chain_buffers(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
bufs = [Tensor(i).reshape((1,)).contiguous().realize() for i in range(N)]
for X in range(1,N):
root = bufs[0]
for i in range(1,N,X):
root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X])
self.assertEqual(root.item(), sum(range(N)))
@unittest.expectedFailure # TODO: failing because of can_chase
def test_indexing_scalars_multiple_dims(self):
X = Tensor.randn(2, 3).realize()

View File

@@ -502,6 +502,29 @@ def get_name(becomes_map:dict[UOp, UOp]) -> str:
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
# TODO: get this from the device through GrouperOpts
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
def limit_bufs(root:UOp):
# check if backend has a buffer limit
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
# count number of unique buffers flowing into this op
bufs: set[UOp] = set()
def gate_input(u:UOp):
if (is_buffer:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u)
return not is_buffer
root.toposort(gate=gate_input)
# NOTE: this -1 is for the output buffer
if len(bufs)>=MAX_BUFS-1:
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src))
split_kernels = PatternMatcher([
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
(UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]),
])
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
@track_rewrites(name_fxn=get_name)
@@ -515,6 +538,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
# insert gbarriers in places determined by the realize map
realize_map = group_realizes(tensor_map[big_sink])
tensor_map = graph_rewrite_map(tensor_map[big_sink], add_gbarrier, realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier")
tensor_map = graph_rewrite_map(tensor_map[big_sink], split_kernels, input_map=tensor_map, name="split_kernels")
tensor_map = graph_rewrite_map(tensor_map[big_sink], remove_tags, input_map=tensor_map, name="remove_tags")
# TODO: move view_left/view_right here