Optimizations in lazy.py (#987)

* optimizations in lazy.py

* make mypy happy with stubs and fix the graph import hack

* merge conflict in helpers.py
This commit is contained in:
Rayan Hatout
2023-06-26 21:55:42 +01:00
committed by GitHub
parent 8bea6b6d35
commit dedbd970aa
9 changed files with 387 additions and 244 deletions

View File

@@ -4,13 +4,12 @@ from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable
# bottom ones are asm only
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
@@ -94,27 +93,27 @@ class Linearizer:
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
# get the output buffers
self.bufs = [output_buffer] + dedup(get_buffers(ast))
self.bufs = [output_buffer] + dedup(ast.buffers)
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}"
self.key = (ast.map_buffers({x:i for i,x in enumerate(self.bufs)}).key, tuple([x.key for x in self.bufs]))
def process(self) -> None:
if hasattr(self, "sts"): return # already processed
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(self.ast)
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
# there's only allowed to be one reduceop
reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps]
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
# get earlybufs, before the one reduce op
self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else []
self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []
# create new shapetrackers inside this kernel, we will permute them
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
@@ -178,7 +177,7 @@ class Linearizer:
for k,out_tokens in self._group_float4(i, load_offset).items():
idxs = [x[2]-out_tokens[0][2] for x in out_tokens]
valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3])
if any(idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
for x in out_tokens: load_offset_new[x[1]] = x
else:
@@ -306,13 +305,13 @@ class Linearizer:
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
# end the late reduce loop
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
# run late AST
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
@@ -334,17 +333,17 @@ class Linearizer:
return out
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
if not isinstance(x, LazyOp): return loaded_buffers[x]
if x.__class__ is not LazyOp: return loaded_buffers[x]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL:
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL:
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
# TODO: fold float4 into a single uop when possible.
if isinstance(x.op, (ReduceOps, FusedOps)):
if x.op.__class__ in {ReduceOps, FusedOps}:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
@@ -431,7 +430,7 @@ class Linearizer:
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
all_ones = [all([st.shape[i]==1 for st in self.sts]) for i in range(self.shape_len)]
# keep at least 1 one
if all(all_ones): all_ones[-1] = False
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
@@ -456,14 +455,14 @@ class Linearizer:
else: rets[j].append((shapes[j][i], strides[j][i]))
# do the reshapes
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ********************
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType):
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.shift_to(unit_stride_axes_mul_4[0], 4)

View File

@@ -5,13 +5,11 @@ except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List, Optional
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import getenv, DEBUG, GlobalCounters
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp
from tinygrad.tensor import LazyBuffer
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
from tinygrad.runtime.lib import RawConst
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
# **** debugging and graphing ****
G = nx.DiGraph() if nx is not None else None
@@ -52,8 +50,8 @@ def str_dtype(dtyp):
def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
if show_graph is None: show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
op: List[Op] = [x.op for x in get_lazyops(ast)]
inp: List[LazyBuffer] = [x for x in get_buffers(ast) if not isinstance(x.realized, RawConst) or GRAPH > 1]
op: List[Op] = [x.op for x in ast.get_lazyops()]
inp: List[LazyBuffer] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1

View File

@@ -1,18 +1,23 @@
from __future__ import annotations
import platform
from dataclasses import dataclass, asdict
import os, math, functools, time, re
import os, functools, platform, time, re
from weakref import KeyedRef, ref
from _weakref import _remove_dead_weakref # type: ignore
import numpy as np
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
from math import prod # noqa: F401 # pylint:disable=unused-import
ShapeType = Tuple[int, ...]
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
def argfix(*x):
if x[0].__class__ in {tuple, list}:
try: return tuple(x[0])
except IndexError: return tuple()
return tuple(x)
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
def all_same(items): return all([x == items[0] for x in items]) if len(items) > 1 else True
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
@@ -43,6 +48,7 @@ class ContextVar:
def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(object):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
@@ -60,6 +66,8 @@ class DType(NamedTuple):
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{self.name}"
@property
def key(self): return (self.name)
# dependent typing?
class ImageDType(DType):
@@ -70,7 +78,6 @@ class ImageDType(DType):
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@dataclass
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
@@ -79,7 +86,9 @@ class dtypes:
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", bool)
float16: Final[DType] = DType(0, 2, "half", np.float16)
half = float16
@@ -97,6 +106,9 @@ class dtypes:
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
@@ -106,3 +118,36 @@ class GlobalCounters:
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
# Stripped down version of a WeakSet
class LightWeakSet:
__slots__ = 'data', '_remove', '__weakref__'
def __init__(self):
self.data = set()
def _remove(item, selfref=ref(self)):
self = selfref()
if self: self.data.discard(item)
self._remove = _remove
def __len__(self): return len(self.data)
def add(self, item): self.data.add(ref(item, self._remove))
def discard(self, item): self.data.discard(ref(item))
# Stripped down version of a WeakValueDictionary
class LightWeakValueDictionary:
__slots__ = 'data', '_remove', '__weakref__'
def __init__(self):
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
self = selfref()
if self: _atomic_removal(self.data, wr.key)
self._remove = remove
self.data = {}
def __getitem__(self, key):
o = self.data[key]()
if o is None: raise KeyError(key)
else: return o
def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key)
def __contains__(self, key): return key in self.data

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, cast
import operator
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
import sys, importlib, inspect, functools, pathlib
from weakref import ref
import numpy as np
from weakref import WeakValueDictionary, ref, WeakSet
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped
from tinygrad.helpers import GRAPH, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, LoadOps, OpType, LazyOp
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@@ -25,23 +27,23 @@ PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op
return LazyOp(self.op.op, (src,), self.op.arg)
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op # type: ignore
return LazyOp(self.op.op, (src,), self.op.arg, src.get_buffers())
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[int, ...] = self.shape
if len(psrcs) >= 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and len(psrcs) >= 1:
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
if psrc[1].optype == ReduceOps:
top = _ast_reduceops(psrc[1])
real_srcs[psrc[0]] = top
real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrc[0].shape != psrc[1].shape:
@@ -51,124 +53,91 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None: real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
ast = map_buffers(real_srcs, self.op)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape)
ast = self.op.map_buffers(real_srcs)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape, ast.buffers) if intermediate_shape != self.shape else ast
# **** lazy operations ****
def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple([get_weakop(x) if x.__class__ is LazyOp else ref(x) for x in op.src]), op.arg)
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def replace_with_movement_ops(y:Union[LazyOp, LazyBuffer], ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> LazyBuffer:
if isinstance(y, LazyBuffer):
for op, arg in ops: y = y.movement_op(op, arg)
return y
assert y.op in BinaryOps or y.op in UnaryOps
return elementwise_op(y.op, *[replace_with_movement_ops(z, ops) for z in y.src], arg=y.arg) # type: ignore
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
lazycache: WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType):
st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
lazycache: LightWeakValueDictionary = LightWeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
# fromcpu aren't cached
if optype == LoadOps and op.op in [LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST]: return LazyBuffer(device, st, optype, op, dtype)
if optype == LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}: return LazyBuffer(device, st, optype, op, dtype)
#print("create_lazybuffer", device, shape, optype, op, dtype)
# NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
# get_weakop makes all the LazyBuffers in the op have a weakref
wop = (device, dtype, optype, get_weakop(op))
wop = (device, dtype, optype, ref(op))
if wop not in lazycache: lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
else: ret = lazycache[wop]
if wop in lazycache: return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
return ret
class LazyBuffer:
__slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__'
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, src:Union[LazyOp, RawBuffer], dtype:DType):
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype
self.realized: Optional[RawBuffer] = src if isinstance(src, RawBuffer) else None
self.realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet[LazyBuffer] = WeakSet()
self.children: LightWeakSet = LightWeakSet()
# NOTE: op should be read only after construction of LazyBuffer
if isinstance(src, LazyOp):
self.op: LazyOp = src
for x in get_buffers(self.op): x.children.add(self)
if op:
self.op: LazyOp = op
for x in op.buffers: x.children.add(self)
if not LAZY: self.realize()
# log phantom ops to the graph
from tinygrad.graph import log_op, GRAPH
if GRAPH >= 3: log_op(self, self.op, phantom=True)
if GRAPH >= 3:
from tinygrad.graph import log_op
log_op(self, self.op, phantom=True)
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype.key, self.realized.key, self.st.key)
return (self.dtype.key, self.op.op, self.st.key)
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
def realize(self:LazyBuffer) -> LazyBuffer:
if self.realized is None:
if not self.realized:
# get real ops first
if self.op.op == LoadOps.CONTIGUOUS:
realized = self.op.src[0].realize().realized
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
# no need to run an AST, this is already contiguous
self.realized = realized
else:
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
self.op = LazyOp(UnaryOps.NOOP, self.op.src)
elif self.op.op == LoadOps.CUSTOM:
# this needs to immediately realize
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
elif self.op.op == LoadOps.FROM:
rawbuf = self.op.src[0].realize()
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped):
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer())
else:
self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args())
elif self.optype == LoadOps:
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
if self.op.op == LoadOps.EMPTY:
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
elif self.op.op == LoadOps.RAND:
rng = np.random.default_rng(self.op.arg)
assert self.dtype.np is not None, "internal dtypes don't work with LoadOps.RAND"
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
elif self.op.op == LoadOps.CONST:
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):
self.realized = RawConst(1, self.dtype, float(self.op.arg))
else:
self.realized = Device[self.device].buffer.fromCPU(np.array(self.op.arg, dtype=self.dtype.np), **self._device_extra_args())
# these can be late folded and change the op to go further back in the graph
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
if self.optype in REALIZE_DISPATCHER:
self.op = REALIZE_DISPATCHER[self.optype](self)
elif self.op.op in REALIZE_DISPATCHER:
REALIZE_DISPATCHER[self.op.op](self)
# run the ast if we still have to, and log the op
if self.realized is None:
for x in get_buffers(self.op): x.realize()
if not self.realized:
for x in self.op.buffers: x.realize()
# HACK: image shape can be wrong, hot cast it back to a normal float
if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if self.optype != MovementOps and self.dtype.__class__ is ImageDType and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
if self.op.op == MovementOps.RESHAPE:
# put CAST before the final RESHAPE
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
else:
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
self.dtype = dtypes.float32
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
# HACK: allow hot casting of images
assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
self.dtype = self.realized.dtype
# log to the graph
from tinygrad.graph import log_op, GRAPH
if not isinstance(self.realized, RawConst) or GRAPH >= 2: log_op(self, self.op)
if self.realized.__class__ is not RawConst or GRAPH >= 2:
from tinygrad.graph import log_op
log_op(self, self.op)
# no need to keep the op after realization
del self.op
@@ -176,17 +145,17 @@ class LazyBuffer:
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, RawNumpyBuffer.fromCPU(x), dtypes.from_np(x.dtype))
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None, ()), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
.reshape_op((1,)*len(self.shape)).expand_op(self.shape)
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
@@ -198,89 +167,99 @@ class LazyBuffer:
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def contiguous(self:LazyBuffer) -> LazyBuffer:
if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, tuple(srcs), new_shape), self.dtype)
# shrink -> stride -> permute -> reshape -> pad -> expand
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
# very instant nop
if op == MovementOps.RESHAPE and self.shape == arg: return self
# TODO: look into why that copy is needed
local_st = ShapeTracker(self.shape).movement_op(op, arg)
# instant nops
if local_st.contiguous and self.shape == local_st.shape: return self
# two ops in a row is one op. merge them if unresolved
if self.realized is None and self.op.op == op:
# TODO: why is deleting self from children needed? shouldn't GC do it?
self.op.src[0].children.discard(self)
if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: return self.op.src[0].movement_op(op, arg)
if op == MovementOps.SHRINK: return self.op.src[0].movement_op(op, tuple((b1+b2, b1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
if op == MovementOps.PERMUTE: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
if op == MovementOps.PAD: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
if op == MovementOps.STRIDE: return self.op.src[0].movement_op(op, tuple(i*j for i,j in zip(arg, self.op.arg)))
# push permutes before reduce ops
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple(self.op.arg[arg[i]] for i in range(len(arg)))
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.movement_op(op, arg).reduce_op(rop, narg)
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# move permutes before expands (always, this is safe)
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == MovementOps.EXPAND:
self.op.src[0].children.discard(self)
return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg))
# move permutes before reshapes if we can
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # this changes nothing?
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(flatten(shape_idx_groups[i] for i in arg))) \
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and (op in [MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE] or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: # and op != MovementOps.EXPAND and (op != MovementOps.PAD or (SHUFFLE_PAD_OPS and all(x.op != BinaryOps.DIV for x in get_lazyops(self.op)))):
return replace_with_movement_ops(self.op, [(op, arg)])
# create the buffer
ret = create_lazybuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg), self.dtype)
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None, (self,)), self.dtype)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0:
return self.op.replace_with_movement_ops([(op, arg)])
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg, (self,)), self.dtype)
if REMOVE_MOVEMENT_NOPS and not self.realized and not ret.realized and ret.st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
return root.movement_op(MovementOps.RESHAPE, ret.st.shape)
return root.reshape_op(ret.st.shape)
return ret
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape_op(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b == 0 and e == 0 for b,e in arg]): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand_op(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg]))
if not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))])
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg)
# move permutes before expands (always, this is safe)
if self.op.op == MovementOps.EXPAND:
return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg]))
# move permutes before reshapes if we can
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \
.reshape_op(ShapeTracker(self.st).permute(arg).shape)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
local_st = ShapeTracker(self.shape).stride(arg)
if self.shape == local_st.shape and local_st.contiguous: return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg)))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def get_lazyops(self) -> List[Any]: return []
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
mops: List[Tuple[MovementOps, Any]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = bx.op.src[0]
bx = cast(LazyBuffer, bx.op.src[0])
# NOTE: can't push pads with a div
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all([x[0] != MovementOps.PAD for x in mops]) or all([x.op != BinaryOps.DIV for x in bx.op.get_lazyops()])):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)
return tuple(new_srcs)
@@ -290,29 +269,30 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
# get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs = []
if PUSH_CONTIGUOUS and any([not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs]):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
x.op.src[0].children.discard(x)
new_srcs.append(x.op.src[0])
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
else:
new_srcs.append(x)
return elementwise_op(op, *new_srcs, arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS:
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
srcs = tuple([x.op if x.optype == BinaryOps and len(x.children) == 0 and not x.realized else x for x in srcs]) # type: ignore
return create_lazybuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype)
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
@@ -324,3 +304,58 @@ class _Device:
except Exception: pass
return "CPU"
Device = _Device()
def _realize_contiguous(buffer: LazyBuffer) -> None:
realized = buffer.op.src[0].realize().realized
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape):
# no need to run an AST, this is already contiguous
buffer.realized = realized
else:
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
buffer.op = LazyOp(UnaryOps.NOOP, buffer.op.src)
def _realize_custom(buffer: LazyBuffer) -> None:
# this needs to immediately realize
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
def _realize_from(buffer: LazyBuffer) -> None:
rawbuf = buffer.op.src[0].realize()
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
def _realize_empty(buffer: LazyBuffer) -> None:
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
def _realize_rand(buffer: LazyBuffer) -> None:
rng = np.random.default_rng(buffer.op.arg)
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
def _realize_const(buffer: LazyBuffer) -> None:
if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'):
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
REALIZE_DISPATCHER: Dict[Any, Callable] = {
LoadOps.CONTIGUOUS: _realize_contiguous,
LoadOps.CUSTOM: _realize_custom,
LoadOps.FROM: _realize_from,
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.CONST: _realize_const,
ReduceOps: _ast_reduceops,
BinaryOps: _ast_binaryops,
}
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape_op,
MovementOps.EXPAND: LazyBuffer.expand_op,
MovementOps.SHRINK: LazyBuffer.shrink_op,
MovementOps.PERMUTE: LazyBuffer.permute_op,
MovementOps.PAD: LazyBuffer.pad_op,
MovementOps.STRIDE: LazyBuffer.stride_op,
}

View File

@@ -1,6 +1,6 @@
from typing import Tuple, Optional
from tinygrad.helpers import argsort, ShapeType
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
import math
@@ -10,6 +10,7 @@ class Contiguous(Function):
def backward(self, grad_output): return grad_output
class Cast(Function):
__slots__ = "input_dtype"
def forward(self, x, dtype):
self.input_dtype = x.dtype
return x.cast(dtype)
@@ -19,6 +20,7 @@ class Cast(Function):
# ************* unary ops *************
class Sin(Function):
__slots__ = "x"
def forward(self, x: LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.SIN)
@@ -26,6 +28,7 @@ class Sin(Function):
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
return self.ret
@@ -35,6 +38,7 @@ class Relu(Function):
return mask.binary_op(BinaryOps.MUL, grad_output)
class Log(Function):
__slots__ = "x"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e)))
@@ -43,6 +47,7 @@ class Log(Function):
return grad_output.binary_op(BinaryOps.DIV, self.x)
class Exp(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2)
return self.ret
@@ -53,27 +58,29 @@ class Exp(Function):
# ************* reduce ops *************
class Sum(Function):
__slots__ = "input_shape"
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.reduce_op(ReduceOps.SUM, new_shape)
def backward(self, grad_output):
return grad_output.movement_op(MovementOps.EXPAND, self.input_shape)
return grad_output.expand_op(self.input_shape)
class Max(Function):
__slots__ = "x", "ret"
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.movement_op(MovementOps.EXPAND, self.x.shape))
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape))
# sum of locations, averaged
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).movement_op(MovementOps.EXPAND, self.x.shape)
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, self.x.shape)
grad_output_expanded = grad_output.expand_op(self.x.shape)
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops *************
@@ -83,6 +90,7 @@ class Equal(Function):
return x.binary_op(BinaryOps.CMPEQ, y)
class Maximum(Function):
__slots__ = "x", "y", "ret"
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
self.ret = x.binary_op(BinaryOps.MAX, y)
@@ -113,6 +121,7 @@ class Sub(Function):
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
class Mul(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer):
self.x, self.y = x, y
return x.binary_op(BinaryOps.MUL, y)
@@ -122,6 +131,7 @@ class Mul(Function):
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Pow(Function):
__slots__ = 'x', 'y', 'ret'
def forward(self, x:LazyBuffer, y:LazyBuffer):
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
return self.ret
@@ -131,6 +141,7 @@ class Pow(Function):
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
class Div(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.DIV, y)
@@ -143,49 +154,55 @@ class Div(Function):
# NOTE: this is sum in reverse
class Expand(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.movement_op(MovementOps.EXPAND, shape)
return x.expand_op(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
class Reshape(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.movement_op(MovementOps.RESHAPE, shape)
return x.reshape_op(shape)
def backward(self, grad_output):
return grad_output.movement_op(MovementOps.RESHAPE, self.input_shape)
return grad_output.reshape_op(self.input_shape)
class Permute(Function):
__slots__ = 'input_order'
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
self.input_order = order
return x.movement_op(MovementOps.PERMUTE, order)
return x.permute_op(order)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.movement_op(MovementOps.PERMUTE, argsort(self.input_order))
return grad_output.permute_op(argsort(self.input_order))
class Pad(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple((p[0], s+p[0]) for s,p in zip(x.shape, arg))
return x.movement_op(MovementOps.PAD, arg)
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad_op(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.movement_op(MovementOps.SHRINK, self.narg)
return grad_output.shrink_op(self.narg)
class Shrink(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple((p[0], s-p[1]) for s,p in zip(x.shape, arg))
return x.movement_op(MovementOps.SHRINK, arg)
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink_op(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.movement_op(MovementOps.PAD, self.narg)
return grad_output.pad_op(self.narg)
class Flip(Function):
__slots__ = 'arg'
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
self.arg = tuple(-1 if i in axis else 1 for i in range(len(x.shape)))
return x.movement_op(MovementOps.STRIDE, self.arg)
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride_op(self.arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.movement_op(MovementOps.STRIDE, self.arg)
return grad_output.stride_op(self.arg)

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
import functools, operator, time
import functools, time
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.runtime.lib import RawBuffer, RawConst
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
@@ -19,19 +21,62 @@ class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
class LazyOp(NamedTuple):
op: Op
# Any == Union[LazyOp, LazyBuffer, DeviceBuffer]
src: Tuple[Any, ...] # type: ignore
arg: Any = None
class LazyOp:
# TODO: add dest to support multiple outputs. on second thought, multiple outputs will have multiple LazyOps.
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
op: Op
src: Tuple[Union[LazyOp, LazyBuffer], ...]
arg: Any
buffers: Tuple[LazyBuffer, ...]
# Any == Union[LazyBuffer, DeviceBuffer]
def get_buffers(op:LazyOp) -> List[Any]: return functools.reduce(operator.add, [get_buffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None, buffers: Optional[Tuple[LazyBuffer, ...]] = None):
self.op = op
self.src = src
self.arg = arg
if not buffers:
buffers = tuple()
for s in src:
try: buffers += s.get_buffers()
except AttributeError: pass
self.buffers = buffers
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
def __eq__(self, __value: object) -> bool:
if __value.__class__ is not LazyOp: return False
__value = cast(LazyOp, __value)
return self.op == __value.op and self.src == __value.src and self.arg == __value.arg
def __hash__(self) -> int: return hash((self.op, self.src, self.arg))
@property
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
# Any == Union[LazyBuffer, DeviceBuffer]
def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return self.buffers
def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
from tinygrad.lazy import elementwise_op
assert self.op in BinaryOps or self.op in UnaryOps
return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore
@property
def st(self): raise NotImplementedError
@property
def children(self): raise NotImplementedError
@property
def shape(self): raise NotImplementedError
@property
def realized(self): raise NotImplementedError
@property
def optype(self): raise NotImplementedError
def realize(self): raise NotImplementedError
def reshape_op(self, _): raise NotImplementedError
def pad_op(self, _): raise NotImplementedError
def expand_op(self, _): raise NotImplementedError
def permute_op(self, _): raise NotImplementedError
def shrink_op(self, _): raise NotImplementedError
def stride_op(self, _): raise NotImplementedError
# **************** for Interpreted Buffers ****************
@@ -46,12 +91,12 @@ class Interpreted:
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(FusedOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg)
created_context = context is None
if context is None: context = dict()
if not created_context and ast in context: return context[ast]
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src]
if DEBUG >= 3: st = time.perf_counter()
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
@@ -90,7 +135,7 @@ class ASTRunner:
return self
def exec(self, bufs) -> Optional[float]:
rawbufs = [x.realized for x in bufs if x.realized is not None and not isinstance(x.realized, RawConst)]
rawbufs = [x.realized for x in bufs if x.realized is not None and x.realized.__class__ is not RawConst]
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs)
@@ -114,21 +159,21 @@ class Compiled:
def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!
if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized
if ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized: return ast.src[0].realized
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized is not None:
if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst
for a in get_buffers(ast):
if output.realized:
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
for a in ast.buffers:
if a.realized == output.realized and not a.st.contiguous:
output.realized = None
break
# we don't have an output buffer, we have to create it
if output.realized is None:
if not output.realized:
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
# compilation time

View File

@@ -14,6 +14,8 @@ class RawBuffer: # pylint: disable=abstract-method
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
@property
def key(self): return (self.size, self.dtype.key)
# NOTE: this interface allows for 0 copy
@classmethod
@@ -50,3 +52,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
class RawConst(RawBuffer): # pylint: disable=abstract-method
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
@property
def key(self): return (str(self._buf), self.dtype.key)