llama compute gradients explicitly + 243 GB of RAM on MP=8 (#15343)

* llama compute gradients explicitly

* apply grads

* fix multi issue

* multi BUFFER_VIEW support

* simpler

* skip the flaky test
This commit is contained in:
George Hotz
2026-03-18 19:54:40 +08:00
committed by GitHub
parent ff004d2114
commit 5524916e39
7 changed files with 194 additions and 21 deletions

View File

@@ -14,6 +14,7 @@ if __name__ == "__main__":
os.environ["ASM_GEMM"] = "1"
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
from tinygrad.helpers import Timing, colored, GlobalCounters
from tinygrad.uop.ops import Ops, UOp
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
def rmsnorm(x_in:Tensor, eps:float):
@@ -41,8 +42,8 @@ class FlatTransformer:
self.w3 = self.lin_per_layer(dim, hidden_dim)
self.norm_eps = norm_eps
self.attention_norm = Tensor.ones(n_layers, dim)
self.ffn_norm = Tensor.ones(n_layers, dim)
self.attention_norm = Tensor.ones(n_layers, dim).contiguous()
self.ffn_norm = Tensor.ones(n_layers, dim).contiguous()
# output
self.norm = nn.RMSNorm(dim, norm_eps)
@@ -124,31 +125,49 @@ if __name__ == "__main__":
model = FlatTransformer(**model_params, max_context=SEQLEN)
state = nn.state.get_state_dict(model)
print("tensor count:", len(state))
sz = 0
for k,v in state.items():
if v.requires_grad is None: v.requires_grad_(True)
print(f"{colored(k, 'green' if v.requires_grad else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device}")
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
# shard the model
from tinygrad import Device
if (DP := getenv("DP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)))
if (MP := getenv("MP", 1)) > 1:
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True)
with Timing("realize weights: "): Tensor.realize(*state.values())
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int).realize()
# preallocate all the grad buffers and zero them out
grads = {x:Tensor.zeros_like(x).contiguous() for x in state.values() if x.requires_grad is None}
# print model size
sz = 0
for k,v in state.items():
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB")
sz += v.nbytes()
print(f"total sz: {sz/1e9:.2f} GB")
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int)
with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens)
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0)
if MP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)))
# TODO: this shouldn't be needed, but it prevents a copy of the grads. CAT can help
def apply_grad(old_grad:UOp, new_grad:UOp) -> list[UOp]:
if new_grad.op == Ops.ADD:
return apply_grad(old_grad, new_grad.src[0])+apply_grad(old_grad, new_grad.src[1])
elif new_grad.op == Ops.PAD:
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(new_grad.src[0].shape, new_grad.marg)])
return apply_grad(old_grad.shrink(grad_shrink), new_grad.src[0])
else:
return [old_grad.store(old_grad + new_grad)]
@TinyJit
def jit_step(tokens:Tensor):
GlobalCounters.reset()
print(colored("*** step", "red"))
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
with Timing("python backward: "): loss.backward()
with Timing("run step: "): loss.realize(*[x.grad for x in state.values() if x.requires_grad])
with Timing("python backward: "):
for t,g in zip(grads, loss.gradient(*grads)):
grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device)
with Timing("run step: "): loss.realize(*grads.values())
jit_step(tokens)
jit_step(tokens)

View File

@@ -77,5 +77,41 @@ class TestFlatLlama(unittest.TestCase):
diff = abs(ref_grads[ref_key] - flat_grads[flat_key][i]).max()
self.assertLess(diff, 1e-4, f"layer {i} {flat_key} grad mismatch: max abs diff {diff}")
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
def test_forward_match_mp(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
from tinygrad import Device
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
flat.shard(devices, mp=True)
tokens = Tensor([[1, 50, 100, 999, 2]], device=devices[0])
ref_logits = ref(tokens.to(devices[0])).numpy()
flat_logits = flat(tokens.shard(devices)).numpy()
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
def test_forward_match_dp(self):
Tensor.manual_seed(42)
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
from tinygrad import Device
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
ref = Transformer(**params)
flat = FlatTransformer(**params)
copy_weights(flat, ref)
Tensor.realize(*nn.state.get_state_dict(flat).values())
flat.shard(devices)
tokens = Tensor([[1, 50, 100, 999, 2], [2, 100, 50, 1, 999]], device=devices[0])
ref_logits = ref(tokens.to(devices[0])).numpy()
flat_logits = flat(tokens.shard(devices, axis=0)).numpy()
self.assertEqual(ref_logits.shape, flat_logits.shape)
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
unittest.main()

View File

@@ -1130,6 +1130,51 @@ class TestTensorOps(unittest.TestCase):
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiBufferView(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
"""Apply view_fn to both, verify zero compiled kernels and matching values."""
b_ref = view_fn(a_ref)
b_multi = view_fn(a_multi).contiguous()
sched = b_multi.schedule()
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
run_schedule(sched)
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
@unittest.skip("flaky on LLVM")
def test_shrink_non_shard_axis(self):
ref = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().realize()
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t[3])
def test_shrink_2d(self):
ref = Tensor.arange(6*4).reshape(6, 4).contiguous().realize()
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((1, 4), None)))
def test_reshape_then_shrink(self):
ref = Tensor.arange(8*6).reshape(8, 6).contiguous().realize()
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1])
def test_chained_shrink(self):
ref = Tensor.arange(10*8).reshape(10, 8).contiguous().realize()
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None)))
def test_4_devices(self):
ref = Tensor.arange(8*12).reshape(8, 12).contiguous().realize()
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices_4, axis=1).realize()
sched = a[5].contiguous().schedule()
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0)
run_schedule(sched)
np.testing.assert_equal(a[5].contiguous().numpy(), ref[5].numpy())
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):
@needs_second_gpu

