diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 73b61284dc..9cd8a58510 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1132,5 +1132,92 @@ class TestMultiRamUsage(unittest.TestCase): _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize() self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage +@unittest.skipIf(not_support_multi_device(), "need multi") +class TestMultiAssign(unittest.TestCase): + device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2)) + + def test_multi_assign_realized(self): + out = Tensor.zeros(4).shard(self.device, 0).contiguous().realize() + ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize() + out.assign(ones).realize() + self.assertListEqual(out.tolist(), [1,1,1,1]) + + def test_multi_assign_unrealized(self): + out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0) + ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize() + out.assign(ones).realize() + self.assertListEqual(out.tolist(), [1,1,1,1]) + + def test_multi_assign_both_unrealized(self): + out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0) + ones = Tensor.ones(4).contiguous().realize().shard(self.device, 0) + out.assign(ones).realize() + self.assertListEqual(out.tolist(), [1,1,1,1]) + + def test_multi_assign_piece(self): + out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize() + ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() + out[:, 2:3].assign(ones).realize() + self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) + + def test_multi_assign_piece_noncontig(self): + out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize() + ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() + out[:, 2:3].assign(ones).realize() + self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) + + @unittest.expectedFailure + def test_multi_assign_piece_unrealized(self): + out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0) + ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize() + out[:, 2:3].assign(ones).realize() + self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]]) + +@unittest.skipIf(not_support_multi_device(), "need multi") +class TestMultiTransformer(unittest.TestCase): + def test_transformer(self): + device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2)) + + from extra.models.llama import Transformer + args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64} + real_model = Transformer(**args, jit=False) + shard_model = Transformer(**args, jit=False) + + # copy state + nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model)) + + # shard + for k,v in nn.state.get_state_dict(shard_model).items(): + if 'scale' in k: v.shard_(device, axis=None) # from quantized + elif '.attention.' in k: v.shard_(device, axis=-1) + elif '.feed_forward.w1.' in k: v.shard_(device, axis=0) + elif '.feed_forward.w3.' in k: v.shard_(device, axis=0) + elif '.feed_forward.' in k: v.shard_(device, axis=-1) + elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0) + elif 'output.weight' in k: v.shard_(device, axis=0) + else: v.shard_(device, axis=None) + + last_tok = 0 + real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0) + shard_tok = shard_model(Tensor([[last_tok]], device=device), 0) + self.assertEqual(real_tok.item(), shard_tok.item()) + + @unittest.skip("super slow") + def test_llama1b_full(self): + from tinygrad.helpers import fetch + fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct") + model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", + "Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct") + + device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2)) + from examples.llama3 import build_transformer + real_model = build_transformer(model, model_size="1B", device=Device.DEFAULT) + shard_model = build_transformer(model, model_size="1B", device=device) + + last_tok = 0 + real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0) + shard_tok = shard_model(Tensor([[last_tok]], device=device), 0) + self.assertEqual(real_tok.item(), shard_tok.item()) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ff8a8c497c..25546b1f40 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -554,6 +554,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.MSTACK: ret = MultiBuffer.__new__(MultiBuffer) ret.bufs = [cast(Buffer, x.buffer) for x in self.src] + assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers" return ret assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret