make multidot work on CPU

This commit is contained in:
George Hotz
2020-12-27 17:25:37 -05:00
parent 131e04c90c
commit f15bec6dbc
3 changed files with 12 additions and 6 deletions

View File

@@ -48,14 +48,16 @@ class TransformerBlock:
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
score = query.dot(key)
print(query.shape)
print(key.shape)
print(score.shape)
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
#scaled_score = score * (1/np.sqrt(self.projection_dim))
print(query.shape)
print(key.shape)
#print(value.shape)
#print(scaled_score.shape)

View File

@@ -39,6 +39,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
# TODO: everywhere you see this, make the op work on GPU
def cpu_only(func):
def wrapper(self):
if self.device == Device.CPU:
@@ -70,6 +71,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
@cpu_only
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
def test_sum_axis(self):
@@ -113,7 +117,7 @@ class TestOps(unittest.TestCase):
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
@cpu_only # TODO: transpose for GPU
@cpu_only
def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)

View File

@@ -77,13 +77,13 @@ class Dot(Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input.dot(weight)
return input @ weight
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output.dot(weight.T)
grad_weight = input.T.dot(grad_output)
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
return grad_input, grad_weight
register('dot', Dot)