mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
if you like your transformers twice as slow, use the GPU
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from tinygrad.tensor import Device
|
||||
from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from extra.transformer import Transformer
|
||||
@@ -23,7 +25,6 @@ def make_dataset():
|
||||
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
@@ -32,7 +33,7 @@ if __name__ == "__main__":
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
|
||||
for i in range(5):
|
||||
train(model, X_train, Y_train, optim, 500, BS=32)
|
||||
train(model, X_train, Y_train, optim, 500, BS=32, device=Device.GPU if os.getenv("GPU") else Device.CPU)
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class Transformer:
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.cpu().data
|
||||
xnp = x.cpu().data.astype(np.int32)
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
|
||||
@@ -77,6 +77,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
|
||||
@@ -250,9 +250,9 @@ class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
assert input.shape[-1] == weight.shape[-2]
|
||||
cnt = input.shape[0] if len(input.shape) == 3 else 1
|
||||
cnt = np.prod(input.shape[0:-2]) if len(input.shape) > 2 else 1
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
ret = buffer_new(ctx, (isize, osize) if cnt == 1 else (cnt, isize, osize))
|
||||
ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize])
|
||||
|
||||
matmul = clbuild(ctx.cl_ctx, "matmul", """
|
||||
__kernel void matmul(
|
||||
|
||||
Reference in New Issue
Block a user