mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
log and exp are first class ops
This commit is contained in:
@@ -105,12 +105,12 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
|
||||
|
||||
### Adding an accelerator
|
||||
|
||||
You need to support 14 basic ops:
|
||||
You need to support 15 basic ops:
|
||||
|
||||
```
|
||||
Add, Sub, Mul, Pow, Sum, Dot
|
||||
Pad2D, Reshape
|
||||
Relu, Sigmoid, LogSoftmax
|
||||
Pad2D, Reshape, Transpose
|
||||
Relu, Log, Exp
|
||||
Conv2D, MaxPool2D, AvgPool2D
|
||||
```
|
||||
|
||||
|
||||
@@ -55,8 +55,7 @@ class TransformerBlock:
|
||||
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
||||
|
||||
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
||||
# TODO: this should be a normal softmax
|
||||
weights = score.logsoftmax() # (bs, num_heads, T, T)
|
||||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3))
|
||||
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
|
||||
# layernorm
|
||||
|
||||
@@ -67,6 +67,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device)
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device)
|
||||
def test_log(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log, device=self.device)
|
||||
def test_exp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp, device=self.device)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
|
||||
def test_dot(self):
|
||||
|
||||
@@ -59,13 +59,14 @@ register('pow', Pow)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input,axis=None):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
return np.array([input.sum()]) if axis is None else input.sum(axis=axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
register('sum', Sum)
|
||||
@@ -138,41 +139,30 @@ class ReLU(Function):
|
||||
return grad_output * (input >= 0)
|
||||
register('relu', ReLU)
|
||||
|
||||
class Sigmoid(Function):
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
with np.warnings.catch_warnings():
|
||||
np.warnings.filterwarnings('ignore')
|
||||
ret = np.where(input >= 0,
|
||||
1/(1 + np.exp(-input)),
|
||||
np.exp(input)/(1 + np.exp(input))
|
||||
)
|
||||
ctx.save_for_backward(input)
|
||||
return np.log(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output / input
|
||||
register('log', Log)
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = np.exp(input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * (ret * (1 - ret))
|
||||
register('sigmoid', Sigmoid)
|
||||
|
||||
def _exp_normalize(x, axis=None):
|
||||
y = np.exp(x - x.max(axis=axis, keepdims=True))
|
||||
return y / y.sum(axis=axis, keepdims=True)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
softmax = _exp_normalize(input, axis=-1)
|
||||
ctx.save_for_backward(softmax)
|
||||
return np.log(softmax)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
softmax, = ctx.saved_tensors
|
||||
return grad_output - grad_output.sum(axis=-1, keepdims=True)*softmax
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
return grad_output * ret
|
||||
register('exp', Exp)
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
|
||||
@@ -226,6 +226,7 @@ register('pow', Pow, device=Device.GPU)
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ctx.save_for_backward(input, axis)
|
||||
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
if axis is not None:
|
||||
@@ -363,18 +364,30 @@ class ReLU(Function):
|
||||
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
|
||||
register('relu', ReLU, device=Device.GPU)
|
||||
|
||||
class Sigmoid(Function):
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = unary_op(ctx, '1./(1+exp(-a))', input)
|
||||
ctx.save_for_backward(input)
|
||||
return unary_op(ctx, 'log(a)', input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a / b', grad_output, input)
|
||||
register('log', Log, device=Device.GPU)
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = unary_op(ctx, 'exp(a)', input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret)
|
||||
register('sigmoid', Sigmoid, device=Device.GPU)
|
||||
return binary_op(ctx, 'a * b', grad_output, ret)
|
||||
register('exp', Exp, device=Device.GPU)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
@@ -411,23 +424,6 @@ class MaxPool2D(Function):
|
||||
input2=idxs)
|
||||
register('max_pool2d', MaxPool2D, device=Device.GPU)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
# TODO: stability?
|
||||
lsum = reduce_op(ctx, "out += exp(a)", "log(out)", input, axis=[1])
|
||||
output = binary_op(ctx, 'a-b', input, lsum)
|
||||
ctx.save_for_backward(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output, = ctx.saved_tensors
|
||||
lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1])
|
||||
texp = binary_op(ctx, "exp(a) * b", output, lsum)
|
||||
return binary_op(ctx, "a - b", grad_output, texp)
|
||||
register('logsoftmax', LogSoftmax, device=Device.GPU)
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
|
||||
@@ -211,6 +211,10 @@ class Tensor:
|
||||
def div(self, y):
|
||||
return self * (y ** -1.0)
|
||||
|
||||
def sigmoid(self):
|
||||
e = self.exp()
|
||||
return e.div(1 + e)
|
||||
|
||||
def swish(self):
|
||||
return self * self.sigmoid()
|
||||
|
||||
@@ -220,6 +224,15 @@ class Tensor:
|
||||
def leakyrelu(self, neg_slope=0.01):
|
||||
return self.relu() - (-neg_slope*self).relu()
|
||||
|
||||
def softmax(self):
|
||||
# Replace with (self - self.max())
|
||||
e = self.exp()
|
||||
ss = e.sum(axis=len(self.shape)-1).reshape(shape=list(self.shape)[:-1]+[1])
|
||||
return e.div(ss)
|
||||
|
||||
def logsoftmax(self):
|
||||
return self.softmax().log()
|
||||
|
||||
def dropout(self, p=0.5):
|
||||
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user