mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
embedded bwd vocab shard (#15001)
* fix: remove more multi from call * feat: embedding bwd vocab sharding * clean: unused import * clean: don't actually need this pattern
This commit is contained in:
@@ -8,6 +8,7 @@ from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, tensor, desired):
|
||||
@@ -188,6 +189,33 @@ class TestIndexing(unittest.TestCase):
|
||||
for i in idx.flatten().numpy(): expected_grad[i] += 2
|
||||
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@needs_second_gpu
|
||||
@unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU")
|
||||
@Context(USE_ATOMICS=1, SPEC=1)
|
||||
def test_embedding_backward_vocab_sharded(self):
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer")
|
||||
devices = (f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1")
|
||||
vocab_size, embed_size = 1000, 128
|
||||
bs, seqlen = 4, 256
|
||||
idx = Tensor.randint(bs, seqlen, high=vocab_size)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
|
||||
gt = Tensor.zeros(bs, seqlen, embed_size)
|
||||
Tensor.realize(idx, emb.weight, gt)
|
||||
# compute expected grad on single device
|
||||
expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32)
|
||||
for i in idx.flatten().numpy(): expected_grad[i] += 2
|
||||
# now shard the embedding weight on vocab axis and recompute
|
||||
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
|
||||
emb.weight.shard_(devices, axis=0)
|
||||
idx = idx.shard(devices, axis=None)
|
||||
gt = gt.shard(devices, axis=None)
|
||||
Tensor.realize(idx, emb.weight, gt)
|
||||
loss = (emb(idx)-gt).square().sum()
|
||||
loss.backward()
|
||||
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD" or (Device.DEFAULT == "NULL" and EMULATE.value.startswith("AMD")), "tests AMD bf16 cast overhead")
|
||||
def base_test_llama_8b_rope_backward(self, dtype):
|
||||
from extra.models.llama import precompute_freqs_cis, apply_rotary_emb
|
||||
|
||||
@@ -245,5 +245,36 @@ class TestCallSchedule(unittest.TestCase):
|
||||
out = f(a) + 2
|
||||
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
|
||||
|
||||
def test_call_reduce_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
Tensor.realize(a)
|
||||
c = Tensor.call(a, fxn=a.as_param(0).sum(axis=0))
|
||||
np.testing.assert_equal(c.numpy(), 10 * np.ones(10))
|
||||
|
||||
def test_call_reduce_sharded_mixed_args(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
b = Tensor.ones(10).shard(devs, axis=None)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0).sum(axis=0) + b.as_param(1))
|
||||
np.testing.assert_equal(c.numpy(), 11 * np.ones(10))
|
||||
|
||||
def test_call_reduce_sharded_backward(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
|
||||
b = Tensor.randn(10, 10, requires_grad=True).shard(devs, axis=0)
|
||||
Tensor.realize(a, b)
|
||||
|
||||
def grad_fxn(grad, call):
|
||||
a_arg, b_arg = call.src[1], call.src[2]
|
||||
return (grad.expand(a_arg.shape) * b_arg, grad.expand(b_arg.shape) * a_arg)
|
||||
|
||||
body = (a.as_param(0) * b.as_param(1)).sum(axis=0)
|
||||
c = Tensor.call(a, b, fxn=body, grad_fxn=grad_fxn)
|
||||
c.sum().backward()
|
||||
np.testing.assert_allclose(a.grad.numpy(), b.numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), a.numpy(), rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -306,13 +306,19 @@ class RMSNorm:
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType
|
||||
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
weight, idx = call.src[1:]
|
||||
# for multi-device: unshard inputs to one device
|
||||
is_vocab_sharded = isinstance(weight.device, tuple) and weight.axis == 0
|
||||
# for multi-device: replicate grad_emb and idx on all devices
|
||||
if isinstance(weight.device, tuple):
|
||||
assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS"
|
||||
assert weight.axis is None or weight.axis == 0, "only vocab (axis=0) sharding supported on Embedding with USE_ATOMICS"
|
||||
grad_emb = grad_emb.copy_to_device(weight.device)
|
||||
idx = idx.copy_to_device(weight.device)
|
||||
# weight is replicated, grad_weight should match
|
||||
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
|
||||
if is_vocab_sharded:
|
||||
ndev = len(weight.device)
|
||||
local_vocab_size = weight.shape[0] // ndev
|
||||
grad_weight_uop = Tensor.empty(local_vocab_size, weight.shape[1], dtype=dtypes.float, device=weight.device).uop.multi(axis=0)
|
||||
else:
|
||||
# weight is replicated (or single device), grad_weight should match
|
||||
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
|
||||
|
||||
# TODO: how do we remove this dumb kernel and use Tensor.zeros?
|
||||
def _zero_kernel(out:UOp) -> UOp:
|
||||
@@ -327,9 +333,6 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
|
||||
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
|
||||
|
||||
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
|
||||
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
|
||||
|
||||
embed_size = grad_weight.shape[-1]
|
||||
BLOCK_J = min(256, embed_size)
|
||||
assert embed_size % BLOCK_J == 0, f"embed_size {embed_size} must be divisible by {BLOCK_J}"
|
||||
@@ -340,13 +343,22 @@ def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
j_outer = UOp.range(n_j_blocks, 1)
|
||||
j = j_outer * BLOCK_J + j_inner
|
||||
|
||||
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
|
||||
|
||||
if is_vocab_sharded:
|
||||
# each device owns [offset, offset+local_vocab_size) of the global vocabulary
|
||||
dnum = UOp.variable("_device_num", 0, ndev-1)
|
||||
offset = dnum * local_vocab_size
|
||||
global_token_id = idx_flat[i].cast(dtypes.index)
|
||||
local_token_id = (global_token_id - offset).clip(0, grad_weight.shape[0]-1)
|
||||
in_range = (global_token_id >= offset) & (global_token_id < (offset + local_vocab_size))
|
||||
grad_val = in_range.where(grad_emb_flat[i, j].cast(dtypes.float), 0.0)
|
||||
else:
|
||||
local_token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
|
||||
grad_val = grad_emb_flat[i, j].cast(dtypes.float)
|
||||
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
|
||||
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
|
||||
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
|
||||
else: raise NotImplementedError(f"no atomics for device {device}")
|
||||
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
|
||||
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(local_token_id, j, ptr=True), grad_val), arg = atomic_arg)
|
||||
return atomic.end(i, j_outer, j_inner).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
|
||||
|
||||
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
|
||||
|
||||
Reference in New Issue
Block a user