diff --git a/test/test_rangeify.py b/test/test_rangeify.py index ad480d7949..b4ec084a1b 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -76,6 +76,18 @@ class TestRangeifyEdgeCase(unittest.TestCase): res = Tensor.cat(a, c, dim=0) self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16) + def test_pcontig_multi_gather(self): + # regression test: local bufferize must have device set for const_like to work + with Context(PCONTIG=2): + # NOTE: with uint type, this will become a long and fail on WEBGPU + forest = Tensor(list(range(8)), dtype='int') + idx = Tensor([0, 0], dtype='int') + node_val = forest.gather(0, idx) + idx2 = idx * 2 + 1 + node_val2 = forest.gather(0, idx2) + result = (node_val + node_val2).numpy() + self.assertEqual(result.tolist(), [1, 1]) + if getenv("BIG") > 2: # llama 8B (8192) BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128 diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index c56b635b98..e7bbf33db4 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS # None in the device assigns it a number later opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \ - BufferizeOpts(None, AddrSpace.LOCAL, removable=removable) + BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable) new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges]) new_srcs.append(new_src) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 314ff2eefa..e81f5a6d32 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo @@ -429,6 +429,9 @@ to_define_global = PatternMatcher([ (UPat(Ops.BIND, name="b"), unbind_kernel), (UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after), + # remove device from local BUFFERIZE + (UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))), + # HACK in case any CONSTs were replaced # this is only needed if you are using symbolic (UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1d9c36c99e..8603e92396 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -120,7 +120,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: if u._shape is not None: label += f"\n{shape_to_str(u.shape)}" if u.op in {Ops.INDEX, Ops.BUFFERIZE}: - label += f"\n{u.render()}" + if len(u.toposort()) < 30: label += f"\n{u.render()}" ranges: list[UOp] = [] for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}] if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])