cpu only decorator

This commit is contained in:
George Hotz
2020-12-27 17:18:55 -05:00
parent 2f1b2c0a3b
commit 131e04c90c
2 changed files with 26 additions and 12 deletions

View File

@@ -25,8 +25,8 @@ class TransformerBlock:
def __init__(self, embed_dim, num_heads):
# Multi-Head Attention
self.num_heads = num_heads
self.projection_dim = embed_dim // num_heads
assert self.projection_dim * self.num_heads == embed_dim
self.head_size = embed_dim // num_heads
assert self.head_size * self.num_heads == embed_dim
# looks like bias is useless
self.query_dense = Tensor.uniform(embed_dim, embed_dim)
@@ -37,16 +37,27 @@ class TransformerBlock:
self.ff2 = Tensor.uniform(embed_dim, embed_dim)
def __call__(self, x):
# bs x T x embed_dim
bs = x.shape[0]
x = x.reshape(shape=(-1, self.num_heads * self.projection_dim))
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
# run multi head attention
qkv = [x.dot(y) \
.reshape(shape=(bs, -1, self.num_heads, self.projection_dim)) \
.transpose(order=(0,2,1,3)) \
# run multi head attention (bs, T, num_heads, head_size)
query, key, value = [x.dot(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
print(qkv[0].shape)
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
#scaled_score = score * (1/np.sqrt(self.projection_dim))
print(query.shape)
print(key.shape)
#print(value.shape)
#print(scaled_score.shape)
#query = self.query_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim))
#key = self.key_dense(x).reshape((bs, -1, self.num_heads, self.projection_dim))
@@ -59,6 +70,5 @@ if __name__ == "__main__":
tb = TransformerBlock(128, 4)
tmp = Tensor.zeros(20, 10, 128)
ret = tb(tmp)
ret.backward()
print(ret)

View File

@@ -39,6 +39,12 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
def cpu_only(func):
def wrapper(self):
if self.device == Device.CPU:
func(self)
return wrapper
class TestOps(unittest.TestCase):
device=Device.CPU
@@ -107,10 +113,8 @@ class TestOps(unittest.TestCase):
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
@cpu_only # TODO: transpose for GPU
def test_transpose(self):
# TODO: transpose for GPU
if self.device == Device.GPU:
return
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
def test_reshape(self):