mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
isolate the 134ms kernel in train_gpt2.py (#4773)
133ms on tinybox red with BEAM=2
This commit is contained in:
10
test/external/external_test_lm_head.py
vendored
Normal file
10
test/external/external_test_lm_head.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
from tinygrad import Tensor, nn
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab_size = 50257
|
||||
n_embd = 768
|
||||
lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
bs = 4
|
||||
seq_len = 1024
|
||||
x = Tensor.rand(bs, seq_len, n_embd)
|
||||
ret = lm_head(x).realize()
|
||||
Reference in New Issue
Block a user