mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
transformer is training
This commit is contained in:
@@ -55,6 +55,7 @@ class TransformerBlock:
|
||||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
|
||||
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
||||
# TODO: this should be a normal softmax
|
||||
weights = score.logsoftmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3))
|
||||
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
|
||||
@@ -77,19 +78,18 @@ class Transformer:
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.cpu().data
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
onehot = np.zeros((bs*x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
print(onehot.shape)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||
x = Tensor(onehot, device=x.device).dot(self.embed)
|
||||
print(x.shape)
|
||||
onehot[range(bs*i, bs*(i+1)), i] = 1
|
||||
onehot[range(bs*i, bs*(i+1)), self.maxlen + xnp[:, i]] = 1
|
||||
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
|
||||
for t in self.tbs:
|
||||
x = t(x)
|
||||
return x.dot(self.final).logsoftmax()
|
||||
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(out.cpu().data, axis=1)
|
||||
cat = np.argmax(out.cpu().data, axis=-1)
|
||||
accuracy = (cat == y).mean()
|
||||
|
||||
# printing
|
||||
|
||||
Reference in New Issue
Block a user