mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
make lazy the default (#352)
* make lazy the default * always float32 * while the lazy framework should be default, lazyness itself shouldn't be (for now) * bugfixes * remove the need for the ops class * fxn_for_op * hmm, my contiguous asserts went away * move small shape thing * refactor reduce * remove the weird unused new functions * only that install works * thats broken * unused imports, should be good if it passes
This commit is contained in:
@@ -19,9 +19,6 @@ We are working on support for the Apple Neural Engine and the Google TPU in the
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip3 install git+https://github.com/geohot/tinygrad.git --upgrade
|
||||
|
||||
# or for development
|
||||
git clone https://github.com/geohot/tinygrad.git
|
||||
cd tinygrad
|
||||
python3 setup.py develop
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import numpy as np
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
import operator
|
||||
|
||||
fxn_for_op = {
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
|
||||
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(),
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
|
||||
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: 1.0*(x==y)
|
||||
}
|
||||
|
||||
class CPUBuffer(np.ndarray):
|
||||
def __new__(cls, shape, dtype=np.float32): return np.zeros(shape, dtype=dtype).view(CPUBuffer)
|
||||
def relu(x): return np.maximum(x, 0)
|
||||
def exp(x): return np.exp(x)
|
||||
def log(x): return np.log(x)
|
||||
@@ -14,39 +21,19 @@ class CPUBuffer(np.ndarray):
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return x
|
||||
def fromCPU(x): return x.view(CPUBuffer)
|
||||
def toCPU(x): return x
|
||||
|
||||
def unary_op(x, op):
|
||||
if op == UnaryOps.NOOP: return x[:]
|
||||
elif op == UnaryOps.NEG: return -x
|
||||
elif op == UnaryOps.RELU: return x.relu()
|
||||
elif op == UnaryOps.EXP: return x.exp()
|
||||
elif op == UnaryOps.LOG: return x.log()
|
||||
elif op == UnaryOps.SIGN: return x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def binary_op(x, op, y):
|
||||
if op == BinaryOps.ADD: return x+y
|
||||
elif op == BinaryOps.SUB: return x-y
|
||||
elif op == BinaryOps.MUL: return x*y
|
||||
elif op == BinaryOps.DIV: return x/y
|
||||
elif op == BinaryOps.POW: return x**y
|
||||
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
def unary_op(x, op): return fxn_for_op[op](x)
|
||||
def binary_op(x, op, y): return fxn_for_op[op](x, y)
|
||||
|
||||
def reduce_op(x, op, new_shape):
|
||||
if x.shape == new_shape: # this is just a copy, regardless of the reduce op
|
||||
return x[:]
|
||||
else:
|
||||
if new_shape == (1,): # full reduce
|
||||
axis = tuple(range(len(x.shape)))
|
||||
else:
|
||||
assert len(x.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
|
||||
if op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
assert len(x.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
|
||||
if x.shape == new_shape: return x[:] # this is just a copy, regardless of the reduce op
|
||||
elif op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(x, op, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
|
||||
@@ -131,6 +131,7 @@ class GPUBuffer:
|
||||
if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE")
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
assert bufs[0][0] == "input" and bufs[1][0] == "weight"
|
||||
assert bufs[0][1].st.contiguous and bufs[1][1].st.contiguous
|
||||
ewbufs = bufs[2:] # input and weight are consumed by the convs
|
||||
kernel_name = "conv"
|
||||
conv_src = """
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../accel/lazy/ops_lazy.py
|
||||
@@ -1 +0,0 @@
|
||||
../../accel/opencl/ops_opencl.py
|
||||
@@ -5,14 +5,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
def __new__(cls, shape):
|
||||
if isinstance(shape, torch.Tensor):
|
||||
return super().__new__(cls, shape)
|
||||
else:
|
||||
return TorchBuffer(torch.zeros(shape)).to(device)
|
||||
|
||||
def custompad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
def getdtype(self): return np.float32
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(data):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
|
||||
from tinygrad.helpers import prod, reduce_shape, get_conv_args
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
@@ -9,31 +8,31 @@ from tinygrad.tensor import Function
|
||||
class _UnaryOp(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return ctx.unary_op(ctx.fop, input)
|
||||
return input.unary_op(ctx.fop)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return ctx.binary_op(ctx.bop, input, grad_output)
|
||||
return input.binary_op(ctx.bop, grad_output)
|
||||
|
||||
class ReLU(_UnaryOp):
|
||||
fop = UnaryOps.RELU
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
ret = ctx.unary_op(UnaryOps.SIGN, input)
|
||||
ret = ctx.unary_op(UnaryOps.RELU, ret)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret, grad_output)
|
||||
ret = input.unary_op(UnaryOps.SIGN)
|
||||
ret = ret.unary_op(UnaryOps.RELU)
|
||||
return ret.binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(_UnaryOp):
|
||||
fop = UnaryOps.LOG
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return ctx.binary_op(BinaryOps.DIV, grad_output, input)
|
||||
return grad_output.binary_op(BinaryOps.DIV, input)
|
||||
|
||||
class Exp(_UnaryOp):
|
||||
def forward(ctx, input):
|
||||
ret = ctx.unary_op(UnaryOps.EXP, input)
|
||||
ret = input.unary_op(UnaryOps.EXP)
|
||||
ctx.save_for_backward(ret) # we save the output here, not the input
|
||||
return ret
|
||||
|
||||
@@ -46,15 +45,15 @@ class Exp(_UnaryOp):
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input.shape)
|
||||
return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis))
|
||||
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input)
|
||||
return grad_output.movement_op(MovementOps.EXPAND, shape_input)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis))
|
||||
ret = input.reduce_op(ReduceOps.MAX, reduce_shape(input.shape, axis))
|
||||
ctx.save_for_backward(input, ret)
|
||||
return ret
|
||||
|
||||
@@ -62,21 +61,21 @@ class Max(Function):
|
||||
input, ret = ctx.saved_tensors
|
||||
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = ctx.binary_op(BinaryOps.CMPEQ, input, ctx.movement_op(MovementOps.EXPAND, ret, input.shape))
|
||||
max_is_1s = input.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, input.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = ctx.reduce_op(ReduceOps.SUM, max_is_1s, grad_output.shape)
|
||||
div = ctx.movement_op(MovementOps.EXPAND, div, input.shape)
|
||||
max_is_amount = ctx.binary_op(BinaryOps.DIV, max_is_1s, div)
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape)
|
||||
div = div.movement_op(MovementOps.EXPAND, input.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = ctx.movement_op(MovementOps.EXPAND, grad_output, input.shape)
|
||||
return ctx.binary_op(BinaryOps.MUL, max_is_amount, grad_output_expanded)
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, input.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
return ctx.binary_op(BinaryOps.ADD, x, y)
|
||||
return x.binary_op(BinaryOps.ADD, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
@@ -84,28 +83,28 @@ class Add(Function):
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
return ctx.binary_op(BinaryOps.SUB, x, y)
|
||||
return x.binary_op(BinaryOps.SUB, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None
|
||||
grad_output.unary_op(UnaryOps.NEG) if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return ctx.binary_op(BinaryOps.MUL, x, y)
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
|
||||
grad_x = y.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = x.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
# TODO: add Div? is the optimizer on Pow good enough?
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
ret = ctx.binary_op(BinaryOps.POW, x, y)
|
||||
ret = x.binary_op(BinaryOps.POW, y)
|
||||
ctx.save_for_backward(x, y, ret)
|
||||
return ret
|
||||
|
||||
@@ -113,12 +112,12 @@ class Pow(Function):
|
||||
x,y,powxy = ctx.saved_tensors
|
||||
grad_x, grad_y = None, None
|
||||
if ctx.needs_input_grad[0]:
|
||||
tmp = ctx.binary_op(BinaryOps.DIV, powxy, x) # pow(x,y)/x
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
|
||||
grad_x = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
|
||||
tmp = powxy.binary_op(BinaryOps.DIV, x) # pow(x,y)/x
|
||||
tmp = y.binary_op(BinaryOps.MUL, tmp) # y * pow(x,y)/x
|
||||
grad_x = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
if ctx.needs_input_grad[1]:
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
|
||||
tmp = x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy) # log(x) * pow(x,y)
|
||||
grad_y = grad_output.binary_op(BinaryOps.MUL, tmp)
|
||||
return grad_x, grad_y
|
||||
|
||||
# ************* movement ops *************
|
||||
@@ -127,65 +126,60 @@ class Pow(Function):
|
||||
class Expand(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return ctx.movement_op(MovementOps.EXPAND, x, shape)
|
||||
return x.movement_op(MovementOps.EXPAND, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
|
||||
return grad_output.reduce_op(ReduceOps.SUM, in_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return ctx.movement_op(MovementOps.RESHAPE, x, shape)
|
||||
return x.movement_op(MovementOps.RESHAPE, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape)
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, in_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
return ctx.movement_op(MovementOps.PERMUTE, x, order)
|
||||
return x.movement_op(MovementOps.PERMUTE, order)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
order, = ctx.saved_tensors
|
||||
norder = np.argsort(order).tolist()
|
||||
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, norder)
|
||||
|
||||
# TODO: merge Slice and Flip into Stride with the 3 arguments
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
return ctx.movement_op(MovementOps.SLICE, x, arg)
|
||||
return x.movement_op(MovementOps.SLICE, arg)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, arg = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
|
||||
return grad_output.movement_op(MovementOps.SLICE, narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
return ctx.movement_op(MovementOps.FLIP, x, axis)
|
||||
return x.movement_op(MovementOps.FLIP, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
|
||||
return grad_output.movement_op(MovementOps.FLIP, axis)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
# TODO: this does NOT belong here
|
||||
def _conv(ctx, x, w, C):
|
||||
if "OPENCL" in ctx.device or int(os.getenv("LAZY_OPENCL", 0)):
|
||||
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op
|
||||
x,w,Cmod = preprocessing_op(ctx, x, w, C)
|
||||
ret = ctx.processing_op(ProcessingOps.CONV, x, w, Cmod)
|
||||
return postprocessing_op(ctx, ret, Cmod, C)
|
||||
else:
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, C)
|
||||
# TODO: this does NOT belong here
|
||||
# was pre/post processing for opencl
|
||||
return x.processing_op(ProcessingOps.CONV, w, C)
|
||||
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
@@ -198,13 +192,13 @@ class Conv2D(Function):
|
||||
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
xt = grad_output
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
|
||||
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
|
||||
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
|
||||
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
|
||||
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
wt = wt.movement_op(MovementOps.FLIP, (3, 4))
|
||||
wt = wt.movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
|
||||
py_ = x.shape[2] - xt.shape[2] + C.py
|
||||
px_ = x.shape[3] - xt.shape[3] + C.px
|
||||
@@ -212,14 +206,14 @@ class Conv2D(Function):
|
||||
dx = ctx._conv(xt, wt, Cdx)
|
||||
|
||||
if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
|
||||
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4))
|
||||
xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
|
||||
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox))
|
||||
xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xdw = xdw.movement_op(MovementOps.PERMUTE, (2,1,0,3,4))
|
||||
xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix))
|
||||
grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
grad_output_dw = grad_output_dw.movement_op(MovementOps.RESHAPE, (C.cout, C.bs, C.oy, C.ox))
|
||||
py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1
|
||||
px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
|
||||
grad_weight = ctx._conv(xdw, grad_output_dw, Cdw)
|
||||
dw = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
|
||||
dw = grad_weight.movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
return dx, dw
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
def batch_normalize(x, weight, bias, mean, var, eps):
|
||||
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
|
||||
|
||||
150
tinygrad/ops.py
150
tinygrad/ops.py
@@ -1,11 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from enum import Enum
|
||||
from tinygrad.helpers import prod
|
||||
from typing import Tuple, NamedTuple, Union, Any, List
|
||||
import functools, operator
|
||||
from tinygrad.helpers import ConvArgs
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU"])
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
|
||||
|
||||
# lazy can recurse a lot
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
import os
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
@@ -16,8 +25,7 @@ cnts = defaultdict(int)
|
||||
import atexit
|
||||
if DEBUG:
|
||||
def debug_exit():
|
||||
for k,v in cnts.items():
|
||||
print(k, v)
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
atexit.register(debug_exit)
|
||||
|
||||
if GRAPH:
|
||||
@@ -42,7 +50,7 @@ def log_op(optype, op, ret, inp):
|
||||
global_num_max += 1
|
||||
return f"<<< {x.global_num} >>>"
|
||||
|
||||
top_colors = {UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
|
||||
|
||||
for x in inp:
|
||||
if not isinstance(op, list): op = [op]
|
||||
@@ -58,45 +66,107 @@ def log_op(optype, op, ret, inp):
|
||||
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in st.views[-1].shape_strides))
|
||||
else:
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
if 'contiguous' in str(op).lower():
|
||||
G.nodes[nm(ret)]['fillcolor'] = '#FFFF80'
|
||||
else:
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
|
||||
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = x.unary_op(op)
|
||||
if 'LAZY' not in ctx.device: log_op(UnaryOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
def reduce_op(ctx, op:ReduceOps, x, new_shape):
|
||||
ret = x.reduce_op(op, tuple(new_shape))
|
||||
if 'LAZY' not in ctx.device: log_op(ReduceOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == tuple(new_shape)
|
||||
return ret
|
||||
import importlib, inspect
|
||||
class Device:
|
||||
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops")))
|
||||
DEFAULT = None
|
||||
buffers = {}
|
||||
for i,op in enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]):
|
||||
name = op[len("ops_"):].upper()
|
||||
vars()[name] = name
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try:
|
||||
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
buffers[name] = find_buffer(importlib.import_module('tinygrad.llops.'+op), name)
|
||||
except ImportError as e:
|
||||
print(op, "not available", e)
|
||||
DEFAULT = CPU if DEFAULT is None else DEFAULT
|
||||
|
||||
def binary_op(ctx, op:BinaryOps, x, y):
|
||||
assert x.shape == y.shape
|
||||
ret = x.binary_op(op, y)
|
||||
if 'LAZY' not in ctx.device: log_op(BinaryOps, op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
# TODO: get device buffer types
|
||||
DeviceBuffer = Any
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg):
|
||||
ret = x.movement_op(op, tuple(arg))
|
||||
if 'LAZY' not in ctx.device: log_op(MovementOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == ShapeTracker(x.shape).movement_op(op, arg).shape
|
||||
return ret
|
||||
def _realize(self:LazyBuffer) -> DeviceBuffer:
|
||||
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
|
||||
return Device.buffers[self.device].fromCPU(self.op.arg), []
|
||||
elif self.optype == ReduceOps:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
|
||||
elif self.optype == MovementOps:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.movement_op(self.op.op, self.op.arg), [real_src]
|
||||
elif self.optype == UnaryOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
return real_src_x.unary_op(self.op.op), [real_src_x]
|
||||
elif self.optype == BinaryOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
real_src_y = self.op.src[1].realize(self.device)
|
||||
return real_src_x.binary_op(self.op.op, real_src_y), [real_src_x, real_src_y]
|
||||
elif self.optype == ProcessingOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
real_src_w = self.op.src[1].realize(self.device)
|
||||
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
class LazyOp(NamedTuple):
|
||||
op: Op
|
||||
src: Tuple[Union[LazyOp, LazyBuffer]]
|
||||
arg: Any = None
|
||||
# TODO: add dest to support multiple outputs
|
||||
|
||||
def get_lazybuffers(op:LazyOp) -> List[LazyBuffer]: return functools.reduce(operator.add, [get_lazybuffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
|
||||
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
|
||||
|
||||
LAZY = int(os.getenv("LAZY", "0"))
|
||||
|
||||
class LazyBuffer:
|
||||
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp):
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape = self.st.shape
|
||||
self.optype, self.op = optype, op
|
||||
self.realized = None
|
||||
self.device = device
|
||||
if not LAZY: self.realize()
|
||||
|
||||
# this produces a device buffer
|
||||
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
|
||||
if required_device is not None: assert required_device == self.device
|
||||
if self.realized is None:
|
||||
# we haven't realized the Buffer yet
|
||||
self.realized, real_srcs = _realize(self)
|
||||
# in lazy mode, we don't log until we realize
|
||||
log_op(self.optype, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
|
||||
# no need to keep the op after realization
|
||||
del self.op
|
||||
|
||||
assert self.realized.shape == self.shape
|
||||
assert isinstance(self.realized, Device.buffers[self.device])
|
||||
return self.realized
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x, device):
|
||||
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x))
|
||||
|
||||
def toCPU(x):
|
||||
return x.realize().toCPU()
|
||||
|
||||
def unary_op(x:LazyBuffer, op:UnaryOps) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, x.shape, UnaryOps, LazyOp(op, (x,)))
|
||||
|
||||
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, x.shape, BinaryOps, LazyOp(op, (x,y)))
|
||||
|
||||
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int]) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
|
||||
|
||||
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x,), arg))
|
||||
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, C):
|
||||
ret = x.processing_op(op, y, C)
|
||||
if 'LAZY' not in ctx.device: log_op(ProcessingOps, op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == C.out_shape
|
||||
return ret
|
||||
@@ -1,68 +1,82 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
import os, inspect, functools, importlib
|
||||
import inspect, functools, importlib
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod
|
||||
from typing import List
|
||||
from tinygrad.ops import Device
|
||||
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
class Device:
|
||||
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops")))
|
||||
imports = dict(enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]))
|
||||
DEFAULT = None
|
||||
buffers, llops = {}, {}
|
||||
for i,op in imports.items():
|
||||
name = op[len("ops_"):].upper()
|
||||
vars()[name] = name
|
||||
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
|
||||
try:
|
||||
llops[name] = importlib.import_module('tinygrad.llops.'+op)
|
||||
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
|
||||
buffers[name] = find_buffer(llops[name], name)
|
||||
except ImportError as e:
|
||||
print(op, "not available", e)
|
||||
DEFAULT = CPU if DEFAULT is None else DEFAULT
|
||||
from tinygrad.ops import LazyBuffer
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
training = False
|
||||
|
||||
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
|
||||
self.device, self.data = device, self._move_data(data, device)
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
# TODO: this has to realize, it shouldn't have to
|
||||
data = data.realize().toCPU()
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.shape == tuple(): data = data.reshape((1,))
|
||||
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
|
||||
elif isinstance(data, LazyBuffer):
|
||||
self.lazydata = data
|
||||
else:
|
||||
raise Exception(f"can't create Tensor from {data}")
|
||||
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad, self.requires_grad = None, requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.data!r} with grad {(self.grad.data if self.grad else None)!r}>"
|
||||
return f"<Tensor {self.lazydata!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
||||
@property
|
||||
def shape(self): return self.lazydata.shape
|
||||
|
||||
# dtype handling was very broken. it's always float32 now
|
||||
@property
|
||||
def dtype(self): return np.float32
|
||||
|
||||
@property
|
||||
def device(self): return self.lazydata.device
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
def realize(self):
|
||||
# TODO: once lazy is upstreamed, we can remove this check
|
||||
if getattr(self.data, 'realize', None) is not None:
|
||||
self.data.realize()
|
||||
self.lazydata.realize()
|
||||
|
||||
def assign(self, x):
|
||||
if not isinstance(x, Tensor):
|
||||
x = Tensor(x)
|
||||
assert self.shape == x.shape
|
||||
self.data = x.data
|
||||
self.lazydata = x.lazydata
|
||||
return x
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.data.shape
|
||||
|
||||
@staticmethod
|
||||
def _get_data_dtype(data):
|
||||
return data.getdtype() if getattr(data, 'getdtype', None) else (data.dtype if getattr(data, 'dtype', None) else np.float32)
|
||||
def detach(self):
|
||||
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
def numpy(self):
|
||||
return np.array(self.lazydata.toCPU())
|
||||
|
||||
# TOOD: this keeps the legacy behavior working, remove it after refactor
|
||||
@property
|
||||
def dtype(self):
|
||||
return Tensor._get_data_dtype(self.data)
|
||||
def data(self):
|
||||
return self.numpy()
|
||||
|
||||
def to_(self, device):
|
||||
self.device = device
|
||||
if self.grad: self.grad.device = device
|
||||
|
||||
def to(self, device):
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@@ -114,7 +128,7 @@ class Tensor:
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.data)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
@@ -123,41 +137,6 @@ class Tensor:
|
||||
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
|
||||
# ***** tinygrad supports many devices *****
|
||||
|
||||
@staticmethod
|
||||
def _move_data(data, device):
|
||||
if isinstance(data, Device.buffers[device]):
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
# TODO: don't use np.array here, support Tensor creation direct to device
|
||||
data = np.array(data, dtype=np.float32)
|
||||
if isinstance(data, np.ndarray):
|
||||
data = data.view(Device.buffers[Device.CPU])
|
||||
|
||||
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}")
|
||||
Tensor.did_float_warning = True
|
||||
|
||||
data = data.toCPU().view(Device.buffers[Device.CPU])
|
||||
return Device.buffers[device].fromCPU(data)
|
||||
|
||||
def to_(self, device):
|
||||
self.data, self.device = self._move_data(self.data, device), device
|
||||
if self.grad: self.grad.to_(device)
|
||||
|
||||
def to(self, device):
|
||||
ret = Tensor(self.data, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
def detach(self):
|
||||
return Tensor(self.data, device=self.device, requires_grad=False)
|
||||
|
||||
def numpy(self):
|
||||
return np.array(self.cpu().data)
|
||||
|
||||
# ***** non first class ops (hlops) *****
|
||||
|
||||
def __getitem__(self, val):
|
||||
@@ -350,8 +329,7 @@ class Tensor:
|
||||
return y.div((y*y).mean(axis=-1, keepdim=True).add(eps).sqrt())
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
from tinygrad.ops import Ops
|
||||
class Function(Ops):
|
||||
class Function:
|
||||
def __init__(self, device, *tensors):
|
||||
self.device = device
|
||||
self.parents = tensors
|
||||
@@ -359,16 +337,14 @@ class Function(Ops):
|
||||
self.requires_grad = any(self.needs_input_grad)
|
||||
self.saved_tensors = []
|
||||
|
||||
buffer = property(lambda self: Device.buffers[self.device])
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
if self.requires_grad:
|
||||
self.saved_tensors.extend(x)
|
||||
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *x:List[Tensor], **kwargs):
|
||||
ctx = cls(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs),
|
||||
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad: ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user