mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
new test_multitensor tests (#10667)
* new test_multitensor tests * cleanup scheduler
This commit is contained in:
@@ -1173,13 +1173,37 @@ class TestMultiAssign(unittest.TestCase):
|
||||
out[:, 2:3].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
def test_multi_assign_var_offset(self):
|
||||
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
|
||||
vi = Variable("i", 0, 3).bind(2)
|
||||
out[:, vi:vi+1].assign(ones).realize()
|
||||
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
|
||||
|
||||
def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None)
|
||||
def test_multi_assign_var_offset_jit(self, shard_axis=0):
|
||||
out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize()
|
||||
ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize()
|
||||
|
||||
@TinyJit
|
||||
def f(out:Tensor, vi):
|
||||
out[:, vi:vi+1].assign(ones).realize()
|
||||
ones.assign(ones+1).realize()
|
||||
|
||||
vi = Variable("i", 0, 5)
|
||||
for i in range(1,5):
|
||||
GlobalCounters.reset()
|
||||
f(out, vi.bind(i))
|
||||
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiTransformer(unittest.TestCase):
|
||||
def test_transformer(self):
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
|
||||
from extra.models.llama import Transformer
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
|
||||
"hidden_dim": 64, "max_context": 12}
|
||||
real_model = Transformer(**args)
|
||||
shard_model = Transformer(**args)
|
||||
|
||||
@@ -1199,10 +1223,18 @@ class TestMultiTransformer(unittest.TestCase):
|
||||
|
||||
last_tok = 0
|
||||
for i in range(10):
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
|
||||
last_tok = real_tok.item()
|
||||
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
|
||||
|
||||
# test kv cache
|
||||
kv1 = real_model.layers[0].attention.cache_kv.numpy()
|
||||
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
|
||||
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
|
||||
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
|
||||
|
||||
# test token
|
||||
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
|
||||
last_tok = real_tok
|
||||
|
||||
@unittest.skip("super slow")
|
||||
def test_llama1b_full(self):
|
||||
|
||||
Reference in New Issue
Block a user