mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add device to local, fix PCONTIG=2 (#14266)
* add device to local, fix PCONTIG=2 * regression test * remove the device when we render * viz slowness * no long
This commit is contained in:
@@ -76,6 +76,18 @@ class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
def test_pcontig_multi_gather(self):
|
||||
# regression test: local bufferize must have device set for const_like to work
|
||||
with Context(PCONTIG=2):
|
||||
# NOTE: with uint type, this will become a long and fail on WEBGPU
|
||||
forest = Tensor(list(range(8)), dtype='int')
|
||||
idx = Tensor([0, 0], dtype='int')
|
||||
node_val = forest.gather(0, idx)
|
||||
idx2 = idx * 2 + 1
|
||||
node_val2 = forest.gather(0, idx2)
|
||||
result = (node_val + node_val2).numpy()
|
||||
self.assertEqual(result.tolist(), [1, 1])
|
||||
|
||||
if getenv("BIG") > 2:
|
||||
# llama 8B (8192)
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128
|
||||
|
||||
@@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
BufferizeOpts(None, AddrSpace.LOCAL, removable=removable)
|
||||
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
@@ -429,6 +429,9 @@ to_define_global = PatternMatcher([
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
|
||||
|
||||
# remove device from local BUFFERIZE
|
||||
(UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||
|
||||
# HACK in case any CONSTs were replaced
|
||||
# this is only needed if you are using symbolic
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
@@ -120,7 +120,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
if u._shape is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
label += f"\n{u.render()}"
|
||||
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
||||
ranges: list[UOp] = []
|
||||
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
|
||||
if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])
|
||||
|
||||
Reference in New Issue
Block a user