From 5524916e39e5a2e22db7ec888e8e091fb2b7e10b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:54:40 +0800 Subject: [PATCH] llama compute gradients explicitly + 243 GB of RAM on MP=8 (#15343) * llama compute gradients explicitly * apply grads * fix multi issue * multi BUFFER_VIEW support * simpler * skip the flaky test --- examples/mlperf/models/flat_llama.py | 43 +++++++++++++----- examples/mlperf/models/test_flat_llama.py | 36 +++++++++++++++ test/backend/test_multitensor.py | 45 +++++++++++++++++++ test/null/test_schedule.py | 55 +++++++++++++++++++++++ tinygrad/engine/allocations.py | 27 +++++++---- tinygrad/schedule/multi.py | 5 ++- tinygrad/uop/ops.py | 4 ++ 7 files changed, 194 insertions(+), 21 deletions(-) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 67f635811c..e5c235e45e 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -14,6 +14,7 @@ if __name__ == "__main__": os.environ["ASM_GEMM"] = "1" from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit from tinygrad.helpers import Timing, colored, GlobalCounters +from tinygrad.uop.ops import Ops, UOp from extra.models.llama import apply_rotary_emb, precompute_freqs_cis def rmsnorm(x_in:Tensor, eps:float): @@ -41,8 +42,8 @@ class FlatTransformer: self.w3 = self.lin_per_layer(dim, hidden_dim) self.norm_eps = norm_eps - self.attention_norm = Tensor.ones(n_layers, dim) - self.ffn_norm = Tensor.ones(n_layers, dim) + self.attention_norm = Tensor.ones(n_layers, dim).contiguous() + self.ffn_norm = Tensor.ones(n_layers, dim).contiguous() # output self.norm = nn.RMSNorm(dim, norm_eps) @@ -124,31 +125,49 @@ if __name__ == "__main__": model = FlatTransformer(**model_params, max_context=SEQLEN) state = nn.state.get_state_dict(model) print("tensor count:", len(state)) - sz = 0 - for k,v in state.items(): - if v.requires_grad is None: v.requires_grad_(True) - print(f"{colored(k, 'green' if v.requires_grad else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device}") - sz += v.nbytes() - print(f"total sz: {sz/1e9:.2f} GB") + # shard the model from tinygrad import Device if (DP := getenv("DP", 1)) > 1: model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))) if (MP := getenv("MP", 1)) > 1: model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True) - with Timing("realize weights: "): Tensor.realize(*state.values()) - with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int).realize() + # preallocate all the grad buffers and zero them out + grads = {x:Tensor.zeros_like(x).contiguous() for x in state.values() if x.requires_grad is None} + + # print model size + sz = 0 + for k,v in state.items(): + print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB") + sz += v.nbytes() + print(f"total sz: {sz/1e9:.2f} GB") + + with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int) + with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens) + print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items()))) if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0) if MP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))) + # TODO: this shouldn't be needed, but it prevents a copy of the grads. CAT can help + def apply_grad(old_grad:UOp, new_grad:UOp) -> list[UOp]: + if new_grad.op == Ops.ADD: + return apply_grad(old_grad, new_grad.src[0])+apply_grad(old_grad, new_grad.src[1]) + elif new_grad.op == Ops.PAD: + grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(new_grad.src[0].shape, new_grad.marg)]) + return apply_grad(old_grad.shrink(grad_shrink), new_grad.src[0]) + else: + return [old_grad.store(old_grad + new_grad)] + @TinyJit def jit_step(tokens:Tensor): GlobalCounters.reset() print(colored("*** step", "red")) with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:]) - with Timing("python backward: "): loss.backward() - with Timing("run step: "): loss.realize(*[x.grad for x in state.values() if x.requires_grad]) + with Timing("python backward: "): + for t,g in zip(grads, loss.gradient(*grads)): + grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device) + with Timing("run step: "): loss.realize(*grads.values()) jit_step(tokens) jit_step(tokens) diff --git a/examples/mlperf/models/test_flat_llama.py b/examples/mlperf/models/test_flat_llama.py index 27032f14f7..fe99de4ac7 100644 --- a/examples/mlperf/models/test_flat_llama.py +++ b/examples/mlperf/models/test_flat_llama.py @@ -77,5 +77,41 @@ class TestFlatLlama(unittest.TestCase): diff = abs(ref_grads[ref_key] - flat_grads[flat_key][i]).max() self.assertLess(diff, 1e-4, f"layer {i} {flat_key} grad mismatch: max abs diff {diff}") + @unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test") + def test_forward_match_mp(self): + Tensor.manual_seed(42) + params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64) + from tinygrad import Device + devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1") + ref = Transformer(**params) + flat = FlatTransformer(**params) + copy_weights(flat, ref) + Tensor.realize(*nn.state.get_state_dict(flat).values()) + flat.shard(devices, mp=True) + + tokens = Tensor([[1, 50, 100, 999, 2]], device=devices[0]) + ref_logits = ref(tokens.to(devices[0])).numpy() + flat_logits = flat(tokens.shard(devices)).numpy() + self.assertEqual(ref_logits.shape, flat_logits.shape) + np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4) + + @unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test") + def test_forward_match_dp(self): + Tensor.manual_seed(42) + params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64) + from tinygrad import Device + devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1") + ref = Transformer(**params) + flat = FlatTransformer(**params) + copy_weights(flat, ref) + Tensor.realize(*nn.state.get_state_dict(flat).values()) + flat.shard(devices) + + tokens = Tensor([[1, 50, 100, 999, 2], [2, 100, 50, 1, 999]], device=devices[0]) + ref_logits = ref(tokens.to(devices[0])).numpy() + flat_logits = flat(tokens.shard(devices, axis=0)).numpy() + self.assertEqual(ref_logits.shape, flat_logits.shape) + np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index e5a90a2ea7..c002478f2d 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -1130,6 +1130,51 @@ class TestTensorOps(unittest.TestCase): def test_bitcast(self): helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int)) +@unittest.skipIf(not_support_multi_device(), "need multi") +class TestMultiBufferView(unittest.TestCase): + @needs_second_gpu + def setUp(self): pass + + def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn): + """Apply view_fn to both, verify zero compiled kernels and matching values.""" + b_ref = view_fn(a_ref) + b_multi = view_fn(a_multi).contiguous() + sched = b_multi.schedule() + compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)] + self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}") + run_schedule(sched) + np.testing.assert_equal(b_multi.numpy(), b_ref.numpy()) + + @unittest.skip("flaky on LLVM") + def test_shrink_non_shard_axis(self): + ref = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().realize() + a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices_2, axis=1).realize() + self._check(ref, a, lambda t: t[3]) + + def test_shrink_2d(self): + ref = Tensor.arange(6*4).reshape(6, 4).contiguous().realize() + a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices_2, axis=1).realize() + self._check(ref, a, lambda t: t.shrink(((1, 4), None))) + + def test_reshape_then_shrink(self): + ref = Tensor.arange(8*6).reshape(8, 6).contiguous().realize() + a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices_2, axis=1).realize() + self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1]) + + def test_chained_shrink(self): + ref = Tensor.arange(10*8).reshape(10, 8).contiguous().realize() + a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices_2, axis=1).realize() + self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None))) + + def test_4_devices(self): + ref = Tensor.arange(8*12).reshape(8, 12).contiguous().realize() + a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices_4, axis=1).realize() + sched = a[5].contiguous().schedule() + compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)] + self.assertEqual(len(compiled), 0) + run_schedule(sched) + np.testing.assert_equal(a[5].contiguous().numpy(), ref[5].numpy()) + @unittest.skipIf(not_support_multi_device(), "need multi") class TestMultiFromUnrenderable(unittest.TestCase): @needs_second_gpu diff --git a/test/null/test_schedule.py b/test/null/test_schedule.py index a7072a1eed..6c38de9c73 100644 --- a/test/null/test_schedule.py +++ b/test/null/test_schedule.py @@ -1189,6 +1189,61 @@ class TestBufferView(unittest.TestCase): b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous() run_schedule(check_schedule(b, 0)) + def test_shrink_non_shard_axis_is_buffer_view_multi(self): + # indexing a non-shard axis of a realized sharded tensor should be BUFFER_VIEW on each device, not copy kernels + # this is the flat_llama pattern: weight[layer_idx] where weight is (n_layers, out, dim) sharded on axis=1 + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a[3].contiguous(), 0)) + + def test_shrink_2d_non_shard_axis_multi(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.shrink(((1, 4), None)).contiguous(), 0)) + + def test_shrink_shard_axis_0_multi(self): + # shrinking a middle dim is not contiguous per shard, so this needs copy kernels + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=0).realize() + run_schedule(check_schedule(a.shrink((None, (2, 5), None)).contiguous(), 2)) + + def test_reshape_then_shrink_multi(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.reshape(4, 2, 6)[1].contiguous(), 0)) + + def test_permute_then_shrink_multi(self): + # permute makes per-shard view non-contiguous, needs copy kernels + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.permute(1, 0, 2).shrink(((0, 6), (1, 3), None)).contiguous(), 2)) + + def test_multi_buffer_view_4_devices(self): + devices = tuple(f"NULL:{i}" for i in range(4)) + a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a[5].contiguous(), 0)) + + def test_chained_shrink_multi(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0)) + + # negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard) + def test_expand_multi_not_buffer_view(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(4*2).reshape(4, 1, 2).contiguous().shard(devices, axis=2).realize() + run_schedule(check_schedule(a.expand(4, 3, 2).contiguous(), 2)) + + def test_pad_multi_not_buffer_view(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.pad(((1, 1), (0, 0))).contiguous(), 2)) + + def test_flip_multi_not_buffer_view(self): + devices = ("NULL:1", "NULL:2") + a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize() + run_schedule(check_schedule(a.flip(0).contiguous(), 2)) + class TestInvalidTensor(unittest.TestCase): def test_full_invalid_is_zero_kernels(self): from tinygrad.dtype import Invalid diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index c9a6004b24..64d67023b9 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -66,6 +66,13 @@ def replace_store_after_with_contig(u:UOp, src:UOp): while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag) +def _make_buffer_view(src:UOp) -> UOp|None: + """If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None.""" + if (offset := src.contiguous_view_offset()) is None: return None + buf = src.base + if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0] + return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape) + def contiguous_mops_to_view(c:UOp, src:UOp): """CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range.""" buf = src.base @@ -76,18 +83,22 @@ def contiguous_mops_to_view(c:UOp, src:UOp): if not all_int(c.shape): return None # check if view is supported - if not isinstance(c.device, str): return None from tinygrad.device import Device - if not hasattr(Device[c.device].allocator, "_offset"): return None + if isinstance(c.device, str): + if not hasattr(Device[c.device].allocator, "_offset"): return None + elif not all(hasattr(Device[d].allocator, "_offset") for d in c.device): return None - # see if this can be a view - if (offset := src.contiguous_view_offset()) is None: return None - - # merge BUFFER_VIEWs - if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0] + # for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create BUFFER_VIEW on the resolved result + if not isinstance(c.device, str): + from tinygrad.schedule.multi import multi_pm + resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view") + if resolved.op is not Ops.MULTI: return None + if (view := _make_buffer_view(resolved.src[0])) is None: return None + return view.multi(resolved.arg).contiguous(tag=c.tag) # NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity - return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag) + if (view := _make_buffer_view(src)) is None: return None + return view.contiguous(tag=c.tag) def transform_precompiled_call(c:UOp) -> UOp|None: if not c.arg.precompile: return None diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index afaf6b480f..7d3b50232b 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -151,11 +151,14 @@ multi_pm = PatternMatcher([ else multi), # rewrite into calls explicitly for MULTI (UPat(Ops.CALL, name="call"), rewrite_into_call), - (UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), + (UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), # we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call (UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root: UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg) if root.src[0].op is not Ops.TUPLE else None), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), + # remove MULTI from STORE + (UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), + lambda root,multi: UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg)), ])+replace_allreduce diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 97423bdb37..eab281e551 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -720,6 +720,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return buf.view(self.size, self.dtype, 0) if self.op is Ops.BUFFER_VIEW: buf = self.src[0].buffer + if isinstance(buf, MultiBuffer): + mbuf = MultiBuffer.__new__(MultiBuffer) + mbuf.bufs = [b.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs] + return mbuf assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW" return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) if self.op is Ops.MSELECT: