rename log_softmax, support dim, fix onnx Softmax

This commit is contained in:
George Hotz
2023-02-24 10:11:24 -08:00
parent 5cdfeffe2c
commit 2e56a4793e
18 changed files with 63 additions and 40 deletions

View File

@@ -82,7 +82,7 @@ class TinyBobNet:
self.l2 = Tensor.uniform(128, 10)
def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
return x.dot(self.l1).relu().dot(self.l2).log_softmax()
model = TinyBobNet()
optim = optim.SGD([model.l1, model.l2], lr=0.001)

View File

@@ -39,7 +39,7 @@ if __name__ == "__main__":
if i < 3 or not CLCACHE:
st = time.monotonic()
out = model.forward(x_train)
loss = out.logsoftmax().mul(y_train).mean()
loss = out.log_softmax().mul(y_train).mean()
if i == 2 and CLCACHE: GlobalCounters.cache = []
if BACKWARD:
optimizer.zero_grad()

View File

@@ -46,8 +46,8 @@ class SpeedyResNet:
nn.Linear(512, num_classes, bias=False)
]
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of logsoftmax
def __call__(self, x): return x.sequential(self.net).logsoftmax()
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
def __call__(self, x): return x.sequential(self.net).log_softmax()
from extra.jit import TinyJit
@TinyJit

View File

@@ -39,7 +39,7 @@ class SpeedyResNet(nn.Module):
])
self.lin = nn.Linear(512, num_classes, bias=False)
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of logsoftmax
# note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax
def forward(self, x):
x = self.ic(x)
x = self.ib(x)

View File

@@ -33,7 +33,7 @@ class LinearDisc:
x = x.dot(self.l1).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l4).logsoftmax()
x = x.dot(self.l4).log_softmax()
return x
def make_batch(images):

View File

@@ -92,7 +92,7 @@ class BigConvNet:
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
return xo.logsoftmax()
return xo.log_softmax()
if __name__ == "__main__":

View File

@@ -75,7 +75,7 @@ if __name__ == "__main__":
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
y = Tensor(y, requires_grad=False)
loss = out.logsoftmax().mul(y).mean()
loss = out.log_softmax().mul(y).mean()
optimizer.zero_grad()

View File

@@ -25,8 +25,11 @@ def get_run_onnx(onnx_model):
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
def buffer_parse(inp):
if inp.data_type in (1,10,7):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
@@ -95,19 +98,14 @@ def get_run_onnx(onnx_model):
if n.op_type == "Relu": ret = inp[0].relu()
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
elif n.op_type == "Tanh": ret = inp[0].tanh()
elif n.op_type == "Softmax": ret = inp[0].softmax()
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
# one liners
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt.get('alpha', 1.0))
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt.get('min', -3.4e38), opt.get('max', 3.4e38))))
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt.get('perm', list(range(len(inp[0].shape))[::-1])))
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
elif n.op_type == "ReduceL2": ret = inp[0].pow(2).sum(axis=opt['axes'], keepdim=opt['keepdims']).sqrt()
elif n.op_type == "ReduceSum": ret = inp[0].sum(axis=opt['axes'], keepdim=opt['keepdims'])
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
elif n.op_type == "Expand": ret = inp[0].reshape([1]*(max(len(inp[0].shape), len(inp[1]))-len(inp[0].shape)) + list(inp[0].shape)) # just broadcast
elif n.op_type == "Div": ret = inp[0].div(inp[1])
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))])

View File

