mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
llama compute gradients explicitly + 243 GB of RAM on MP=8 (#15343)
* llama compute gradients explicitly * apply grads * fix multi issue * multi BUFFER_VIEW support * simpler * skip the flaky test
This commit is contained in:
@@ -14,6 +14,7 @@ if __name__ == "__main__":
|
||||
os.environ["ASM_GEMM"] = "1"
|
||||
from tinygrad import Tensor, nn, function, getenv, dtypes, TinyJit
|
||||
from tinygrad.helpers import Timing, colored, GlobalCounters
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
|
||||
|
||||
def rmsnorm(x_in:Tensor, eps:float):
|
||||
@@ -41,8 +42,8 @@ class FlatTransformer:
|
||||
self.w3 = self.lin_per_layer(dim, hidden_dim)
|
||||
|
||||
self.norm_eps = norm_eps
|
||||
self.attention_norm = Tensor.ones(n_layers, dim)
|
||||
self.ffn_norm = Tensor.ones(n_layers, dim)
|
||||
self.attention_norm = Tensor.ones(n_layers, dim).contiguous()
|
||||
self.ffn_norm = Tensor.ones(n_layers, dim).contiguous()
|
||||
|
||||
# output
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
@@ -124,31 +125,49 @@ if __name__ == "__main__":
|
||||
model = FlatTransformer(**model_params, max_context=SEQLEN)
|
||||
state = nn.state.get_state_dict(model)
|
||||
print("tensor count:", len(state))
|
||||
sz = 0
|
||||
for k,v in state.items():
|
||||
if v.requires_grad is None: v.requires_grad_(True)
|
||||
print(f"{colored(k, 'green' if v.requires_grad else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device}")
|
||||
sz += v.nbytes()
|
||||
print(f"total sz: {sz/1e9:.2f} GB")
|
||||
|
||||
# shard the model
|
||||
from tinygrad import Device
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)))
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
model.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)), mp=True)
|
||||
|
||||
with Timing("realize weights: "): Tensor.realize(*state.values())
|
||||
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int).realize()
|
||||
# preallocate all the grad buffers and zero them out
|
||||
grads = {x:Tensor.zeros_like(x).contiguous() for x in state.values() if x.requires_grad is None}
|
||||
|
||||
# print model size
|
||||
sz = 0
|
||||
for k,v in state.items():
|
||||
print(f"{colored(k, 'green' if v in grads else 'white'):30s} {str(v.shape):30s} {v.dtype} {v.device} {v.nbytes()/1e9:.2f} GB")
|
||||
sz += v.nbytes()
|
||||
print(f"total sz: {sz/1e9:.2f} GB")
|
||||
|
||||
with Timing("fake data: "): tokens = Tensor.randint(BS, SEQLEN+1, low=0, high=model.vocab_size, dtype=dtypes.int)
|
||||
with Timing("realize weights/grads/data: "): Tensor.realize(*state.values(), *grads.values(), tokens)
|
||||
print("mem per device: " + ', '.join(f"{dev}: {mem/1e9:.2f} GB" for dev, mem in sorted(GlobalCounters.mem_used_per_device.items())))
|
||||
if DP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)), axis=0)
|
||||
if MP > 1: tokens = tokens.shard(tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)))
|
||||
|
||||
# TODO: this shouldn't be needed, but it prevents a copy of the grads. CAT can help
|
||||
def apply_grad(old_grad:UOp, new_grad:UOp) -> list[UOp]:
|
||||
if new_grad.op == Ops.ADD:
|
||||
return apply_grad(old_grad, new_grad.src[0])+apply_grad(old_grad, new_grad.src[1])
|
||||
elif new_grad.op == Ops.PAD:
|
||||
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(new_grad.src[0].shape, new_grad.marg)])
|
||||
return apply_grad(old_grad.shrink(grad_shrink), new_grad.src[0])
|
||||
else:
|
||||
return [old_grad.store(old_grad + new_grad)]
|
||||
|
||||
@TinyJit
|
||||
def jit_step(tokens:Tensor):
|
||||
GlobalCounters.reset()
|
||||
print(colored("*** step", "red"))
|
||||
with Timing("python forward: "): loss = model(tokens[:, :-1]).sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
with Timing("python backward: "): loss.backward()
|
||||
with Timing("run step: "): loss.realize(*[x.grad for x in state.values() if x.requires_grad])
|
||||
with Timing("python backward: "):
|
||||
for t,g in zip(grads, loss.gradient(*grads)):
|
||||
grads[t] = Tensor(grads[t].uop.after(UOp.group(*apply_grad(grads[t].uop, g.uop))), device=t.device)
|
||||
with Timing("run step: "): loss.realize(*grads.values())
|
||||
|
||||
jit_step(tokens)
|
||||
jit_step(tokens)
|
||||
|
||||
@@ -77,5 +77,41 @@ class TestFlatLlama(unittest.TestCase):
|
||||
diff = abs(ref_grads[ref_key] - flat_grads[flat_key][i]).max()
|
||||
self.assertLess(diff, 1e-4, f"layer {i} {flat_key} grad mismatch: max abs diff {diff}")
|
||||
|
||||
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
|
||||
def test_forward_match_mp(self):
|
||||
Tensor.manual_seed(42)
|
||||
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
|
||||
from tinygrad import Device
|
||||
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
|
||||
ref = Transformer(**params)
|
||||
flat = FlatTransformer(**params)
|
||||
copy_weights(flat, ref)
|
||||
Tensor.realize(*nn.state.get_state_dict(flat).values())
|
||||
flat.shard(devices, mp=True)
|
||||
|
||||
tokens = Tensor([[1, 50, 100, 999, 2]], device=devices[0])
|
||||
ref_logits = ref(tokens.to(devices[0])).numpy()
|
||||
flat_logits = flat(tokens.shard(devices)).numpy()
|
||||
self.assertEqual(ref_logits.shape, flat_logits.shape)
|
||||
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipUnless(os.getenv("CPU", "") == "1", "multi-device CPU test")
|
||||
def test_forward_match_dp(self):
|
||||
Tensor.manual_seed(42)
|
||||
params = dict(dim=128, hidden_dim=256, n_heads=4, n_kv_heads=2, n_layers=2, norm_eps=1e-5, vocab_size=1024, rope_theta=10000, max_context=64)
|
||||
from tinygrad import Device
|
||||
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
|
||||
ref = Transformer(**params)
|
||||
flat = FlatTransformer(**params)
|
||||
copy_weights(flat, ref)
|
||||
Tensor.realize(*nn.state.get_state_dict(flat).values())
|
||||
flat.shard(devices)
|
||||
|
||||
tokens = Tensor([[1, 50, 100, 999, 2], [2, 100, 50, 1, 999]], device=devices[0])
|
||||
ref_logits = ref(tokens.to(devices[0])).numpy()
|
||||
flat_logits = flat(tokens.shard(devices, axis=0)).numpy()
|
||||
self.assertEqual(ref_logits.shape, flat_logits.shape)
|
||||
np.testing.assert_allclose(flat_logits, ref_logits, atol=1e-4, rtol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1130,6 +1130,51 @@ class TestTensorOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiBufferView(unittest.TestCase):
|
||||
@needs_second_gpu
|
||||
def setUp(self): pass
|
||||
|
||||
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
|
||||
"""Apply view_fn to both, verify zero compiled kernels and matching values."""
|
||||
b_ref = view_fn(a_ref)
|
||||
b_multi = view_fn(a_multi).contiguous()
|
||||
sched = b_multi.schedule()
|
||||
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
|
||||
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
|
||||
|
||||
@unittest.skip("flaky on LLVM")
|
||||
def test_shrink_non_shard_axis(self):
|
||||
ref = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().realize()
|
||||
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices_2, axis=1).realize()
|
||||
self._check(ref, a, lambda t: t[3])
|
||||
|
||||
def test_shrink_2d(self):
|
||||
ref = Tensor.arange(6*4).reshape(6, 4).contiguous().realize()
|
||||
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices_2, axis=1).realize()
|
||||
self._check(ref, a, lambda t: t.shrink(((1, 4), None)))
|
||||
|
||||
def test_reshape_then_shrink(self):
|
||||
ref = Tensor.arange(8*6).reshape(8, 6).contiguous().realize()
|
||||
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices_2, axis=1).realize()
|
||||
self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1])
|
||||
|
||||
def test_chained_shrink(self):
|
||||
ref = Tensor.arange(10*8).reshape(10, 8).contiguous().realize()
|
||||
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices_2, axis=1).realize()
|
||||
self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None)))
|
||||
|
||||
def test_4_devices(self):
|
||||
ref = Tensor.arange(8*12).reshape(8, 12).contiguous().realize()
|
||||
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices_4, axis=1).realize()
|
||||
sched = a[5].contiguous().schedule()
|
||||
compiled = [si for si in sched if isinstance(si.prg, CompiledRunner)]
|
||||
self.assertEqual(len(compiled), 0)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(a[5].contiguous().numpy(), ref[5].numpy())
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiFromUnrenderable(unittest.TestCase):
|
||||
@needs_second_gpu
|
||||
|
||||
@@ -1189,6 +1189,61 @@ class TestBufferView(unittest.TestCase):
|
||||
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
|
||||
run_schedule(check_schedule(b, 0))
|
||||
|
||||
def test_shrink_non_shard_axis_is_buffer_view_multi(self):
|
||||
# indexing a non-shard axis of a realized sharded tensor should be BUFFER_VIEW on each device, not copy kernels
|
||||
# this is the flat_llama pattern: weight[layer_idx] where weight is (n_layers, out, dim) sharded on axis=1
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(8*4*10).reshape(8, 4, 10).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a[3].contiguous(), 0))
|
||||
|
||||
def test_shrink_2d_non_shard_axis_multi(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(6*4).reshape(6, 4).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.shrink(((1, 4), None)).contiguous(), 0))
|
||||
|
||||
def test_shrink_shard_axis_0_multi(self):
|
||||
# shrinking a middle dim is not contiguous per shard, so this needs copy kernels
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=0).realize()
|
||||
run_schedule(check_schedule(a.shrink((None, (2, 5), None)).contiguous(), 2))
|
||||
|
||||
def test_reshape_then_shrink_multi(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(8*6).reshape(8, 6).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.reshape(4, 2, 6)[1].contiguous(), 0))
|
||||
|
||||
def test_permute_then_shrink_multi(self):
|
||||
# permute makes per-shard view non-contiguous, needs copy kernels
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*6*2).reshape(4, 6, 2).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.permute(1, 0, 2).shrink(((0, 6), (1, 3), None)).contiguous(), 2))
|
||||
|
||||
def test_multi_buffer_view_4_devices(self):
|
||||
devices = tuple(f"NULL:{i}" for i in range(4))
|
||||
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a[5].contiguous(), 0))
|
||||
|
||||
def test_chained_shrink_multi(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(10*8).reshape(10, 8).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0))
|
||||
|
||||
# negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard)
|
||||
def test_expand_multi_not_buffer_view(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*2).reshape(4, 1, 2).contiguous().shard(devices, axis=2).realize()
|
||||
run_schedule(check_schedule(a.expand(4, 3, 2).contiguous(), 2))
|
||||
|
||||
def test_pad_multi_not_buffer_view(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.pad(((1, 1), (0, 0))).contiguous(), 2))
|
||||
|
||||
def test_flip_multi_not_buffer_view(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*2).reshape(4, 2).contiguous().shard(devices, axis=1).realize()
|
||||
run_schedule(check_schedule(a.flip(0).contiguous(), 2))
|
||||
|
||||
class TestInvalidTensor(unittest.TestCase):
|
||||
def test_full_invalid_is_zero_kernels(self):
|
||||
from tinygrad.dtype import Invalid
|
||||
|
||||
@@ -66,6 +66,13 @@ def replace_store_after_with_contig(u:UOp, src:UOp):
|
||||
while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
|
||||
|
||||
def _make_buffer_view(src:UOp) -> UOp|None:
|
||||
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
|
||||
if (offset := src.contiguous_view_offset()) is None: return None
|
||||
buf = src.base
|
||||
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
buf = src.base
|
||||
@@ -76,18 +83,22 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
if not all_int(c.shape): return None
|
||||
|
||||
# check if view is supported
|
||||
if not isinstance(c.device, str): return None
|
||||
from tinygrad.device import Device
|
||||
if not hasattr(Device[c.device].allocator, "_offset"): return None
|
||||
if isinstance(c.device, str):
|
||||
if not hasattr(Device[c.device].allocator, "_offset"): return None
|
||||
elif not all(hasattr(Device[d].allocator, "_offset") for d in c.device): return None
|
||||
|
||||
# see if this can be a view
|
||||
if (offset := src.contiguous_view_offset()) is None: return None
|
||||
|
||||
# merge BUFFER_VIEWs
|
||||
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
|
||||
# for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create BUFFER_VIEW on the resolved result
|
||||
if not isinstance(c.device, str):
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
|
||||
if resolved.op is not Ops.MULTI: return None
|
||||
if (view := _make_buffer_view(resolved.src[0])) is None: return None
|
||||
return view.multi(resolved.arg).contiguous(tag=c.tag)
|
||||
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.size, offset)).reshape(src.shape).contiguous(tag=c.tag)
|
||||
if (view := _make_buffer_view(src)) is None: return None
|
||||
return view.contiguous(tag=c.tag)
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
|
||||
@@ -151,11 +151,14 @@ multi_pm = PatternMatcher([
|
||||
else multi),
|
||||
# rewrite into calls explicitly for MULTI
|
||||
(UPat(Ops.CALL, name="call"), rewrite_into_call),
|
||||
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
(UPat((Ops.CALL, Ops.AFTER), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
|
||||
if root.src[0].op is not Ops.TUPLE else None),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# remove MULTI from STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True),
|
||||
lambda root,multi: UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg)),
|
||||
])+replace_allreduce
|
||||
|
||||
@@ -720,6 +720,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return buf.view(self.size, self.dtype, 0)
|
||||
if self.op is Ops.BUFFER_VIEW:
|
||||
buf = self.src[0].buffer
|
||||
if isinstance(buf, MultiBuffer):
|
||||
mbuf = MultiBuffer.__new__(MultiBuffer)
|
||||
mbuf.bufs = [b.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
|
||||
return mbuf
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
|
||||
return buf.view(self.size, self.dtype, self.arg[1] * self.dtype.itemsize)
|
||||
if self.op is Ops.MSELECT:
|
||||
|
||||
Reference in New Issue
Block a user