mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fixup training loop
This commit is contained in:
@@ -3,6 +3,10 @@ import numpy as np
|
||||
import random
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
from extra.utils import get_parameters
|
||||
from extra.training import train, evaluate
|
||||
from tinygrad.optim import Adam
|
||||
|
||||
# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb
|
||||
def make_dataset():
|
||||
ds = []
|
||||
@@ -19,8 +23,6 @@ def make_dataset():
|
||||
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
#X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
# Multi-Head Attention
|
||||
@@ -33,44 +35,72 @@ class TransformerBlock:
|
||||
self.key_dense = Tensor.uniform(embed_dim, embed_dim)
|
||||
self.value_dense = Tensor.uniform(embed_dim, embed_dim)
|
||||
|
||||
self.final = Tensor.uniform(embed_dim, embed_dim)
|
||||
|
||||
self.ff1 = Tensor.uniform(embed_dim, embed_dim)
|
||||
self.ff2 = Tensor.uniform(embed_dim, embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
bs = x.shape[0]
|
||||
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
|
||||
inputs = x.reshape(shape=(-1, self.num_heads * self.head_size))
|
||||
|
||||
# run multi head attention (bs, T, num_heads, head_size)
|
||||
query, key, value = [x.dot(y) \
|
||||
query, key, value = [inputs.dot(y) \
|
||||
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
|
||||
score = query.dot(key)
|
||||
print(query.shape)
|
||||
print(key.shape)
|
||||
print(score.shape)
|
||||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
|
||||
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
||||
weights = score.logsoftmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3))
|
||||
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
|
||||
print(x.shape)
|
||||
# layernorm
|
||||
x = x + x.dot(self.ff1).relu().dot(self.ff2)
|
||||
print(x.shape)
|
||||
# layernorm
|
||||
return x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, syms, maxlen, cnt, embed_dim, num_heads):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.uniform(maxlen+syms, embed_dim)
|
||||
self.tbs = []
|
||||
for i in range(cnt):
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads))
|
||||
self.final = Tensor.uniform(embed_dim, syms)
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.cpu().data
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
print(onehot.shape)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||
x = Tensor(onehot, device=x.device).dot(self.embed)
|
||||
print(x.shape)
|
||||
for t in self.tbs:
|
||||
x = t(x)
|
||||
return x.dot(self.final).logsoftmax()
|
||||
|
||||
|
||||
|
||||
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
|
||||
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
|
||||
#scaled_score = score * (1/np.sqrt(self.projection_dim))
|
||||
|
||||
#print(value.shape)
|
||||
#print(scaled_score.shape)
|
||||
|
||||
#query = self.query_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim))
|
||||
#key = self.key_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim))
|
||||
#value = self.value_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim))
|
||||
|
||||
#x = self.ff2(self.ff1(x).relu())
|
||||
#return x
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
tb = TransformerBlock(128, 4)
|
||||
tmp = Tensor.zeros(20, 10, 128)
|
||||
ret = tb(tmp)
|
||||
print(ret)
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
|
||||
#in1 = Tensor.zeros(20, 6, 128)
|
||||
#ret = model.forward(in1)
|
||||
#print(ret.shape)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 100)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,20 +4,25 @@ from tqdm import trange
|
||||
from extra.utils import get_parameters
|
||||
from tinygrad.tensor import Tensor, GPU, Device
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=Device.CPU, lossfn = lambda out,y: out.mul(y).mean()):
|
||||
def sparse_categorical_crossentropy(out, Y):
|
||||
num_classes = out.shape[-1]
|
||||
YY = Y.flatten()
|
||||
y = np.zeros((YY.shape[0], num_classes), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),YY] = -1.0*num_classes
|
||||
y = y.reshape(list(Y.shape)+[num_classes])
|
||||
y = Tensor(y, device=out.device)
|
||||
return out.mul(y).mean()
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, lossfn=sparse_categorical_crossentropy):
|
||||
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
|
||||
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
|
||||
if num_classes is None: num_classes = Y_train.max().astype(int)+1
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
||||
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device)
|
||||
Y = Y_train[samp]
|
||||
y = np.zeros((len(samp),num_classes), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),Y] = -1.0*num_classes
|
||||
y = Tensor(y, device=device)
|
||||
x = Tensor(X_train[samp], device=device)
|
||||
y = Y_train[samp]
|
||||
|
||||
# network
|
||||
out = model.forward(x)
|
||||
@@ -29,7 +34,7 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, devic
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
accuracy = (cat == y).mean()
|
||||
|
||||
# printing
|
||||
loss = loss.cpu().data
|
||||
@@ -41,7 +46,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ from extra.utils import fetch, get_parameters
|
||||
def fetch_mnist():
|
||||
import gzip
|
||||
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
@@ -138,10 +138,6 @@ class ReLU(Function):
|
||||
return grad_output * (input >= 0)
|
||||
register('relu', ReLU)
|
||||
|
||||
def _exp_normalize(x, axis=None):
|
||||
y = np.exp(x - x.max(axis=axis, keepdims=True))
|
||||
return y / y.sum(axis=axis, keepdims=True)
|
||||
|
||||
class Sigmoid(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
@@ -160,17 +156,21 @@ class Sigmoid(Function):
|
||||
return grad_output * (ret * (1 - ret))
|
||||
register('sigmoid', Sigmoid)
|
||||
|
||||
def _exp_normalize(x, axis=None):
|
||||
y = np.exp(x - x.max(axis=axis, keepdims=True))
|
||||
return y / y.sum(axis=axis, keepdims=True)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
softmax = _exp_normalize(input, axis=1)
|
||||
softmax = _exp_normalize(input, axis=-1)
|
||||
ctx.save_for_backward(softmax)
|
||||
return np.log(softmax)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
softmax, = ctx.saved_tensors
|
||||
return grad_output - grad_output.sum(axis=1, keepdims=True)*softmax
|
||||
return grad_output - grad_output.sum(axis=-1, keepdims=True)*softmax
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user