View File

@@ -1189,6 +1189,61 @@ class TestBufferView(unittest.TestCase):
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
run_schedule(check_schedule(b, 0))
def test_shrink_non_shard_axis_is_buffer_view_multi(self):
# indexing a non-shard axis of a realized sharded tensor should be BUFFER_VIEW on each device, not copy kernels
# this is the flat_llama pattern: weight[layer_idx] where weight is (n_layers, out, dim) sharded on axis=1
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a[3].contiguous(), 0))
def test_shrink_2d_non_shard_axis_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.shrink(((1, 4), None)).contiguous(), 0))
def test_shrink_shard_axis_0_multi(self):
# shrinking a middle dim is not contiguous per shard, so this needs copy kernels
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=0).realize()
run_schedule(check_schedule(a.shrink((None, (2, 5), None)).contiguous(), 2))
def test_reshape_then_shrink_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.reshape(4, 2, 6)[1].contiguous(), 0))
def test_permute_then_shrink_multi(self):
# permute makes per-shard view non-contiguous, needs copy kernels
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.permute(1, 0, 2).shrink(((0, 6), (1, 3), None)).contiguous(), 2))
def test_multi_buffer_view_4_devices(self):
devices = tuple(f"NULL:{i}" for i in range(4))
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a[5].contiguous(), 0))
def test_chained_shrink_multi(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0))
# negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard)
def test_expand_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 1, 2).contiguous().shard(devices, axis=2).realize()
run_schedule(check_schedule(a.expand(4, 3, 2).contiguous(), 2))
def test_pad_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.pad(((1, 1), (0, 0))).contiguous(), 2))
def test_flip_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
run_schedule(check_schedule(a.flip(0).contiguous(), 2))
class TestInvalidTensor(unittest.TestCase):
def test_full_invalid_is_zero_kernels(self):
from tinygrad.dtype import Invalid

View File

@@ -66,6 +66,13 @@ def replace_store_after_with_contig(u:UOp, src:UOp):
while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
def _make_buffer_view(src:UOp) -> UOp|None:
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
if (offset := src.contiguous_view_offset()) is None: return None
buf = src.base
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape)
def contiguous_mops_to_view(c:UOp, src:UOp):
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
buf = src.base
@@ -76,18 +83,22 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
if not all_int(c.shape): return None
# check if view is supported
if not isinstance(c.device, str): return None
from tinygrad.device import Device
if not hasattr(Device[c.device].allocator, "_offset"): return None
if isinstance(c.device, str):
if not hasattr(Device[c.device].allocator, "_offset"): return None
elif not all(hasattr(Device[d].allocator, "_offset") for d in c.device): return None
# see if this can be a view
if (offset := src.contiguous_view_offset()) is None: return None
# merge BUFFER_VIEWs
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
# for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create BUFFER_VIEW on the resolved result
if not isinstance(c.device, str):
from tinygrad.schedule.multi import multi_pm
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
if resolved.op is not Ops.MULTI: return None
if (view := _make_buffer_view(resolved.src[0])) is None: return None
return view.multi(resolved.arg).contiguous(tag=c.tag)
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)
if (view := _make_buffer_view(src)) is None: return None
return view.contiguous(tag=c.tag)
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None

View File

@@ -151,11 +151,14 @@ multi_pm = PatternMatcher([
else multi),
# rewrite into calls explicitly for MULTI
(UPat(Ops.CALL, name="call"), rewrite_into_call),
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
(UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
if root.src[0].op is not Ops.TUPLE else None),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# remove MULTI from STORE
(UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True),
lambda root,multi: UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg)),
])+replace_allreduce

View File

@@ -720,6 +720,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return buf.view(self.size, self.dtype, 0)
if self.op is Ops.BUFFER_VIEW:
buf = self.src[0].buffer
if isinstance(buf, MultiBuffer):
mbuf = MultiBuffer.__new__(MultiBuffer)
mbuf.bufs = [b.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
return mbuf
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize)
if self.op is Ops.MSELECT: