Lazy refactor (#538)

* refactor lazy to return ASTs

* a lil cleaner

* oops, compare ids

* gate on GRAPH

* cleanups

* less calls to log_op

* simpler

* realize_buffers -> map_buffers

* even simpler

* think in asts

* a lil cleaner

* NOOP means contiguous
This commit is contained in:
George Hotz
2023-02-07 11:53:21 -06:00
committed by GitHub
parent 02d8cb0959
commit 001cc96e25
3 changed files with 55 additions and 72 deletions

View File

@@ -3,7 +3,7 @@ import atexit
import itertools
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.helpers import getenv
GRAPH = getenv("GRAPH", 0)
@@ -34,7 +34,13 @@ if GRAPH:
atexit.register(save_graph_exit)
global_num_max = 0
def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[DeviceBuffer]):
def log_op(ret : DeviceBuffer, ast : LazyOp):
if not DEBUG and not GRAPH: return
op : List[Op] = [x.op for x in get_lazyops(ast)]
inp : List[DeviceBuffer] = get_buffers(ast)
if len(inp) == 1 and inp[0] == ret: return # don't log self loops
oporder = [LoadOps, ProcessingOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if DEBUG >= 3:
print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")

View File

@@ -4,7 +4,7 @@ from copy import copy
import sys, weakref
from tinygrad.helpers import ConvArgs, get_available_llops, prod
from tinygrad.shape import ShapeTracker
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, DEBUG
from tinygrad.graph import log_op
from tinygrad.helpers import getenv
@@ -28,59 +28,22 @@ class Device:
vars()[name] = name
# **** realize helpers ****
def realize_buffers(real_srcs, x:LazyOp) -> LazyOp:
def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
if x in real_srcs:
return realize_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple(realize_buffers(real_srcs, y) for y in x.src), x.arg)
return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple(map_buffers(real_srcs, y) for y in x.src), x.arg)
# **** realize functions ****
# TODO: make all _realize functions return an AST, perhaps unrealized
# NOTE: loadops and movementops aren't valid ASTs and won't become kernels
def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], Optional[OpType]]:
if self.op.op == LoadOps.FROMCPU:
return Device._buffers[self.device].fromCPU(self.op.arg), [], LoadOps
elif self.op.op == LoadOps.CONTIGUOUS:
# under the hood, this is an AST or a no op. rename to MetaOps?
real_src = self.op.src[0].realize(self.device)
ret = real_src.contiguous()
return ret, [real_src], LoadOps if id(ret) != id(real_src) else None
else:
raise NotImplementedError(f"unknown LoadOp {self.op.op}")
def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
# TODO: add MetaOps.TOIMAGE instead?
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
return _realize_reduceops_w_shape(src, output_shape = self.op.arg)
real_src = src.realize(self.device)
return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps
def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
ast = LazyOp(self.op.op, tuple(x.realize(self.device) for x in self.op.src), self.op.arg)
return self.dbuffer.exec_ast(ast), get_buffers(ast), ProcessingOps
# this supports late merging an upstream Elementwise op
def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
# this is the new version, deprecate _processing_op
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)}
ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg)
else:
ast = LazyOp(self.op.op, (src.realize(self.device),), self.op.arg)
if output_shape is not None: ast = LazyOp(MovementOps.RESHAPE, (ast, ), output_shape)
return self.dbuffer.exec_ast(ast), get_buffers(ast), ReduceOps
def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: return _realize_reduceops_w_shape(self)
src = src.op
return LazyOp(self.op.op, (src,), self.op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)}
op_type : OpType = BinaryOps
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
if DEBUG >= 3:
for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
if x.optype in [ProcessingOps,ReduceOps] and x.realized is None:
@@ -89,22 +52,15 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
for tx in x.children: print("x", tx)
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape = self.shape
intermediate_shape : Tuple[int, ...] = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
if psrcs[0][1].optype == ProcessingOps:
real_srcs[psrcs[0][0]] = psrcs[0][1].op
for x in psrcs[0][1].op.src:
real_srcs[x] = x.realize(self.device)
op_type = ProcessingOps
top = psrcs[0][1].op # _ast_processingops
elif psrcs[0][1].optype == ReduceOps:
src = psrcs[0][1].op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
for x in get_buffers(real_srcs[psrcs[0][0]]): # type: ignore
# these are the early buffers
real_srcs[x] = x.realize(self.device)
op_type = ReduceOps
top = _ast_reduceops(psrcs[0][1])
real_srcs[psrcs[0][0]] = top
real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrcs[0][0].shape != psrcs[0][1].shape:
intermediate_shape = psrcs[0][1].shape
@@ -114,11 +70,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device)
ast = LazyOp(MovementOps.RESHAPE, (realize_buffers(real_srcs, self.op), ), self.shape)
return self.dbuffer.exec_ast(ast), get_buffers(ast), op_type
_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops}
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
return LazyOp(MovementOps.RESHAPE, (map_buffers(real_srcs, self.op), ), self.shape)
# **** lazy operations ****
@@ -159,14 +112,38 @@ class LazyBuffer:
if required_device is not None:
assert required_device == self.device
if self.realized is None:
# we haven't realized the Buffer yet
self.realized, real_srcs, real_type = _realize[self.optype](self)
# in lazy mode, we don't log until we realize
if real_type is not None:
log_op(real_type, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs)
# get real ops first
if self.op.op == LoadOps.FROMCPU:
self.realized = Device._buffers[self.device].fromCPU(self.op.arg)
ast = LazyOp(self.op.op, tuple())
elif self.op.op == LoadOps.CONTIGUOUS:
real_src = self.op.src[0].realize(self.device)
self.realized = real_src.contiguous()
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == MovementOps:
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
# it's okay to add a RESHAPE to the ast here
ast = LazyOp(MovementOps.RESHAPE, (_ast_reduceops(src), ), self.op.arg)
else:
# movement ops aren't an AST, just run them
real_src = src.realize(self.device)
self.realized = real_src.movement_op(self.op.op, self.op.arg)
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == ProcessingOps: ast = self.op # no ast modifications for ProcessingOps
elif self.optype == ReduceOps: ast = _ast_reduceops(self)
elif self.optype == BinaryOps: ast = _ast_binaryops(self)
# no need to keep the op after realization
del self.op
# run the ast if we still have to, and log the op
if self.realized is None:
self.realized = self.dbuffer.exec_ast(map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast))
log_op(self.realized, ast)
assert self.realized.shape == self.shape
assert isinstance(self.realized, Device._buffers[self.device])
return self.realized

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import shape_to_axis
class CPUBuffer(np.ndarray, GenericExecAST):
fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
UnaryOps.NOOP: lambda x: x[:].contiguous(), UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),