hotfix: multitensor transformer test tests kv cache

This commit is contained in:
George Hotz
2025-06-05 21:08:57 -07:00
parent 8325c4f192
commit ad9f88419a

View File

@@ -1180,8 +1180,8 @@ class TestMultiTransformer(unittest.TestCase):
from extra.models.llama import Transformer
args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64}
real_model = Transformer(**args, jit=False)
shard_model = Transformer(**args, jit=False)
real_model = Transformer(**args)
shard_model = Transformer(**args)
# copy state
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
@@ -1198,9 +1198,11 @@ class TestMultiTransformer(unittest.TestCase):
else: v.shard_(device, axis=None)
last_tok = 0
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
self.assertEqual(real_tok.item(), shard_tok.item())
for i in range(10):
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i)
shard_tok = shard_model(Tensor([[last_tok]], device=device), i)
last_tok = real_tok.item()
self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}")
@unittest.skip("super slow")
def test_llama1b_full(self):