mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
Promote Embedding to nn (#798)
* feat: promote Embedding to nn * fix: fix failing test * feat: add test with jit * feat: rewrite embedding to no longer need stacked for loops * clean+fix: don't know how that happened
This commit is contained in:
6
test/external/external_test_opt.py
vendored
6
test/external/external_test_opt.py
vendored
@@ -81,12 +81,12 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
out.realize()
|
||||
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer, onehot_encode
|
||||
from examples.llama import Transformer
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(85):
|
||||
model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize()
|
||||
with CLCache(86):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user