mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
do not create kernels with more inputs than the backend allows (#10510)
* work
* no itertools + top down pass
* clean viz
* python can do that
* webgpu
* gbarrier of gbarrier is gbarrier
* device can be tuple
* bug in toposort
* failing test for gated toposort
* contiguous of gbarrier is gbarrier
* check for binops
* Revert "check for binops"
This reverts commit 53e3cdf720.
* viz + match on gbarrier, self exists by default
* alt
* green now
* cleanup
This commit is contained in:
@@ -1750,7 +1750,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
|
||||
assert k.upcasted == 1 and k.full_shape[-1] == 7
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL can only run kernels with up to 32 buffers")
|
||||
@unittest.skipIf(Device.DEFAULT in {"METAL", "WEBGPU"}, "METAL/WEBGPU split this kernel since it has 37 buffers")
|
||||
def test_masked_upcast_wino(self):
|
||||
monster = Tensor.stack(*[Tensor.stack(*[Tensor.empty(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
|
||||
@@ -98,6 +98,17 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(xt, 2))
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
def test_add_chain_buffers(self):
|
||||
N = 31
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
bufs = [Tensor(i).reshape((1,)).contiguous().realize() for i in range(N)]
|
||||
for X in range(1,N):
|
||||
root = bufs[0]
|
||||
for i in range(1,N,X):
|
||||
root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X])
|
||||
self.assertEqual(root.item(), sum(range(N)))
|
||||
|
||||
@unittest.expectedFailure # TODO: failing because of can_chase
|
||||
def test_indexing_scalars_multiple_dims(self):
|
||||
X = Tensor.randn(2, 3).realize()
|
||||
|
||||
@@ -502,6 +502,29 @@ def get_name(becomes_map:dict[UOp, UOp]) -> str:
|
||||
|
||||
add_gbarrier = PatternMatcher([(UPat(GroupOp.All-{Ops.GBARRIER, Ops.ASSIGN}, name="x"),
|
||||
lambda ctx,x: x.replace(tag=1).gbarrier() if x in ctx and x.tag is None else None)])
|
||||
|
||||
# TODO: get this from the device through GrouperOpts
|
||||
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
||||
|
||||
def limit_bufs(root:UOp):
|
||||
# check if backend has a buffer limit
|
||||
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
|
||||
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
|
||||
# count number of unique buffers flowing into this op
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
if (is_buffer:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u)
|
||||
return not is_buffer
|
||||
root.toposort(gate=gate_input)
|
||||
# NOTE: this -1 is for the output buffer
|
||||
if len(bufs)>=MAX_BUFS-1:
|
||||
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).gbarrier() for s in root.src))
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
||||
(UPat((Ops.GBARRIER, Ops.CONTIGUOUS), src=(UPat(Ops.GBARRIER),), name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
@@ -515,6 +538,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
# insert gbarriers in places determined by the realize map
|
||||
realize_map = group_realizes(tensor_map[big_sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], add_gbarrier, realize_map, bottom_up=True, input_map=tensor_map, name="insert_gbarrier")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], split_kernels, input_map=tensor_map, name="split_kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], remove_tags, input_map=tensor_map, name="remove_tags")
|
||||
|
||||
# TODO: move view_left/view_right here
|
||||
|
||||
Reference in New Issue
Block a user