From 5eb6e1e65a274f7739b46fdedb1cdb1219e4c989 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 5 Jun 2025 21:15:34 -0700 Subject: [PATCH] Revert "hotfix: multitensor transformer test tests kv cache" This reverts commit ad9f88419af3caff2be2e1e9c9722640fbb012b0. --- test/test_multitensor.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 0a88763825..9cd8a58510 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1180,8 +1180,8 @@ class TestMultiTransformer(unittest.TestCase): from extra.models.llama import Transformer args = {"dim": 64, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024, "hidden_dim": 64} - real_model = Transformer(**args) - shard_model = Transformer(**args) + real_model = Transformer(**args, jit=False) + shard_model = Transformer(**args, jit=False) # copy state nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model)) @@ -1198,11 +1198,9 @@ class TestMultiTransformer(unittest.TestCase): else: v.shard_(device, axis=None) last_tok = 0 - for i in range(10): - real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i) - shard_tok = shard_model(Tensor([[last_tok]], device=device), i) - last_tok = real_tok.item() - self.assertEqual(last_tok, shard_tok.item(), f"issue at token {i}") + real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0) + shard_tok = shard_model(Tensor([[last_tok]], device=device), 0) + self.assertEqual(real_tok.item(), shard_tok.item()) @unittest.skip("super slow") def test_llama1b_full(self):