From 7f0f97aa76e564a135ce36153798a5e5435010ee Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:26:28 -0700 Subject: [PATCH] new test_multitensor tests (#10667) * new test_multitensor tests * cleanup scheduler --- test/test_multitensor.py | 42 ++++++++++++++++++++++++++++++++----- tinygrad/engine/grouper.py | 11 +--------- tinygrad/engine/multi.py | 18 ++++++++++++++-- tinygrad/engine/schedule.py | 25 +++++++++------------- 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 0a88763825..5f2250f310 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1173,13 +1173,37 @@ class TestMultiAssign(unittest.TestCase): out[:, 2:3].assign(ones).realize() self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) + def test_multi_assign_var_offset(self): + out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize() + ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() + vi = Variable("i", 0, 3).bind(2) + out[:, vi:vi+1].assign(ones).realize() + self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) + + def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None) + def test_multi_assign_var_offset_jit(self, shard_axis=0): + out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize() + ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize() + + @TinyJit + def f(out:Tensor, vi): + out[:, vi:vi+1].assign(ones).realize() + ones.assign(ones+1).realize() + + vi = Variable("i", 0, 5) + for i in range(1,5): + GlobalCounters.reset() + f(out, vi.bind(i)) + self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4) + @unittest.skipIf(not_support_multi_device(), "need multi") class TestMultiTransformer(unittest.TestCase): def test_transformer(self): device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2)) from extra.models.llama import Transformer - args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64} + args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, + "hidden_dim": 64, "max_context": 12} real_model = Transformer(**args) shard_model = Transformer(**args) @@ -1199,10 +1223,18 @@ class TestMultiTransformer(unittest.TestCase): last_tok = 0 for i in range(10): - real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i) - shard_tok = shard_model(Tensor([[last_tok]], device=device), i) - last_tok = real_tok.item() - self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}") + real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item() + shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item() + + # test kv cache + kv1 = real_model.layers[0].attention.cache_kv.numpy() + kv2 = shard_model.layers[0].attention.cache_kv.numpy() + #print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4)) + np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}") + + # test token + self.assertEqual(real_tok, shard_tok, f"issue at token {i}") + last_tok = real_tok @unittest.skip("super slow") def test_llama1b_full(self): diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 829ef7bfd9..7dc6061a06 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -50,14 +50,6 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp): if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device) return base.copy_to_device(copy.device).view(view.arg) -def mselect_reorder_view(ms:UOp, view:UOp, base:UOp): - st = unwrap(view.st) - # replace dnum in ShapeTracker with literal const for this mselect - if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']): - assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}" - st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)}) - return base.mselect(ms.arg).view(st) - ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK} sym = symbolic_simple+PatternMatcher([ @@ -79,8 +71,6 @@ sym = symbolic_simple+PatternMatcher([ (UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None), # store a shrink before COPY, otherwise view after the COPY (UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view), - # MSELECT must select a base, if there are views apply them after selecting the base - (UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), @@ -570,6 +560,7 @@ def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]: if u.op is not Ops.ASSIGN: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: + # TODO: this is probably broken for MSELECT/MSTACK if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()): raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index 82a9dadc5f..b9a1ed36a9 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -1,6 +1,6 @@ from typing import cast import functools, itertools, operator -from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv +from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp # *** allreduce implementation *** @@ -53,7 +53,21 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: pads = [((s,numel-e),) for s,e in chunks] return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape) -replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),]) +# ***** multi rewrite MSELECT/MSTACK ***** + +def mselect_reorder_view(ms:UOp, view:UOp, base:UOp): + st = unwrap(view.st) + # replace dnum in ShapeTracker with literal const for this mselect + if (dnums:=[x for x in st.vars() if x.arg[0] == '_device_num']): + assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}" + st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)}) + return base.mselect(ms.arg).view(st) + +replace_allreduce = PatternMatcher([ + (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), + # MSELECT must select a base, if there are views apply them after selecting the base + (UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), mselect_reorder_view), +]) # ***** multi functions ***** diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3aefd61fb2..5805b7cede 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -77,28 +77,23 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) ubufs = tuple(s.buf_uop.buffer for s in k.src) if any(isinstance(x, MultiBuffer) for x in ubufs): - if ast.op is Ops.COPY: - assert ast.arg is None, "copy arg is no longer supported" - if isinstance(ubufs[1], MultiBuffer): # src is multiple buffers, none selected - if isinstance(ubufs[0], MultiBuffer): - # COPY ALL -> ALL - assert len(ubufs[0].bufs) == len(ubufs[1].bufs), "all to all copy must have matching buffer length" - for b1,b2 in zip(ubufs[0].bufs, ubufs[1].bufs): schedule.append(ScheduleItem(ast, (b1, b2), k.arg.metadata)) - else: - # COPY ANY -> ONE. Currently we just select the first - schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata)) + if ast.op is Ops.COPY and (isinstance(ubufs[0], Buffer) or isinstance(ubufs[1], Buffer)): + if isinstance(ubufs[1], MultiBuffer) and isinstance(ubufs[0], Buffer): # src is multiple buffers, none selected + # COPY ANY -> ONE. Currently we just select the first + schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata)) + elif isinstance(ubufs[0], MultiBuffer) and isinstance(ubufs[1], Buffer): + # COPY ONE -> ALL (BROADCAST) + for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata)) else: - assert isinstance(ubufs[1], Buffer), "src can't be MultiBuffer" - if isinstance(ubufs[0], MultiBuffer): - # COPY ONE -> ALL (BROADCAST) - for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, ubufs[1]), k.arg.metadata)) - else: schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1]), k.arg.metadata)) # COPY ONE -> ONE + raise RuntimeError("unsupported copy type") else: assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" + # ALL -> ALL dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {})) else: + # ONE -> ONE schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) for x in children[k]: in_degree[x] -= 1