add device to local, fix PCONTIG=2 (#14266)

* add device to local, fix PCONTIG=2

* regression test

* remove the device when we render

* viz slowness

* no long
This commit is contained in:
George Hotz
2026-01-21 22:12:18 +09:00
committed by GitHub
parent c1d14ea832
commit 41d00a046d
4 changed files with 18 additions and 3 deletions

View File

@@ -76,6 +76,18 @@ class TestRangeifyEdgeCase(unittest.TestCase):
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
def test_pcontig_multi_gather(self):
# regression test: local bufferize must have device set for const_like to work
with Context(PCONTIG=2):
# NOTE: with uint type, this will become a long and fail on WEBGPU
forest = Tensor(list(range(8)), dtype='int')
idx = Tensor([0, 0], dtype='int')
node_val = forest.gather(0, idx)
idx2 = idx * 2 + 1
node_val2 = forest.gather(0, idx2)
result = (node_val + node_val2).numpy()
self.assertEqual(result.tolist(), [1, 1])
if getenv("BIG") > 2:
# llama 8B (8192)
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128

View File

@@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(None, AddrSpace.LOCAL, removable=removable)
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
@@ -429,6 +429,9 @@ to_define_global = PatternMatcher([
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
# remove device from local BUFFERIZE
(UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),

View File

@@ -120,7 +120,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
if len(u.toposort()) < 30: label += f"\n{u.render()}"
ranges: list[UOp] = []
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])