make lazy the default (#352)

* make lazy the default

* always float32

* while the lazy framework should be default, lazyness itself shouldn't be (for now)

* bugfixes

* remove the need for the ops class

* fxn_for_op

* hmm, my contiguous asserts went away

* move small shape thing

* refactor reduce

* remove the weird unused new functions

* only that install works

* thats broken

* unused imports, should be good if it passes
This commit is contained in:
George Hotz
2022-07-03 11:40:27 -07:00
committed by GitHub
parent bbfdd28a6d
commit df16b455a7
10 changed files with 236 additions and 221 deletions

View File

@@ -19,9 +19,6 @@ We are working on support for the Apple Neural Engine and the Google TPU in the
### Installation
```bash
pip3 install git+https://github.com/geohot/tinygrad.git --upgrade
# or for development
git clone https://github.com/geohot/tinygrad.git
cd tinygrad
python3 setup.py develop

View File

@@ -1,8 +1,15 @@
import numpy as np
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
import operator
fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(),
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: 1.0*(x==y)
}
class CPUBuffer(np.ndarray):
def __new__(cls, shape, dtype=np.float32): return np.zeros(shape, dtype=dtype).view(CPUBuffer)
def relu(x): return np.maximum(x, 0)
def exp(x): return np.exp(x)
def log(x): return np.log(x)
@@ -14,39 +21,19 @@ class CPUBuffer(np.ndarray):
def expand(x, new_shape): return np.broadcast_to(x, new_shape).view(CPUBuffer)
@staticmethod
def fromCPU(x): return x
def fromCPU(x): return x.view(CPUBuffer)
def toCPU(x): return x
def unary_op(x, op):
if op == UnaryOps.NOOP: return x[:]
elif op == UnaryOps.NEG: return -x
elif op == UnaryOps.RELU: return x.relu()
elif op == UnaryOps.EXP: return x.exp()
elif op == UnaryOps.LOG: return x.log()
elif op == UnaryOps.SIGN: return x.sign()
else: raise Exception(f"{op} isn't supported")
def binary_op(x, op, y):
if op == BinaryOps.ADD: return x+y
elif op == BinaryOps.SUB: return x-y
elif op == BinaryOps.MUL: return x*y
elif op == BinaryOps.DIV: return x/y
elif op == BinaryOps.POW: return x**y
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
else: raise Exception(f"{op} isn't supported")
def unary_op(x, op): return fxn_for_op[op](x)
def binary_op(x, op, y): return fxn_for_op[op](x, y)
def reduce_op(x, op, new_shape):
if x.shape == new_shape: # this is just a copy, regardless of the reduce op
return x[:]
else:
if new_shape == (1,): # full reduce
axis = tuple(range(len(x.shape)))
else:
assert len(x.shape) == len(new_shape)
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
if op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
assert len(x.shape) == len(new_shape)
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
if x.shape == new_shape: return x[:] # this is just a copy, regardless of the reduce op
elif op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
def movement_op(x, op, arg=None):
if op == MovementOps.RESHAPE: return x.reshape(arg)

View File

@@ -131,6 +131,7 @@ class GPUBuffer:
if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE")
global_size = [C.bs*C.cout, C.oy, C.ox]
assert bufs[0][0] == "input" and bufs[1][0] == "weight"
assert bufs[0][1].st.contiguous and bufs[1][1].st.contiguous
ewbufs = bufs[2:] # input and weight are consumed by the convs
kernel_name = "conv"
conv_src = """

View File

@@ -1 +0,0 @@
../../accel/lazy/ops_lazy.py

View File

@@ -1 +0,0 @@
../../accel/opencl/ops_opencl.py

View File

@@ -5,14 +5,7 @@ from tinygrad.ops import MovementOps, ProcessingOps
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TorchBuffer(torch.Tensor):
def __new__(cls, shape):
if isinstance(shape, torch.Tensor):
return super().__new__(cls, shape)
else:
return TorchBuffer(torch.zeros(shape)).to(device)
def custompad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
def getdtype(self): return np.float32
@staticmethod
def fromCPU(data):

View File

@@ -1,4 +1,3 @@
import os
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
from tinygrad.helpers import prod, reduce_shape, get_conv_args
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
@@ -9,31 +8,31 @@ from tinygrad.tensor import Function
class _UnaryOp(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return ctx.unary_op(ctx.fop, input)
return input.unary_op(ctx.fop)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return ctx.binary_op(ctx.bop, input, grad_output)
return input.binary_op(ctx.bop, grad_output)
class ReLU(_UnaryOp):
fop = UnaryOps.RELU
def backward(ctx, grad_output):
input, = ctx.saved_tensors
ret = ctx.unary_op(UnaryOps.SIGN, input)
ret = ctx.unary_op(UnaryOps.RELU, ret)
return ctx.binary_op(BinaryOps.MUL, ret, grad_output)
ret = input.unary_op(UnaryOps.SIGN)
ret = ret.unary_op(UnaryOps.RELU)
return ret.binary_op(BinaryOps.MUL, grad_output)
class Log(_UnaryOp):
fop = UnaryOps.LOG
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return ctx.binary_op(BinaryOps.DIV, grad_output, input)
return grad_output.binary_op(BinaryOps.DIV, input)
class Exp(_UnaryOp):
def forward(ctx, input):
ret = ctx.unary_op(UnaryOps.EXP, input)
ret = input.unary_op(UnaryOps.EXP)
ctx.save_for_backward(ret) # we save the output here, not the input
return ret
@@ -46,15 +45,15 @@ class Exp(_UnaryOp):
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis))
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input)
return grad_output.movement_op(MovementOps.EXPAND, shape_input)
class Max(Function):
def forward(ctx, input, axis=None):
ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis))
ret = input.reduce_op(ReduceOps.MAX, reduce_shape(input.shape, axis))
ctx.save_for_backward(input, ret)
return ret
@@ -62,21 +61,21 @@ class Max(Function):
input, ret = ctx.saved_tensors
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = ctx.binary_op(BinaryOps.CMPEQ, input, ctx.movement_op(MovementOps.EXPAND, ret, input.shape))
max_is_1s = input.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, input.shape))
# sum of locations, averaged
div = ctx.reduce_op(ReduceOps.SUM, max_is_1s, grad_output.shape)
div = ctx.movement_op(MovementOps.EXPAND, div, input.shape)
max_is_amount = ctx.binary_op(BinaryOps.DIV, max_is_1s, div)
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape)
div = div.movement_op(MovementOps.EXPAND, input.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = ctx.movement_op(MovementOps.EXPAND, grad_output, input.shape)
return ctx.binary_op(BinaryOps.MUL, max_is_amount, grad_output_expanded)
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, input.shape)
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops *************
class Add(Function):
def forward(ctx, x, y):
return ctx.binary_op(BinaryOps.ADD, x, y)
return x.binary_op(BinaryOps.ADD, y)
def backward(ctx, grad_output):
return grad_output if ctx.needs_input_grad[0] else None, \
@@ -84,28 +83,28 @@ class Add(Function):
class Sub(Function):
def forward(ctx, x, y):
return ctx.binary_op(BinaryOps.SUB, x, y)
return x.binary_op(BinaryOps.SUB, y)
def backward(ctx, grad_output):
return grad_output if ctx.needs_input_grad[0] else None, \
ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None
grad_output.unary_op(UnaryOps.NEG) if ctx.needs_input_grad[1] else None
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return ctx.binary_op(BinaryOps.MUL, x, y)
return x.binary_op(BinaryOps.MUL, y)
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
grad_x = y.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[0] else None
grad_y = x.binary_op(BinaryOps.MUL, grad_output) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
# TODO: add Div? is the optimizer on Pow good enough?
class Pow(Function):
def forward(ctx, x, y):
ret = ctx.binary_op(BinaryOps.POW, x, y)
ret = x.binary_op(BinaryOps.POW, y)
ctx.save_for_backward(x, y, ret)
return ret
@@ -113,12 +112,12 @@ class Pow(Function):
x,y,powxy = ctx.saved_tensors
grad_x, grad_y = None, None
if ctx.needs_input_grad[0]:
tmp = ctx.binary_op(BinaryOps.DIV, powxy, x) # pow(x,y)/x
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
grad_x = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
tmp = powxy.binary_op(BinaryOps.DIV, x) # pow(x,y)/x
tmp = y.binary_op(BinaryOps.MUL, tmp) # y * pow(x,y)/x
grad_x = grad_output.binary_op(BinaryOps.MUL, tmp)
if ctx.needs_input_grad[1]:
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
grad_y = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
tmp = x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy) # log(x) * pow(x,y)
grad_y = grad_output.binary_op(BinaryOps.MUL, tmp)
return grad_x, grad_y
# ************* movement ops *************
@@ -127,65 +126,60 @@ class Pow(Function):
class Expand(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return ctx.movement_op(MovementOps.EXPAND, x, shape)
return x.movement_op(MovementOps.EXPAND, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
return grad_output.reduce_op(ReduceOps.SUM, in_shape)
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
return ctx.movement_op(MovementOps.RESHAPE, x, shape)
return x.movement_op(MovementOps.RESHAPE, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape)
return grad_output.movement_op(MovementOps.RESHAPE, in_shape)
class Permute(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
return ctx.movement_op(MovementOps.PERMUTE, x, order)
return x.movement_op(MovementOps.PERMUTE, order)
def backward(ctx, grad_output):
order, = ctx.saved_tensors
norder = np.argsort(order).tolist()
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
return grad_output.movement_op(MovementOps.PERMUTE, norder)
# TODO: merge Slice and Flip into Stride with the 3 arguments
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape, arg)
return ctx.movement_op(MovementOps.SLICE, x, arg)
return x.movement_op(MovementOps.SLICE, arg)
def backward(ctx, grad_output):
shape, arg = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
return ctx.movement_op(MovementOps.SLICE, grad_output, narg)
return grad_output.movement_op(MovementOps.SLICE, narg)
class Flip(Function):
def forward(ctx, x, axis):
ctx.save_for_backward(axis)
return ctx.movement_op(MovementOps.FLIP, x, axis)
return x.movement_op(MovementOps.FLIP, axis)
def backward(ctx, grad_output):
axis, = ctx.saved_tensors
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
return grad_output.movement_op(MovementOps.FLIP, axis)
# ************* processing ops *************
class Conv2D(Function):
# TODO: this does NOT belong here
def _conv(ctx, x, w, C):
if "OPENCL" in ctx.device or int(os.getenv("LAZY_OPENCL", 0)):
from accel.opencl.preprocessing import preprocessing_op, postprocessing_op
x,w,Cmod = preprocessing_op(ctx, x, w, C)
ret = ctx.processing_op(ProcessingOps.CONV, x, w, Cmod)
return postprocessing_op(ctx, ret, Cmod, C)
else:
return ctx.processing_op(ProcessingOps.CONV, x, w, C)
# TODO: this does NOT belong here
# was pre/post processing for opencl
return x.processing_op(ProcessingOps.CONV, w, C)
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
@@ -198,13 +192,13 @@ class Conv2D(Function):
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
xt = grad_output
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
xt = xt.movement_op(MovementOps.SLICE, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.sy), (0,xt.shape[4]), (0,C.sx)))
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
wt = wt.movement_op(MovementOps.FLIP, (3, 4))
wt = wt.movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W))
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
py_ = x.shape[2] - xt.shape[2] + C.py
px_ = x.shape[3] - xt.shape[3] + C.px
@@ -212,14 +206,14 @@ class Conv2D(Function):
dx = ctx._conv(xt, wt, Cdx)
if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4))
xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix))
grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox))
xdw = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
xdw = xdw.movement_op(MovementOps.PERMUTE, (2,1,0,3,4))
xdw = xdw.movement_op(MovementOps.RESHAPE, (C.cin, C.groups*C.bs, C.iy, C.ix))
grad_output_dw = grad_output.movement_op(MovementOps.PERMUTE, (1,0,2,3))
grad_output_dw = grad_output_dw.movement_op(MovementOps.RESHAPE, (C.cout, C.bs, C.oy, C.ox))
py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1
px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
grad_weight = ctx._conv(xdw, grad_output_dw, Cdw)
dw = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3))
dw = grad_weight.movement_op(MovementOps.PERMUTE, (1,0,2,3))
return dx, dw

View File

@@ -1,5 +1,4 @@
from tinygrad.tensor import Tensor
import numpy as np
def batch_normalize(x, weight, bias, mean, var, eps):
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])

View File

@@ -1,11 +1,20 @@
from __future__ import annotations
from enum import Enum
from tinygrad.helpers import prod
from typing import Tuple, NamedTuple, Union, Any, List
import functools, operator
from tinygrad.helpers import ConvArgs
from tinygrad.shapetracker import ShapeTracker
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
ProcessingOps = Enum("ProcessingOps", ["CONV"])
LoadOps = Enum("LoadOps", ["FROMCPU"])
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
# lazy can recurse a lot
import sys
sys.setrecursionlimit(10000)
import os
DEBUG = int(os.getenv("DEBUG", "0"))
@@ -16,8 +25,7 @@ cnts = defaultdict(int)
import atexit
if DEBUG:
def debug_exit():
for k,v in cnts.items():
print(k, v)
for k,v in cnts.items(): print(k, v)
atexit.register(debug_exit)
if GRAPH:
@@ -42,7 +50,7 @@ def log_op(optype, op, ret, inp):
global_num_max += 1
return f"<<< {x.global_num} >>>"
top_colors = {UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
for x in inp:
if not isinstance(op, list): op = [op]
@@ -58,45 +66,107 @@ def log_op(optype, op, ret, inp):
G.nodes[nm(ret)]['label'] = str(tuple(x[0] if x[1]!=0 else 0 for x in st.views[-1].shape_strides))
else:
G.nodes[nm(ret)]['label'] = str(ret.shape)
if 'contiguous' in str(op).lower():
G.nodes[nm(ret)]['fillcolor'] = '#FFFF80'
else:
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if non_contiguous else '')) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
class Ops:
def unary_op(ctx, op:UnaryOps, x):
ret = x.unary_op(op)
if 'LAZY' not in ctx.device: log_op(UnaryOps, op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
return ret
# **** enumerate supported devices ****
def reduce_op(ctx, op:ReduceOps, x, new_shape):
ret = x.reduce_op(op, tuple(new_shape))
if 'LAZY' not in ctx.device: log_op(ReduceOps, op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == tuple(new_shape)
return ret
import importlib, inspect
class Device:
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops")))
DEFAULT = None
buffers = {}
for i,op in enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]):
name = op[len("ops_"):].upper()
vars()[name] = name
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try:
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
buffers[name] = find_buffer(importlib.import_module('tinygrad.llops.'+op), name)
except ImportError as e:
print(op, "not available", e)
DEFAULT = CPU if DEFAULT is None else DEFAULT
def binary_op(ctx, op:BinaryOps, x, y):
assert x.shape == y.shape
ret = x.binary_op(op, y)
if 'LAZY' not in ctx.device: log_op(BinaryOps, op, ret, [x, y])
assert isinstance(ret, ctx.buffer)
assert ret.shape == x.shape
return ret
# TODO: get device buffer types
DeviceBuffer = Any
def movement_op(ctx, op:MovementOps, x, arg):
ret = x.movement_op(op, tuple(arg))
if 'LAZY' not in ctx.device: log_op(MovementOps, op, ret, [x])
assert isinstance(ret, ctx.buffer)
assert ret.shape == ShapeTracker(x.shape).movement_op(op, arg).shape
return ret
def _realize(self:LazyBuffer) -> DeviceBuffer:
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
return Device.buffers[self.device].fromCPU(self.op.arg), []
elif self.optype == ReduceOps:
real_src = self.op.src[0].realize(self.device)
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
elif self.optype == MovementOps:
real_src = self.op.src[0].realize(self.device)
return real_src.movement_op(self.op.op, self.op.arg), [real_src]
elif self.optype == UnaryOps:
real_src_x = self.op.src[0].realize(self.device)
return real_src_x.unary_op(self.op.op), [real_src_x]
elif self.optype == BinaryOps:
real_src_x = self.op.src[0].realize(self.device)
real_src_y = self.op.src[1].realize(self.device)
return real_src_x.binary_op(self.op.op, real_src_y), [real_src_x, real_src_y]
elif self.optype == ProcessingOps:
real_src_x = self.op.src[0].realize(self.device)
real_src_w = self.op.src[1].realize(self.device)
return real_src_x.processing_op(self.op.op, real_src_w, self.op.arg), [real_src_x, real_src_w]
# **** lazy operations ****
class LazyOp(NamedTuple):
op: Op
src: Tuple[Union[LazyOp, LazyBuffer]]
arg: Any = None
# TODO: add dest to support multiple outputs
def get_lazybuffers(op:LazyOp) -> List[LazyBuffer]: return functools.reduce(operator.add, [get_lazybuffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
LAZY = int(os.getenv("LAZY", "0"))
class LazyBuffer:
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp):
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape = self.st.shape
self.optype, self.op = optype, op
self.realized = None
self.device = device
if not LAZY: self.realize()
# this produces a device buffer
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
if required_device is not None: assert required_device == self.device
if self.realized is None:
# we haven't realized the Buffer yet
self.realized, real_srcs = _realize(self)
# in lazy mode, we don't log until we realize
log_op(self.optype, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
# no need to keep the op after realization
del self.op
assert self.realized.shape == self.shape
assert isinstance(self.realized, Device.buffers[self.device])
return self.realized
@staticmethod
def fromCPU(x, device):
return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x))
def toCPU(x):
return x.realize().toCPU()
def unary_op(x:LazyBuffer, op:UnaryOps) -> LazyBuffer:
return LazyBuffer(x.device, x.shape, UnaryOps, LazyOp(op, (x,)))
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer:
return LazyBuffer(x.device, x.shape, BinaryOps, LazyOp(op, (x,y)))
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int]) -> LazyBuffer:
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
return LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x,), arg))
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
def processing_op(ctx, op:ProcessingOps, x, y, C):
ret = x.processing_op(op, y, C)
if 'LAZY' not in ctx.device: log_op(ProcessingOps, op, ret, [x, y])
assert isinstance(ret, ctx.buffer)
assert ret.shape == C.out_shape
return ret

View File

@@ -1,68 +1,82 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
import os, inspect, functools, importlib
import inspect, functools, importlib
import numpy as np
from tinygrad.helpers import prod
from typing import List
from tinygrad.ops import Device
# **** enumerate supported devices ****
class Device:
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "llops")))
imports = dict(enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]))
DEFAULT = None
buffers, llops = {}, {}
for i,op in imports.items():
name = op[len("ops_"):].upper()
vars()[name] = name
DEFAULT = name if os.environ.get(name, 0) == "1" else DEFAULT
try:
llops[name] = importlib.import_module('tinygrad.llops.'+op)
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
buffers[name] = find_buffer(llops[name], name)
except ImportError as e:
print(op, "not available", e)
DEFAULT = CPU if DEFAULT is None else DEFAULT
from tinygrad.ops import LazyBuffer
# **** start with two base classes, Tensor and Function ****
class Tensor:
did_float_warning = False
training = False
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
self.device, self.data = device, self._move_data(data, device)
if isinstance(data, list):
data = np.array(data, dtype=np.float32)
elif isinstance(data, LazyBuffer) and data.device != device:
# TODO: this has to realize, it shouldn't have to
data = data.realize().toCPU()
if isinstance(data, np.ndarray):
if data.shape == tuple(): data = data.reshape((1,))
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
elif isinstance(data, LazyBuffer):
self.lazydata = data
else:
raise Exception(f"can't create Tensor from {data}")
# tensors have gradients, buffers do not
self.grad, self.requires_grad = None, requires_grad
# internal variables used for autograd graph construction
self._ctx = None
def __repr__(self):
return f"<Tensor {self.data!r} with grad {(self.grad.data if self.grad else None)!r}>"
return f"<Tensor {self.lazydata!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
@property
def shape(self): return self.lazydata.shape
# dtype handling was very broken. it's always float32 now
@property
def dtype(self): return np.float32
@property
def device(self): return self.lazydata.device
# ***** data handlers ****
def realize(self):
# TODO: once lazy is upstreamed, we can remove this check
if getattr(self.data, 'realize', None) is not None:
self.data.realize()
self.lazydata.realize()
def assign(self, x):
if not isinstance(x, Tensor):
x = Tensor(x)
assert self.shape == x.shape
self.data = x.data
self.lazydata = x.lazydata
return x
@property
def shape(self):
return self.data.shape
@staticmethod
def _get_data_dtype(data):
return data.getdtype() if getattr(data, 'getdtype', None) else (data.dtype if getattr(data, 'dtype', None) else np.float32)
def detach(self):
return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self):
return np.array(self.lazydata.toCPU())
# TOOD: this keeps the legacy behavior working, remove it after refactor
@property
def dtype(self):
return Tensor._get_data_dtype(self.data)
def data(self):
return self.numpy()
def to_(self, device):
self.device = device
if self.grad: self.grad.device = device
def to(self, device):
ret = Tensor(self.lazydata, device)
if self.grad: ret.grad = self.grad.to(device)
return ret
# ***** creation helper functions *****
@@ -114,7 +128,7 @@ class Tensor:
if not any(x.requires_grad for x in t0._ctx.parents):
continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.data)
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
@@ -123,41 +137,6 @@ class Tensor:
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
# ***** tinygrad supports many devices *****
@staticmethod
def _move_data(data, device):
if isinstance(data, Device.buffers[device]):
return data
if isinstance(data, list):
# TODO: don't use np.array here, support Tensor creation direct to device
data = np.array(data, dtype=np.float32)
if isinstance(data, np.ndarray):
data = data.view(Device.buffers[Device.CPU])
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
# warning? float64 is actually needed for numerical jacobian
print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}")
Tensor.did_float_warning = True
data = data.toCPU().view(Device.buffers[Device.CPU])
return Device.buffers[device].fromCPU(data)
def to_(self, device):
self.data, self.device = self._move_data(self.data, device), device
if self.grad: self.grad.to_(device)
def to(self, device):
ret = Tensor(self.data, device)
if self.grad: ret.grad = self.grad.to(device)
return ret
def detach(self):
return Tensor(self.data, device=self.device, requires_grad=False)
def numpy(self):
return np.array(self.cpu().data)
# ***** non first class ops (hlops) *****
def __getitem__(self, val):
@@ -350,8 +329,7 @@ class Tensor:
return y.div((y*y).mean(axis=-1, keepdim=True).add(eps).sqrt())
# An instantiation of the Function is the Context
from tinygrad.ops import Ops
class Function(Ops):
class Function:
def __init__(self, device, *tensors):
self.device = device
self.parents = tensors
@@ -359,16 +337,14 @@ class Function(Ops):
self.requires_grad = any(self.needs_input_grad)
self.saved_tensors = []
buffer = property(lambda self: Device.buffers[self.device])
def save_for_backward(self, *x):
if self.requires_grad:
self.saved_tensors.extend(x)
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
self.saved_tensors.extend(x)
@classmethod
def apply(cls, *x:List[Tensor], **kwargs):
ctx = cls(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs),
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs),
device=ctx.device, requires_grad=ctx.requires_grad)
if ctx.requires_grad: ret._ctx = ctx # used by autograd engine
return ret