mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make multidot work on CPU
This commit is contained in:
@@ -48,14 +48,16 @@ class TransformerBlock:
|
||||
|
||||
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
|
||||
score = query.dot(key)
|
||||
print(query.shape)
|
||||
print(key.shape)
|
||||
print(score.shape)
|
||||
|
||||
|
||||
#score = query.reshape(shape=(-1, self.projection_dim)).dot(
|
||||
# key.reshape(shape=(-1, self.projection_dim)).transpose(order=(1,0)))
|
||||
#scaled_score = score * (1/np.sqrt(self.projection_dim))
|
||||
|
||||
print(query.shape)
|
||||
print(key.shape)
|
||||
#print(value.shape)
|
||||
#print(scaled_score.shape)
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0
|
||||
|
||||
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
|
||||
|
||||
# TODO: everywhere you see this, make the op work on GPU
|
||||
def cpu_only(func):
|
||||
def wrapper(self):
|
||||
if self.device == Device.CPU:
|
||||
@@ -70,6 +71,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
|
||||
@cpu_only
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
def test_sum_axis(self):
|
||||
@@ -113,7 +117,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
|
||||
|
||||
@cpu_only # TODO: transpose for GPU
|
||||
@cpu_only
|
||||
def test_transpose(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
|
||||
|
||||
|
||||
@@ -77,13 +77,13 @@ class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input.dot(weight)
|
||||
return input @ weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output.dot(weight.T)
|
||||
grad_weight = input.T.dot(grad_output)
|
||||
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
|
||||
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user