mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* work to make GEMV fast * half8 cast * align struct * fix amd * float8 is a later problem
13 lines
346 B
Python
13 lines
346 B
Python
#!/usr/bin/env python3
|
|
from tinygrad import Tensor, TinyJit, nn
|
|
from extra.models.llama import FeedForward
|
|
|
|
if __name__ == "__main__":
|
|
model = FeedForward(4096, 14336)
|
|
for x in nn.state.get_parameters(model): x.replace(x.half()).realize()
|
|
jrun = TinyJit(model)
|
|
for i in range(5):
|
|
print(f"*** run {i}")
|
|
jrun(Tensor.rand(1, 4096))
|
|
|