diff --git a/README.md b/README.md index 14791c9517..6d40b48e3b 100644 --- a/README.md +++ b/README.md @@ -88,8 +88,10 @@ python -m pytest ### TODO +* Train an EfficientNet + * EfficientNet backward pass + * Tensors on GPU (GPU support, must support Mac) * Reduce code * Increase speed * Add features -* In that order diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 6c771cdabb..5165a4d8bc 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -1,5 +1,4 @@ -# TODO: implement BatchNorm2d and Swish -# aka batch_norm, pad, swish, dropout +# load weights from # https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth # a rough copy of # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 95f6a2c7c1..928b5c0fdd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -48,6 +48,21 @@ class Pow(Function): return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output register('pow', Pow) +class Sum(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return np.array([input.sum()]) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output * np.ones_like(input) +register('sum', Sum) + + +# ************* GEMM ************* + class Dot(Function): @staticmethod def forward(ctx, input, weight): @@ -63,20 +78,8 @@ class Dot(Function): register('dot', Dot) register('matmul', Dot) -class Sum(Function): - @staticmethod - def forward(ctx, input): - ctx.save_for_backward(input) - return np.array([input.sum()]) - @staticmethod - def backward(ctx, grad_output): - input, = ctx.saved_tensors - return grad_output * np.ones_like(input) -register('sum', Sum) - - -# ************* nn ops ************* +# ************* simple ops ************* class Pad2D(Function): @staticmethod @@ -88,6 +91,21 @@ class Pad2D(Function): raise Exception("write this") register('pad2d', Pad2D) +class Reshape(Function): + @staticmethod + def forward(ctx, x, shape): + ctx.save_for_backward(x.shape) + return x.reshape(shape) + + @staticmethod + def backward(ctx, grad_output): + in_shape, = ctx.saved_tensors + return grad_output.reshape(in_shape) +register('reshape', Reshape) + + +# ************* activation ops ************* + class ReLU(Function): @staticmethod def forward(ctx, input): @@ -118,18 +136,6 @@ class Sigmoid(Function): return grad_input register('sigmoid', Sigmoid) -class Reshape(Function): - @staticmethod - def forward(ctx, x, shape): - ctx.save_for_backward(x.shape) - return x.reshape(shape) - - @staticmethod - def backward(ctx, grad_output): - in_shape, = ctx.saved_tensors - return grad_output.reshape(in_shape) -register('reshape', Reshape) - class LogSoftmax(Function): @staticmethod def forward(ctx, input):