From 593233b668be6eaee70cedbb40a2e708a00756dd Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 28 Dec 2020 10:00:30 -0500 Subject: [PATCH] log and exp are first class ops --- README.md | 6 +++--- examples/transformer.py | 3 +-- test/test_ops.py | 4 ++++ tinygrad/ops_cpu.py | 46 ++++++++++++++++------------------------- tinygrad/ops_gpu.py | 38 +++++++++++++++------------------- tinygrad/tensor.py | 13 ++++++++++++ 6 files changed, 56 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 44e3e70fa5..0ce446d1ef 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,12 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi ### Adding an accelerator -You need to support 14 basic ops: +You need to support 15 basic ops: ``` Add, Sub, Mul, Pow, Sum, Dot -Pad2D, Reshape -Relu, Sigmoid, LogSoftmax +Pad2D, Reshape, Transpose +Relu, Log, Exp Conv2D, MaxPool2D, AvgPool2D ``` diff --git a/examples/transformer.py b/examples/transformer.py index 1a6bf6d57b..b72c4af067 100755 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -55,8 +55,7 @@ class TransformerBlock: value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size) score = query.dot(key) * (1 / np.sqrt(self.head_size)) - # TODO: this should be a normal softmax - weights = score.logsoftmax() # (bs, num_heads, T, T) + weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final) # layernorm diff --git a/test/test_ops.py b/test/test_ops.py index 10b9c411e8..f69c9ccf52 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -67,6 +67,10 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device) def test_abs(self): helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device) + def test_log(self): + helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log, device=self.device) + def test_exp(self): + helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp, device=self.device) def test_sigmoid(self): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device) def test_dot(self): diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index cffbb0a46d..5939bad682 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -59,13 +59,14 @@ register('pow', Pow) class Sum(Function): @staticmethod - def forward(ctx, input,axis=None): + def forward(ctx, input, axis=None): ctx.save_for_backward(input, axis) return np.array([input.sum()]) if axis is None else input.sum(axis=axis) @staticmethod def backward(ctx, grad_output): input, axis = ctx.saved_tensors + axis = [axis] if type(axis) == int else axis shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] return grad_output.reshape(shape) + np.zeros_like(input) register('sum', Sum) @@ -138,41 +139,30 @@ class ReLU(Function): return grad_output * (input >= 0) register('relu', ReLU) -class Sigmoid(Function): +class Log(Function): @staticmethod def forward(ctx, input): - with np.warnings.catch_warnings(): - np.warnings.filterwarnings('ignore') - ret = np.where(input >= 0, - 1/(1 + np.exp(-input)), - np.exp(input)/(1 + np.exp(input)) - ) + ctx.save_for_backward(input) + return np.log(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output / input +register('log', Log) + +class Exp(Function): + @staticmethod + def forward(ctx, input): + ret = np.exp(input) ctx.save_for_backward(ret) return ret @staticmethod def backward(ctx, grad_output): ret, = ctx.saved_tensors - return grad_output * (ret * (1 - ret)) -register('sigmoid', Sigmoid) - -def _exp_normalize(x, axis=None): - y = np.exp(x - x.max(axis=axis, keepdims=True)) - return y / y.sum(axis=axis, keepdims=True) - -class LogSoftmax(Function): - @staticmethod - def forward(ctx, input): - softmax = _exp_normalize(input, axis=-1) - ctx.save_for_backward(softmax) - return np.log(softmax) - - @staticmethod - def backward(ctx, grad_output): - softmax, = ctx.saved_tensors - return grad_output - grad_output.sum(axis=-1, keepdims=True)*softmax -register('logsoftmax', LogSoftmax) - + return grad_output * ret +register('exp', Exp) # ************* conv ops ************* diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 78e78b170b..eb25b4af52 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -226,6 +226,7 @@ register('pow', Pow, device=Device.GPU) class Sum(Function): @staticmethod def forward(ctx, input, axis=None): + axis = [axis] if type(axis) == int else axis ctx.save_for_backward(input, axis) ret = reduce_op(ctx, "out += a", "out", input, axis=axis) if axis is not None: @@ -363,18 +364,30 @@ class ReLU(Function): return binary_op(ctx, 'a * (b >= 0)', grad_output, input) register('relu', ReLU, device=Device.GPU) -class Sigmoid(Function): +class Log(Function): @staticmethod def forward(ctx, input): - ret = unary_op(ctx, '1./(1+exp(-a))', input) + ctx.save_for_backward(input) + return unary_op(ctx, 'log(a)', input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return binary_op(ctx, 'a / b', grad_output, input) +register('log', Log, device=Device.GPU) + +class Exp(Function): + @staticmethod + def forward(ctx, input): + ret = unary_op(ctx, 'exp(a)', input) ctx.save_for_backward(ret) return ret @staticmethod def backward(ctx, grad_output): ret, = ctx.saved_tensors - return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret) -register('sigmoid', Sigmoid, device=Device.GPU) + return binary_op(ctx, 'a * b', grad_output, ret) +register('exp', Exp, device=Device.GPU) class AvgPool2D(Function): @staticmethod @@ -411,23 +424,6 @@ class MaxPool2D(Function): input2=idxs) register('max_pool2d', MaxPool2D, device=Device.GPU) -class LogSoftmax(Function): - @staticmethod - def forward(ctx, input): - # TODO: stability? - lsum = reduce_op(ctx, "out += exp(a)", "log(out)", input, axis=[1]) - output = binary_op(ctx, 'a-b', input, lsum) - ctx.save_for_backward(output) - return output - - @staticmethod - def backward(ctx, grad_output): - output, = ctx.saved_tensors - lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1]) - texp = binary_op(ctx, "exp(a) * b", output, lsum) - return binary_op(ctx, "a - b", grad_output, texp) -register('logsoftmax', LogSoftmax, device=Device.GPU) - # ************* conv ops ************* class Conv2D(Function): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b24cf11e4b..3b9ed473a3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -211,6 +211,10 @@ class Tensor: def div(self, y): return self * (y ** -1.0) + def sigmoid(self): + e = self.exp() + return e.div(1 + e) + def swish(self): return self * self.sigmoid() @@ -220,6 +224,15 @@ class Tensor: def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu() + def softmax(self): + # Replace with (self - self.max()) + e = self.exp() + ss = e.sum(axis=len(self.shape)-1).reshape(shape=list(self.shape)[:-1]+[1]) + return e.div(ss) + + def logsoftmax(self): + return self.softmax().log() + def dropout(self, p=0.5): _mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype) ret = self * Tensor(_mask, requires_grad=False, device=self.device)