mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
dropout, training
This commit is contained in:
@@ -68,9 +68,9 @@ class TransformerBlock:
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
|
||||
x = inputs + attention.reshape(shape=(-1, embed_dim)).dot(self.final)
|
||||
x = inputs + attention.reshape(shape=(-1, embed_dim)).dot(self.final).dropout(0.1)
|
||||
x = layernorm(x, embed_dim)
|
||||
x = x + x.dot(self.ff1).relu().dot(self.ff2)
|
||||
x = x + x.dot(self.ff1).relu().dot(self.ff2).dropout(0.1)
|
||||
x = layernorm(x, embed_dim)
|
||||
return x.reshape(shape=(bs, -1, embed_dim))
|
||||
|
||||
@@ -107,6 +107,7 @@ if __name__ == "__main__":
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 500, BS=16)
|
||||
|
||||
Tensor.training = False
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user