mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add GPU max thanks to marcelbischoff
This commit is contained in:
@@ -82,7 +82,6 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device)
|
||||
@cpu_only
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
|
||||
|
||||
@@ -228,6 +228,24 @@ class Sum(Function):
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
register('sum', Sum, device=Device.GPU)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis)
|
||||
ctx.save_for_backward(input, axis, ret)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret))
|
||||
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
|
||||
register('max', Max, device=Device.GPU)
|
||||
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
|
||||
@@ -224,15 +224,15 @@ class Tensor:
|
||||
|
||||
def softmax(self):
|
||||
ns = list(self.shape)[:-1]+[1]
|
||||
#e = (self - self.max(axis=len(self.shape)-1).reshape(shape=ns)).exp()
|
||||
e = self.exp()
|
||||
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
|
||||
e = (self - m).exp()
|
||||
ss = e.sum(axis=len(self.shape)-1).reshape(shape=ns)
|
||||
return e.div(ss)
|
||||
|
||||
def logsoftmax(self):
|
||||
ns = list(self.shape)[:-1]+[1]
|
||||
# TODO: logsumexp stability with max
|
||||
ss = self.exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
|
||||
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
|
||||
ss = m + (self-m).exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
|
||||
return self - ss
|
||||
|
||||
def dropout(self, p=0.5):
|
||||
|
||||
Reference in New Issue
Block a user