tests for multi assign (#10658)

* tests for multi assign

* transformer tests

* add that assert
This commit is contained in:
George Hotz
2025-06-05 20:56:40 -07:00
committed by GitHub
parent 0d86f8d375
commit 8325c4f192
2 changed files with 88 additions and 0 deletions

View File

@@ -1132,5 +1132,92 @@ class TestMultiRamUsage(unittest.TestCase):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiAssign(unittest.TestCase):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
def test_multi_assign_realized(self):
out = Tensor.zeros(4).shard(self.device, 0).contiguous().realize()
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_unrealized(self):
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_both_unrealized(self):
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4).contiguous().realize().shard(self.device, 0)
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_piece(self):
out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_piece_noncontig(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
@unittest.expectedFailure
def test_multi_assign_piece_unrealized(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiTransformer(unittest.TestCase):
def test_transformer(self):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from extra.models.llama import Transformer
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
real_model = Transformer(**args, jit=False)
shard_model = Transformer(**args, jit=False)
# copy state
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
# shard
for k,v in nn.state.get_state_dict(shard_model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.' in k: v.shard_(device, axis=-1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
elif 'output.weight' in k: v.shard_(device, axis=0)
else: v.shard_(device, axis=None)
last_tok = 0
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
self.assertEqual(real_tok.item(), shard_tok.item())
@unittest.skip("super slow")
def test_llama1b_full(self):
from tinygrad.helpers import fetch
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
"Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from examples.llama3 import build_transformer
real_model = build_transformer(model, model_size="1B", device=Device.DEFAULT)
shard_model = build_transformer(model, model_size="1B", device=device)
last_tok = 0
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
self.assertEqual(real_tok.item(), shard_tok.item())
if __name__ == '__main__':
unittest.main()

View File

@@ -554,6 +554,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.MSTACK:
ret = MultiBuffer.__new__(MultiBuffer)
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
return ret
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret