mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix: multitensor transformer test tests kv cache
This commit is contained in:
@@ -1180,8 +1180,8 @@ class TestMultiTransformer(unittest.TestCase):
|
||||
|
||||
from extra.models.llama import Transformer
|
||||
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
|
||||
real_model = Transformer(**args, jit=False)
|
||||
shard_model = Transformer(**args, jit=False)
|
||||
real_model = Transformer(**args)
|
||||
shard_model = Transformer(**args)
|
||||
|
||||
# copy state
|
||||
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
|
||||
@@ -1198,9 +1198,11 @@ class TestMultiTransformer(unittest.TestCase):
|
||||
else: v.shard_(device, axis=None)
|
||||
|
||||
last_tok = 0
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
|
||||
self.assertEqual(real_tok.item(), shard_tok.item())
|
||||
for i in range(10):
|
||||
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
|
||||
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
|
||||
last_tok = real_tok.item()
|
||||
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
|
||||
|
||||
@unittest.skip("super slow")
|
||||
def test_llama1b_full(self):
|
||||
|
||||
Reference in New Issue
Block a user