add bias term to transformer

This commit is contained in:
George Hotz
2021-11-29 12:45:27 -05:00
parent 99b6051467
commit 30eb3afbe1
2 changed files with 6 additions and 6 deletions

View File

@@ -21,7 +21,7 @@ def get_parameters(obj):
parameters = []
if isinstance(obj, Tensor):
parameters.append(obj)
elif isinstance(obj, list):
elif isinstance(obj, list) or isinstance(obj, tuple):
for x in obj:
parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):

View File

@@ -17,10 +17,10 @@ class TransformerBlock:
self.head_size = embed_dim // num_heads
assert self.head_size * self.num_heads == embed_dim
# looks like bias is useless
self.query_dense = Tensor.uniform(embed_dim, embed_dim)
self.key_dense = Tensor.uniform(embed_dim, embed_dim)
self.value_dense = Tensor.uniform(embed_dim, embed_dim)
# added bias
self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.uniform(embed_dim))
self.final = Tensor.uniform(embed_dim, embed_dim)
@@ -34,7 +34,7 @@ class TransformerBlock:
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
query, key, value = [inputs.dot(y) \
query, key, value = [inputs.dot(y[0]).add(y[1].reshape(shape=[1, -1])) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]