log and exp are first class ops

This commit is contained in:
George Hotz
2020-12-28 10:00:30 -05:00
parent ffff98db78
commit 593233b668
6 changed files with 56 additions and 54 deletions

View File

@@ -105,12 +105,12 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
### Adding an accelerator
You need to support 14 basic ops:
You need to support 15 basic ops:
```
Add, Sub, Mul, Pow, Sum, Dot
Pad2D, Reshape
Relu, Sigmoid, LogSoftmax
Pad2D, Reshape, Transpose
Relu, Log, Exp
Conv2D, MaxPool2D, AvgPool2D
```

View File

@@ -55,8 +55,7 @@ class TransformerBlock:
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
# TODO: this should be a normal softmax
weights = score.logsoftmax() # (bs, num_heads, T, T)
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3))
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
# layernorm

View File

@@ -67,6 +67,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device)
def test_abs(self):
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device)
def test_log(self):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log, device=self.device)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp, device=self.device)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
def test_dot(self):

View File

@@ -66,6 +66,7 @@ class Sum(Function):
@staticmethod
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
axis = [axis] if type(axis) == int else axis
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + np.zeros_like(input)
register('sum', Sum)
@@ -138,41 +139,30 @@ class ReLU(Function):
return grad_output * (input >= 0)
register('relu', ReLU)
class Sigmoid(Function):
class Log(Function):
@staticmethod
def forward(ctx, input):
with np.warnings.catch_warnings():
np.warnings.filterwarnings('ignore')
ret = np.where(input >= 0,
1/(1 + np.exp(-input)),
np.exp(input)/(1 + np.exp(input))
)
ctx.save_for_backward(input)
return np.log(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output / input
register('log', Log)
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = np.exp(input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return grad_output * (ret * (1 - ret))
register('sigmoid', Sigmoid)
def _exp_normalize(x, axis=None):
y = np.exp(x - x.max(axis=axis, keepdims=True))
return y / y.sum(axis=axis, keepdims=True)
class LogSoftmax(Function):
@staticmethod
def forward(ctx, input):
softmax = _exp_normalize(input, axis=-1)
ctx.save_for_backward(softmax)
return np.log(softmax)
@staticmethod
def backward(ctx, grad_output):
softmax, = ctx.saved_tensors
return grad_output - grad_output.sum(axis=-1, keepdims=True)*softmax
register('logsoftmax', LogSoftmax)
return grad_output * ret
register('exp', Exp)
# ************* conv ops *************

View File

@@ -226,6 +226,7 @@ register('pow', Pow, device=Device.GPU)
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
axis = [axis] if type(axis) == int else axis
ctx.save_for_backward(input, axis)
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
if axis is not None:
@@ -363,18 +364,30 @@ class ReLU(Function):
return binary_op(ctx, 'a * (b >= 0)', grad_output, input)
register('relu', ReLU, device=Device.GPU)
class Sigmoid(Function):
class Log(Function):
@staticmethod
def forward(ctx, input):
ret = unary_op(ctx, '1./(1+exp(-a))', input)
ctx.save_for_backward(input)
return unary_op(ctx, 'log(a)', input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, 'a / b', grad_output, input)
register('log', Log, device=Device.GPU)
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = unary_op(ctx, 'exp(a)', input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret)
register('sigmoid', Sigmoid, device=Device.GPU)
return binary_op(ctx, 'a * b', grad_output, ret)
register('exp', Exp, device=Device.GPU)
class AvgPool2D(Function):
@staticmethod
@@ -411,23 +424,6 @@ class MaxPool2D(Function):
input2=idxs)
register('max_pool2d', MaxPool2D, device=Device.GPU)
class LogSoftmax(Function):
@staticmethod
def forward(ctx, input):
# TODO: stability?
lsum = reduce_op(ctx, "out += exp(a)", "log(out)", input, axis=[1])
output = binary_op(ctx, 'a-b', input, lsum)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1])
texp = binary_op(ctx, "exp(a) * b", output, lsum)
return binary_op(ctx, "a - b", grad_output, texp)
register('logsoftmax', LogSoftmax, device=Device.GPU)
# ************* conv ops *************
class Conv2D(Function):

View File

@@ -211,6 +211,10 @@ class Tensor:
def div(self, y):
return self * (y ** -1.0)
def sigmoid(self):
e = self.exp()
return e.div(1 + e)
def swish(self):
return self * self.sigmoid()
@@ -220,6 +224,15 @@ class Tensor:
def leakyrelu(self, neg_slope=0.01):
return self.relu() - (-neg_slope*self).relu()
def softmax(self):
# Replace with (self - self.max())
e = self.exp()
ss = e.sum(axis=len(self.shape)-1).reshape(shape=list(self.shape)[:-1]+[1])
return e.div(ss)
def logsoftmax(self):
return self.softmax().log()
def dropout(self, p=0.5):
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
ret = self * Tensor(_mask, requires_grad=False, device=self.device)