mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Mamba Implementation (#3456)
* first commit * state back to orig * mamba comparisions * rm file * rename file * use Tensor.einsum and mke default model 370M * Cleaned code and made a comparision test * Simplyfy pull request. Only has 1 mamba implementation now. * Update prompt * rm whitespaces * last space * remove Einops dependency * rm unused code * add tests * rm print statement * rm imports * skip CLANG * Update skipIf description * skip model test in CI and add CLANG fix * rm Device import * don't be stupid * Fix conv assign When the prompt is too short, the logic for conv_state assign messes up. This can be fixed when padding the tokenized array to min length of 4. I padded using the empty string token, but idk if proper practice is to use the PAD token * fix p1 * temp * fix jit import --------- Co-authored-by: schlimeszn <schlimeszn@gmail.com> Co-authored-by: reddyn <nikidsniper@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
24
test/models/test_mamba.py
Normal file
24
test/models/test_mamba.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import CI
|
||||
from examples.mamba import Mamba, generate
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
PROMPT = 'Why is gravity '
|
||||
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
|
||||
@unittest.skipIf(CI, "model is slow for CI")
|
||||
class TestMamba(unittest.TestCase):
|
||||
def test_mamba_130M(self):
|
||||
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
|
||||
model = Mamba.from_pretrained('130m')
|
||||
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
||||
self.assertEqual(OUT_130M, tinyoutput)
|
||||
del model
|
||||
def test_mamba_370M(self):
|
||||
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
|
||||
model = Mamba.from_pretrained('370m')
|
||||
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
||||
self.assertEqual(OUT_370M, tinyoutput)
|
||||
del model
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user