diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index eccede4bc1..059e73d4ee 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,6 +2,12 @@ from functools import partialmethod import numpy as np +# optional jit +try: + from numba import jit +except ImportError: + jit = lambda x: x + # **** start with two base classes **** class Tensor: @@ -170,10 +176,11 @@ class LogSoftmax(Function): return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1)) register('logsoftmax', LogSoftmax) + class Conv2D(Function): @staticmethod - def forward(ctx, x, w): - ctx.save_for_backward(x, w) + @jit + def inner_forward(x, w): cout,cin,H,W = w.shape ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype) for j in range(H): @@ -185,8 +192,8 @@ class Conv2D(Function): return ret @staticmethod - def backward(ctx, grad_output): - x, w = ctx.saved_tensors + @jit + def inner_backward(grad_output, x, w): dx = np.zeros_like(x) dw = np.zeros_like(w) cout,cin,H,W = w.shape @@ -200,5 +207,14 @@ class Conv2D(Function): dx[:, :, Y+j, X+i] += gg.dot(tw) dw[:, :, j, i] += gg.T.dot(tx) return dx, dw + + @staticmethod + def forward(ctx, x, w): + ctx.save_for_backward(x, w) + return Conv2D.inner_forward(x, w) + + @staticmethod + def backward(ctx, grad_output): + return Conv2D.inner_backward(grad_output, *ctx.saved_tensors) register('conv2d', Conv2D)