mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-16 02:18:27 -05:00
19 lines
588 B
Python
19 lines
588 B
Python
from tinygrad import Tensor, UOp
|
|
from tinygrad.uop.ops import Ops, AxisType
|
|
import unittest
|
|
# this test is only focused on transformers and using range for the layers
|
|
|
|
class TestOuterworldTransformer(unittest.TestCase):
|
|
def test_three_mats(self):
|
|
w = Tensor.empty(3, 1024, 1024)
|
|
inp = Tensor.empty(1, 1024)
|
|
i = UOp.range(3, -1, AxisType.OUTER)
|
|
inp_after = Tensor(inp.uop.after(i))
|
|
inp_gemm = inp_after@w[i]
|
|
inp = inp.uop.after(inp.uop.store(inp_gemm.uop).end(i)).contiguous()
|
|
inp = Tensor(inp)
|
|
inp.realize()
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|