@@ -1,4 +1,5 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from extra.onnx import safe_numpy
import numpy as np
@@ -63,3 +64,25 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None):
return data * mask * (1/(1.0 - ratio)), mask
def Shape(data, end=None, start=0): return list(data.shape)[start:end]
# TODO: this doesn't match Tensor.flatten behavior
def Flatten(input, axis=1):
new_shape = (1, -1) if axis == 0 else (prod(input.shape[0:axis]), -1)
return input.reshape(new_shape)
# TODO: abstract out the broadcast logic in tensor
def Expand(input, shape):
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
# copied from _broadcasted
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
def Exp(input): return input.exp()
def Softmax(input, axis=-1): return input.softmax(axis)
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None)
def ReduceMax(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL2(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.pow(2).sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()

View File

@@ -97,7 +97,7 @@ class ResNet:
out = out.sequential(self.layer3)
out = out.sequential(self.layer4)
out = out.mean(3).mean(2)
out = out.linear(**self.fc).logsoftmax()
out = out.linear(**self.fc).log_softmax()
return out
def __call__(self, x):

View File

@@ -67,6 +67,6 @@ class Transformer:
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
x = x.sequential(self.tbs)
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
return x.reshape(shape=(bs, -1, x.shape[-1]))

View File

@@ -47,12 +47,15 @@ backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
#backend_test.include('test_gemm_*')
#backend_test.include('test_batchnorm_*')
#backend_test.include('test_transpose_*')
backend_test.include('test_shape_*')
#backend_test.include('test_shape_*')
#backend_test.include('test_flatten_*')
#backend_test.include('test_sum_*')
#backend_test.include('test_expand_*')
# almost passing node tests
#backend_test.include('test_conv_.*')
#backend_test.include('test_dropout_*')
#backend_test.include('test_reshape_*')
# good to investigate
#backend_test.include('test_slice_*')
@@ -62,13 +65,9 @@ backend_test.include('test_shape_*')
#backend_test.include('test_maxpool_2d_*')
"""
backend_test.include('test_sum_*')
backend_test.include('test_tanh_*')
# should be passing (good place to start!)
backend_test.include('test_reshape_*')
backend_test.include('test_flatten_*')
backend_test.include('test_expand_*')
backend_test.include('test_clip_*')
"""
@@ -83,7 +82,7 @@ backend_test.include('test_clip_*')
# the node tests, slowly
#backend_test.include('test_reduce_sum_*')
#backend_test.include('test_softmax_*')
backend_test.include('test_softmax_*')
#backend_test.include('test_lrn_*')
# working big model tests

View File

@@ -19,7 +19,7 @@ class TinyBobNet:
return optim.get_parameters(self)
def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
return x.dot(self.l1).relu().dot(self.l2).log_softmax()
# create a model with a conv layer
class TinyConvNet:
@@ -40,7 +40,7 @@ class TinyConvNet:
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).logsoftmax()
return x.dot(self.l1).log_softmax()
class TestMNIST(unittest.TestCase):
def test_sgd_onestep(self):

View File

@@ -75,7 +75,7 @@ class TestConvSpeed(unittest.TestCase):
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(shape=(x.shape[0], -1))
out = x.dot(l1).logsoftmax()
out = x.dot(l1).log_softmax()
out = out.mean()
et1 = time.time()
out.backward()

View File

@@ -146,8 +146,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
def test_logsoftmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7)
def test_log_softmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
def test_log_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
def test_topo_sort(self):

View File

@@ -33,7 +33,7 @@ class TinyNet():
def forward(self):
out = self.x.dot(self.W).relu()
out = out.logsoftmax()
out = out.log_softmax()
out = out.mul(self.m).add(self.m).sum()
return out

View File

@@ -41,7 +41,7 @@ class TestTinygrad(unittest.TestCase):
W = Tensor(W_init, requires_grad=True)
m = Tensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.log_softmax()
out = out.mul(m).add(m).sum()
out.backward()
return out.cpu().numpy(), x.grad.cpu().numpy(), W.grad.cpu().numpy()
@@ -67,7 +67,7 @@ class TestTinygrad(unittest.TestCase):
x = u.mul(v).relu()
y = u.mul(w).relu()
out = x.add(y).mul(y).relu()
out = out.logsoftmax()
out = out.log_softmax()
out = out.sum()
out.backward()
return out.cpu().numpy(), u.cpu().grad.numpy(), v.cpu().grad.numpy(), w.cpu().grad.numpy()
@@ -123,7 +123,7 @@ class TestTinygrad(unittest.TestCase):
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
tiny_func = lambda x: x.dot(tiny_W).relu().log_softmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
@@ -138,7 +138,7 @@ class TestTinygrad(unittest.TestCase):
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
tiny_func = lambda x: x.dot(tiny_W).relu().log_softmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))

View File

@@ -263,18 +263,17 @@ class Tensor:
out = self.sum(axis=axis, keepdim=keepdim)
return out * (prod(out.shape)/prod(self.shape))
def _softmax(self):
m = self - self.max(axis=len(self.shape)-1, keepdim=True)
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=len(self.shape)-1, keepdim=True)
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self):
_, e, ss = self._softmax()
def softmax(self, axis=-1):
_, e, ss = self._softmax(axis)
return e.div(ss)
# TODO: logsoftmax -> log_softmax and add dim param
def logsoftmax(self):
m, _, ss = self._softmax()
def log_softmax(self, axis=-1):
m, _, ss = self._softmax(axis)
return m - ss.log()
# ***** processing ops *****