mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tests for multi assign (#10658)
* tests for multi assign * transformer tests * add that assert
This commit is contained in:
@@ -1132,5 +1132,92 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiAssign(unittest.TestCase):
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
|
||||
def test_multi_assign_realized(self):
|
||||
out = Tensor.zeros(4).shard(self.device, 0).contiguous().realize()
|
||||
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
|
||||
out.assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [1,1,1,1])
|
||||
|
||||
def test_multi_assign_unrealized(self):
|
||||
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
|
||||
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
|
||||
out.assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [1,1,1,1])
|
||||
|
||||
def test_multi_assign_both_unrealized(self):
|
||||
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
|
||||
ones = Tensor.ones(4).contiguous().realize().shard(self.device, 0)
|
||||
out.assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [1,1,1,1])
|
||||
|
||||
def test_multi_assign_piece(self):
|
||||
out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
out[:, 2:3].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
def test_multi_assign_piece_noncontig(self):
|
||||
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
out[:, 2:3].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_multi_assign_piece_unrealized(self):
|
||||
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0)
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
out[:, 2:3].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiTransformer(unittest.TestCase):
|
||||
def test_transformer(self):
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
|
||||
from extra.models.llama import Transformer
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
|
||||
real_model = Transformer(**args, jit=False)
|
||||
shard_model = Transformer(**args, jit=False)
|
||||
|
||||
# copy state
|
||||
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
|
||||
|
||||
# shard
|
||||
for k,v in nn.state.get_state_dict(shard_model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.' in k: v.shard_(device, axis=-1)
|
||||
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=0)
|
||||
else: v.shard_(device, axis=None)
|
||||
|
||||
last_tok = 0
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
|
||||
self.assertEqual(real_tok.item(), shard_tok.item())
|
||||
|
||||
@unittest.skip("super slow")
|
||||
def test_llama1b_full(self):
|
||||
from tinygrad.helpers import fetch
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||||
"Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
|
||||
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
from examples.llama3 import build_transformer
|
||||
real_model = build_transformer(model, model_size="1B", device=Device.DEFAULT)
|
||||
shard_model = build_transformer(model, model_size="1B", device=device)
|
||||
|
||||
last_tok = 0
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
|
||||
self.assertEqual(real_tok.item(), shard_tok.item())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -554,6 +554,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MSTACK:
|
||||
ret = MultiBuffer.__new__(MultiBuffer)
|
||||
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
|
||||
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
|
||||
return ret
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
|
||||
Reference in New Issue
Block a user