new test_multitensor tests (#10667)

* new test_multitensor tests

* cleanup scheduler
This commit is contained in:
George Hotz
2025-06-06 10:26:28 -07:00
committed by GitHub
parent 5170f387b3
commit 7f0f97aa76
4 changed files with 64 additions and 32 deletions

View File

@@ -1173,13 +1173,37 @@ class TestMultiAssign(unittest.TestCase):
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
vi = Variable("i", 0, 3).bind(2)
out[:, vi:vi+1].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None)
def test_multi_assign_var_offset_jit(self, shard_axis=0):
out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize()
ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize()
@TinyJit
def f(out:Tensor, vi):
out[:, vi:vi+1].assign(ones).realize()
ones.assign(ones+1).realize()
vi = Variable("i", 0, 5)
for i in range(1,5):
GlobalCounters.reset()
f(out, vi.bind(i))
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiTransformer(unittest.TestCase):
def test_transformer(self):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from extra.models.llama import Transformer
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
"hidden_dim": 64, "max_context": 12}
real_model = Transformer(**args)
shard_model = Transformer(**args)
@@ -1199,10 +1223,18 @@ class TestMultiTransformer(unittest.TestCase):
last_tok = 0
for i in range(10):
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
last_tok = real_tok.item()
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
# test kv cache
kv1 = real_model.layers[0].attention.cache_kv.numpy()
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
# test token
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
last_tok = real_tok
@unittest.skip("super slow")
def test_llama1b_full(self):