mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
* mark slow tests as slow instead of as CI * CI shouldn't have different behavior * more skips / CI * slow
25 lines
849 B
Python
25 lines
849 B
Python
import unittest
|
|
from test.helpers import slow
|
|
from examples.mamba import Mamba, generate
|
|
from transformers import AutoTokenizer
|
|
|
|
PROMPT = 'Why is gravity '
|
|
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
|
|
@slow
|
|
class TestMamba(unittest.TestCase):
|
|
def test_mamba_130M(self):
|
|
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
|
|
model = Mamba.from_pretrained('130m')
|
|
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
|
self.assertEqual(OUT_130M, tinyoutput)
|
|
del model
|
|
def test_mamba_370M(self):
|
|
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
|
|
model = Mamba.from_pretrained('370m')
|
|
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
|
self.assertEqual(OUT_370M, tinyoutput)
|
|
del model
|
|
if __name__ == '__main__':
|
|
unittest.main()
|