mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
realize hotspots (#1059)
* realize hotspots * no str check * minor changes * make this an assert * faster and more readable * nicer self.buffers * tests for weak op + LAZYCACHE=0
This commit is contained in:
@@ -11,7 +11,7 @@ ShapeType = Tuple[int, ...]
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
||||
def argfix(*x):
|
||||
def argfix(*x):
|
||||
if x[0].__class__ in {tuple, list}:
|
||||
try: return tuple(x[0])
|
||||
except IndexError: return tuple()
|
||||
@@ -148,6 +148,7 @@ class LightWeakValueDictionary:
|
||||
if o is None: raise KeyError(key)
|
||||
else: return o
|
||||
|
||||
def __len__(self): return len(self.data)
|
||||
def __delitem__(self, key): del self.data[key]
|
||||
def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key)
|
||||
|
||||
def __contains__(self, key): return key in self.data
|
||||
def __contains__(self, key): return key in self.data
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys, importlib, inspect, functools, pathlib
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import GRAPH, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction
|
||||
@@ -17,6 +17,7 @@ sys.setrecursionlimit(10000)
|
||||
|
||||
OPT = getenv("OPT", 2)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
LAZYCACHE = getenv("LAZYCACHE", 1)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
@@ -27,7 +28,7 @@ PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
||||
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
src = self.op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype is BinaryOps and len(src.children) <= 1:
|
||||
src = src.op # type: ignore
|
||||
return LazyOp(self.op.op, (src,), self.op.arg)
|
||||
|
||||
@@ -65,17 +66,15 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement
|
||||
|
||||
lazycache: LightWeakValueDictionary = LightWeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
|
||||
|
||||
# fromcpu aren't cached
|
||||
if optype == LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}: return LazyBuffer(device, st, optype, op, dtype)
|
||||
|
||||
#print("create_lazybuffer", device, shape, optype, op, dtype)
|
||||
|
||||
# NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
|
||||
# get_weakop makes all the LazyBuffers in the op have a weakref
|
||||
wop = (device, dtype, optype, ref(op))
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype)
|
||||
|
||||
# wop is the deduping key. i feel this used to compare more deeply
|
||||
wop = (device, dtype, optype, ref(op))
|
||||
if wop in lazycache: return lazycache[wop]
|
||||
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
||||
return ret
|
||||
|
||||
@@ -83,7 +82,7 @@ class LazyBuffer:
|
||||
__slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__'
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
|
||||
self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
|
||||
self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
|
||||
self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype
|
||||
self.realized: Optional[RawBuffer] = src
|
||||
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
|
||||
@@ -110,16 +109,15 @@ class LazyBuffer:
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized:
|
||||
# get real ops first
|
||||
if self.optype in REALIZE_DISPATCHER:
|
||||
self.op = REALIZE_DISPATCHER[self.optype](self)
|
||||
elif self.op.op in REALIZE_DISPATCHER:
|
||||
REALIZE_DISPATCHER[self.op.op](self)
|
||||
if self.optype is BinaryOps: self.op = _ast_binaryops(self)
|
||||
elif self.optype is ReduceOps: self.op = _ast_reduceops(self)
|
||||
elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self)
|
||||
# run the ast if we still have to, and log the op
|
||||
if not self.realized:
|
||||
for x in self.op.buffers: x.realize()
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if self.optype != MovementOps and self.dtype.__class__ is ImageDType and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
|
||||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
# put CAST before the final RESHAPE
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
|
||||
@@ -130,11 +128,11 @@ class LazyBuffer:
|
||||
|
||||
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
# HACK: allow hot casting of images
|
||||
assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
||||
assert self.realized.dtype == self.dtype or self.dtype.__class__ is ImageDType, f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
||||
self.dtype = self.realized.dtype
|
||||
|
||||
# log to the graph
|
||||
if self.realized.__class__ is not RawConst or GRAPH >= 2:
|
||||
if (DEBUG or GRAPH) and (self.realized.__class__ is not RawConst or GRAPH >= 2):
|
||||
from tinygrad.graph import log_op
|
||||
log_op(self, self.op)
|
||||
|
||||
@@ -338,15 +336,13 @@ def _realize_const(buffer: LazyBuffer) -> None:
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
|
||||
|
||||
REALIZE_DISPATCHER: Dict[Any, Callable] = {
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.CONST: _realize_const,
|
||||
ReduceOps: _ast_reduceops,
|
||||
BinaryOps: _ast_binaryops,
|
||||
}
|
||||
|
||||
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
|
||||
@@ -7,7 +7,8 @@ base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "image
|
||||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
|
||||
@@ -33,8 +33,11 @@ class LazyOp:
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.arg = arg
|
||||
# TODO: this hasattr is required because the linearizer's key function maps the buffers to ints
|
||||
self.buffers = functools.reduce(lambda x,s: (x+s.buffers) if hasattr(s, 'buffers') else x, src, tuple())
|
||||
try:
|
||||
self.buffers = tuple([y for x in src for y in x.buffers])
|
||||
except AttributeError:
|
||||
# NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
|
||||
pass
|
||||
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
@@ -46,10 +49,10 @@ class LazyOp:
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
# Any == Union[LazyBuffer, DeviceBuffer]
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
from tinygrad.lazy import elementwise_op
|
||||
assert self.op in BinaryOps or self.op in UnaryOps
|
||||
return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore
|
||||
|
||||
@@ -468,7 +468,8 @@ class Tensor:
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
Reference in New Issue
Block a user