mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Optimizations in lazy.py (#987)
* optimizations in lazy.py * make mypy happy with stubs and fix the graph import hack * merge conflict in helpers.py
This commit is contained in:
@@ -4,13 +4,12 @@ from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
|
||||
from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
|
||||
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
|
||||
|
||||
@@ -94,27 +93,27 @@ class Linearizer:
|
||||
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
|
||||
|
||||
# get the output buffers
|
||||
self.bufs = [output_buffer] + dedup(get_buffers(ast))
|
||||
self.bufs = [output_buffer] + dedup(ast.buffers)
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
|
||||
self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}"
|
||||
self.key = (ast.map_buffers({x:i for i,x in enumerate(self.bufs)}).key, tuple([x.key for x in self.bufs]))
|
||||
|
||||
def process(self) -> None:
|
||||
if hasattr(self, "sts"): return # already processed
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(self.ast)
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps]
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
|
||||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else []
|
||||
self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
|
||||
@@ -178,7 +177,7 @@ class Linearizer:
|
||||
for k,out_tokens in self._group_float4(i, load_offset).items():
|
||||
idxs = [x[2]-out_tokens[0][2] for x in out_tokens]
|
||||
valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3])
|
||||
if any(idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
|
||||
if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
|
||||
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
|
||||
for x in out_tokens: load_offset_new[x[1]] = x
|
||||
else:
|
||||
@@ -306,13 +305,13 @@ class Linearizer:
|
||||
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore
|
||||
|
||||
# end the late reduce loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer})
|
||||
|
||||
# run late AST
|
||||
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
|
||||
@@ -334,17 +333,17 @@ class Linearizer:
|
||||
return out
|
||||
|
||||
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
|
||||
if not isinstance(x, LazyOp): return loaded_buffers[x]
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce: return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL:
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
|
||||
# TODO: fold float4 into a single uop when possible.
|
||||
if isinstance(x.op, (ReduceOps, FusedOps)):
|
||||
if x.op.__class__ in {ReduceOps, FusedOps}:
|
||||
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
@@ -431,7 +430,7 @@ class Linearizer:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return
|
||||
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
|
||||
all_ones = [all([st.shape[i]==1 for st in self.sts]) for i in range(self.shape_len)]
|
||||
# keep at least 1 one
|
||||
if all(all_ones): all_ones[-1] = False
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
@@ -456,14 +455,14 @@ class Linearizer:
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
|
||||
if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType):
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
self.shift_to(unit_stride_axes_mul_4[0], 4)
|
||||
|
||||
@@ -5,13 +5,11 @@ except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import getenv, DEBUG, GlobalCounters
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp
|
||||
from tinygrad.tensor import LazyBuffer
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
|
||||
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
@@ -52,8 +50,8 @@ def str_dtype(dtyp):
|
||||
def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
|
||||
if show_graph is None: show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
op: List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
inp: List[LazyBuffer] = [x for x in get_buffers(ast) if not isinstance(x.realized, RawConst) or GRAPH > 1]
|
||||
op: List[Op] = [x.op for x in ast.get_lazyops()]
|
||||
inp: List[LazyBuffer] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
|
||||
oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import platform
|
||||
from dataclasses import dataclass, asdict
|
||||
import os, math, functools, time, re
|
||||
import os, functools, platform, time, re
|
||||
from weakref import KeyedRef, ref
|
||||
from _weakref import _remove_dead_weakref # type: ignore
|
||||
import numpy as np
|
||||
from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
|
||||
from math import prod # noqa: F401 # pylint:disable=unused-import
|
||||
|
||||
ShapeType = Tuple[int, ...]
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
||||
def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x)
|
||||
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
|
||||
def argfix(*x):
|
||||
if x[0].__class__ in {tuple, list}:
|
||||
try: return tuple(x[0])
|
||||
except IndexError: return tuple()
|
||||
return tuple(x)
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
||||
def all_same(items): return all([x == items[0] for x in items]) if len(items) > 1 else True
|
||||
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
||||
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
|
||||
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
|
||||
@@ -43,6 +48,7 @@ class ContextVar:
|
||||
def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value
|
||||
|
||||
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
|
||||
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
class Timing(object):
|
||||
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
||||
@@ -60,6 +66,8 @@ class DType(NamedTuple):
|
||||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
||||
sz: int = 1
|
||||
def __repr__(self): return f"dtypes.{self.name}"
|
||||
@property
|
||||
def key(self): return (self.name)
|
||||
|
||||
# dependent typing?
|
||||
class ImageDType(DType):
|
||||
@@ -70,7 +78,6 @@ class ImageDType(DType):
|
||||
super().__init__()
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@dataclass
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
@@ -79,7 +86,9 @@ class dtypes:
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
|
||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", bool)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
half = float16
|
||||
@@ -97,6 +106,9 @@ class dtypes:
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
global_mem: ClassVar[int] = 0
|
||||
@@ -106,3 +118,36 @@ class GlobalCounters:
|
||||
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
|
||||
# Stripped down version of a WeakSet
|
||||
class LightWeakSet:
|
||||
__slots__ = 'data', '_remove', '__weakref__'
|
||||
def __init__(self):
|
||||
self.data = set()
|
||||
def _remove(item, selfref=ref(self)):
|
||||
self = selfref()
|
||||
if self: self.data.discard(item)
|
||||
self._remove = _remove
|
||||
|
||||
def __len__(self): return len(self.data)
|
||||
def add(self, item): self.data.add(ref(item, self._remove))
|
||||
def discard(self, item): self.data.discard(ref(item))
|
||||
|
||||
# Stripped down version of a WeakValueDictionary
|
||||
class LightWeakValueDictionary:
|
||||
__slots__ = 'data', '_remove', '__weakref__'
|
||||
def __init__(self):
|
||||
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
|
||||
self = selfref()
|
||||
if self: _atomic_removal(self.data, wr.key)
|
||||
self._remove = remove
|
||||
self.data = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
o = self.data[key]()
|
||||
if o is None: raise KeyError(key)
|
||||
else: return o
|
||||
|
||||
def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key)
|
||||
|
||||
def __contains__(self, key): return key in self.data
|
||||
367
tinygrad/lazy.py
367
tinygrad/lazy.py
@@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, cast
|
||||
import operator
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
|
||||
import sys, importlib, inspect, functools, pathlib
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
from weakref import WeakValueDictionary, ref, WeakSet
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped
|
||||
from tinygrad.helpers import GRAPH, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction
|
||||
from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -25,23 +27,23 @@ PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
||||
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
src = self.op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
src = src.op
|
||||
return LazyOp(self.op.op, (src,), self.op.arg)
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1:
|
||||
src = src.op # type: ignore
|
||||
return LazyOp(self.op.op, (src,), self.op.arg, src.get_buffers())
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers}
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape: Tuple[int, ...] = self.shape
|
||||
if len(psrcs) >= 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
|
||||
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and len(psrcs) >= 1:
|
||||
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
|
||||
if psrc[1].optype == ReduceOps:
|
||||
top = _ast_reduceops(psrc[1])
|
||||
real_srcs[psrc[0]] = top
|
||||
real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified
|
||||
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
|
||||
|
||||
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
|
||||
if psrc[0].shape != psrc[1].shape:
|
||||
@@ -51,124 +53,91 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
# reshape all the late ops into the output shape
|
||||
# NOTE: these RESHAPEs will return self if they don't change the shape
|
||||
for x in real_srcs.keys():
|
||||
if real_srcs[x] is None: real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
|
||||
ast = map_buffers(real_srcs, self.op)
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape)
|
||||
ast = self.op.map_buffers(real_srcs)
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape, ast.buffers) if intermediate_shape != self.shape else ast
|
||||
|
||||
# **** lazy operations ****
|
||||
def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple([get_weakop(x) if x.__class__ is LazyOp else ref(x) for x in op.src]), op.arg)
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def replace_with_movement_ops(y:Union[LazyOp, LazyBuffer], ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer):
|
||||
for op, arg in ops: y = y.movement_op(op, arg)
|
||||
return y
|
||||
assert y.op in BinaryOps or y.op in UnaryOps
|
||||
return elementwise_op(y.op, *[replace_with_movement_ops(z, ops) for z in y.src], arg=y.arg) # type: ignore
|
||||
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
lazycache: WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType):
|
||||
st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
lazycache: LightWeakValueDictionary = LightWeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
|
||||
|
||||
# fromcpu aren't cached
|
||||
if optype == LoadOps and op.op in [LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST]: return LazyBuffer(device, st, optype, op, dtype)
|
||||
if optype == LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}: return LazyBuffer(device, st, optype, op, dtype)
|
||||
|
||||
#print("create_lazybuffer", device, shape, optype, op, dtype)
|
||||
|
||||
# NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
|
||||
# get_weakop makes all the LazyBuffers in the op have a weakref
|
||||
wop = (device, dtype, optype, get_weakop(op))
|
||||
wop = (device, dtype, optype, ref(op))
|
||||
|
||||
if wop not in lazycache: lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
||||
else: ret = lazycache[wop]
|
||||
if wop in lazycache: return lazycache[wop]
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
||||
return ret
|
||||
|
||||
class LazyBuffer:
|
||||
__slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__'
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, src:Union[LazyOp, RawBuffer], dtype:DType):
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
|
||||
self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
|
||||
self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype
|
||||
self.realized: Optional[RawBuffer] = src if isinstance(src, RawBuffer) else None
|
||||
self.realized: Optional[RawBuffer] = src
|
||||
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children: WeakSet[LazyBuffer] = WeakSet()
|
||||
self.children: LightWeakSet = LightWeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
if isinstance(src, LazyOp):
|
||||
self.op: LazyOp = src
|
||||
for x in get_buffers(self.op): x.children.add(self)
|
||||
if op:
|
||||
self.op: LazyOp = op
|
||||
for x in op.buffers: x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
|
||||
# log phantom ops to the graph
|
||||
from tinygrad.graph import log_op, GRAPH
|
||||
if GRAPH >= 3: log_op(self, self.op, phantom=True)
|
||||
if GRAPH >= 3:
|
||||
from tinygrad.graph import log_op
|
||||
log_op(self, self.op, phantom=True)
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype.key, self.realized.key, self.st.key)
|
||||
return (self.dtype.key, self.op.op, self.st.key)
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op:{self.op.op if self.realized is None else self.realized} st:{self.st}>"
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
|
||||
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
if self.realized is None:
|
||||
if not self.realized:
|
||||
# get real ops first
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
realized = self.op.src[0].realize().realized
|
||||
if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape):
|
||||
# no need to run an AST, this is already contiguous
|
||||
self.realized = realized
|
||||
else:
|
||||
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
|
||||
self.op = LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
elif self.op.op == LoadOps.CUSTOM:
|
||||
# this needs to immediately realize
|
||||
self.realized = self.op.arg(self, *[x.realize() for x in self.op.src])
|
||||
elif self.op.op == LoadOps.FROM:
|
||||
rawbuf = self.op.src[0].realize()
|
||||
# TODO: make this generic
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped):
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer())
|
||||
else:
|
||||
self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args())
|
||||
elif self.optype == LoadOps:
|
||||
if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}")
|
||||
if self.op.op == LoadOps.EMPTY:
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.RAND:
|
||||
rng = np.random.default_rng(self.op.arg)
|
||||
assert self.dtype.np is not None, "internal dtypes don't work with LoadOps.RAND"
|
||||
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.CONST:
|
||||
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):
|
||||
self.realized = RawConst(1, self.dtype, float(self.op.arg))
|
||||
else:
|
||||
self.realized = Device[self.device].buffer.fromCPU(np.array(self.op.arg, dtype=self.dtype.np), **self._device_extra_args())
|
||||
# these can be late folded and change the op to go further back in the graph
|
||||
elif self.optype == ReduceOps: self.op = _ast_reduceops(self)
|
||||
elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape
|
||||
|
||||
if self.optype in REALIZE_DISPATCHER:
|
||||
self.op = REALIZE_DISPATCHER[self.optype](self)
|
||||
elif self.op.op in REALIZE_DISPATCHER:
|
||||
REALIZE_DISPATCHER[self.op.op](self)
|
||||
# run the ast if we still have to, and log the op
|
||||
if self.realized is None:
|
||||
for x in get_buffers(self.op): x.realize()
|
||||
if not self.realized:
|
||||
for x in self.op.buffers: x.realize()
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if self.optype != MovementOps and self.dtype.__class__ is ImageDType and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
# put CAST before the final RESHAPE
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
|
||||
else:
|
||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32)
|
||||
self.dtype = dtypes.float32
|
||||
|
||||
self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args())
|
||||
|
||||
assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
# HACK: allow hot casting of images
|
||||
assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}"
|
||||
self.dtype = self.realized.dtype
|
||||
|
||||
# log to the graph
|
||||
from tinygrad.graph import log_op, GRAPH
|
||||
if not isinstance(self.realized, RawConst) or GRAPH >= 2: log_op(self, self.op)
|
||||
if self.realized.__class__ is not RawConst or GRAPH >= 2:
|
||||
from tinygrad.graph import log_op
|
||||
log_op(self, self.op)
|
||||
|
||||
# no need to keep the op after realization
|
||||
del self.op
|
||||
@@ -176,17 +145,17 @@ class LazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, RawNumpyBuffer.fromCPU(x), dtypes.from_np(x.dtype))
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None, ()), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
|
||||
.movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape)
|
||||
.reshape_op((1,)*len(self.shape)).expand_op(self.shape)
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
@@ -198,89 +167,99 @@ class LazyBuffer:
|
||||
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
|
||||
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, tuple(srcs), new_shape), self.dtype)
|
||||
|
||||
# shrink -> stride -> permute -> reshape -> pad -> expand
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
|
||||
# very instant nop
|
||||
if op == MovementOps.RESHAPE and self.shape == arg: return self
|
||||
|
||||
# TODO: look into why that copy is needed
|
||||
local_st = ShapeTracker(self.shape).movement_op(op, arg)
|
||||
|
||||
# instant nops
|
||||
if local_st.contiguous and self.shape == local_st.shape: return self
|
||||
|
||||
# two ops in a row is one op. merge them if unresolved
|
||||
if self.realized is None and self.op.op == op:
|
||||
# TODO: why is deleting self from children needed? shouldn't GC do it?
|
||||
self.op.src[0].children.discard(self)
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.SHRINK: return self.op.src[0].movement_op(op, tuple((b1+b2, b1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
if op == MovementOps.PERMUTE: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
if op == MovementOps.STRIDE: return self.op.src[0].movement_op(op, tuple(i*j for i,j in zip(arg, self.op.arg)))
|
||||
|
||||
# push permutes before reduce ops
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.optype == ReduceOps:
|
||||
# reduceops have one buffer input, permute it
|
||||
narg = tuple(self.op.arg[arg[i]] for i in range(len(arg)))
|
||||
src, rop = self.op.src[0], self.op.op
|
||||
src.children.discard(self)
|
||||
del self # TODO: why doesn't this delete remove it from the children
|
||||
return src.movement_op(op, arg).reduce_op(rop, narg)
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
# move permutes before expands (always, this is safe)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == MovementOps.EXPAND:
|
||||
self.op.src[0].children.discard(self)
|
||||
return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg))
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # this changes nothing?
|
||||
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(flatten(shape_idx_groups[i] for i in arg))) \
|
||||
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)
|
||||
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and (op in [MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE] or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: # and op != MovementOps.EXPAND and (op != MovementOps.PAD or (SHUFFLE_PAD_OPS and all(x.op != BinaryOps.DIV for x in get_lazyops(self.op)))):
|
||||
return replace_with_movement_ops(self.op, [(op, arg)])
|
||||
|
||||
# create the buffer
|
||||
ret = create_lazybuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg), self.dtype)
|
||||
|
||||
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
|
||||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None, (self,)), self.dtype)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg, (self,)), self.dtype)
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and not ret.realized and ret.st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.movement_op(MovementOps.RESHAPE, ret.st.shape)
|
||||
|
||||
return root.reshape_op(ret.st.shape)
|
||||
return ret
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
||||
|
||||
def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].reshape_op(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b == 0 and e == 0 for b,e in arg]): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand_op(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
|
||||
|
||||
def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if arg == tuple(range(len(self.shape))): return self
|
||||
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg]))
|
||||
if not self.realized:
|
||||
if PUSH_PERMUTES and self.optype == ReduceOps:
|
||||
# reduceops have one buffer input, permute it
|
||||
narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))])
|
||||
src, rop = self.op.src[0], self.op.op
|
||||
src.children.discard(self)
|
||||
del self # TODO: why doesn't this delete remove it from the children
|
||||
return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg)
|
||||
|
||||
# move permutes before expands (always, this is safe)
|
||||
if self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg]))
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \
|
||||
.reshape_op(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
local_st = ShapeTracker(self.shape).stride(arg)
|
||||
if self.shape == local_st.shape and local_st.contiguous: return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
|
||||
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def get_lazyops(self) -> List[Any]: return []
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
return y
|
||||
|
||||
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
new_srcs = []
|
||||
for x in srcs:
|
||||
mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = []
|
||||
mops: List[Tuple[MovementOps, Any]] = []
|
||||
bx = x
|
||||
# backwalk all the movement ops. don't push PAD or EXPAND
|
||||
while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1:
|
||||
while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = bx.op.src[0]
|
||||
bx = cast(LazyBuffer, bx.op.src[0])
|
||||
# NOTE: can't push pads with a div
|
||||
if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))):
|
||||
new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1]))
|
||||
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all([x[0] != MovementOps.PAD for x in mops]) or all([x.op != BinaryOps.DIV for x in bx.op.get_lazyops()])):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return tuple(new_srcs)
|
||||
@@ -290,29 +269,30 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional
|
||||
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
|
||||
|
||||
# get outputs now
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
new_srcs = []
|
||||
if PUSH_CONTIGUOUS and any([not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs]):
|
||||
new_srcs: List[LazyBuffer] = []
|
||||
for x in srcs:
|
||||
if x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
|
||||
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
|
||||
x.op.src[0].children.discard(x)
|
||||
new_srcs.append(x.op.src[0])
|
||||
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
return elementwise_op(op, *new_srcs, arg=arg).contiguous()
|
||||
|
||||
if MERGE_ELEMENTWISE_OPS:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
|
||||
srcs = tuple([x.op if x.optype == BinaryOps and len(x.children) == 0 and not x.realized else x for x in srcs]) # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
|
||||
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper())
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
@@ -324,3 +304,58 @@ class _Device:
|
||||
except Exception: pass
|
||||
return "CPU"
|
||||
Device = _Device()
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer) -> None:
|
||||
realized = buffer.op.src[0].realize().realized
|
||||
if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape):
|
||||
# no need to run an AST, this is already contiguous
|
||||
buffer.realized = realized
|
||||
else:
|
||||
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
|
||||
buffer.op = LazyOp(UnaryOps.NOOP, buffer.op.src)
|
||||
|
||||
def _realize_custom(buffer: LazyBuffer) -> None:
|
||||
# this needs to immediately realize
|
||||
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
|
||||
|
||||
def _realize_from(buffer: LazyBuffer) -> None:
|
||||
rawbuf = buffer.op.src[0].realize()
|
||||
# TODO: make this generic
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
|
||||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
|
||||
|
||||
def _realize_const(buffer: LazyBuffer) -> None:
|
||||
if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'):
|
||||
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
|
||||
|
||||
REALIZE_DISPATCHER: Dict[Any, Callable] = {
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.CONST: _realize_const,
|
||||
ReduceOps: _ast_reduceops,
|
||||
BinaryOps: _ast_binaryops,
|
||||
}
|
||||
|
||||
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape_op,
|
||||
MovementOps.EXPAND: LazyBuffer.expand_op,
|
||||
MovementOps.SHRINK: LazyBuffer.shrink_op,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute_op,
|
||||
MovementOps.PAD: LazyBuffer.pad_op,
|
||||
MovementOps.STRIDE: LazyBuffer.stride_op,
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, ShapeType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
import math
|
||||
@@ -10,6 +10,7 @@ class Contiguous(Function):
|
||||
def backward(self, grad_output): return grad_output
|
||||
|
||||
class Cast(Function):
|
||||
__slots__ = "input_dtype"
|
||||
def forward(self, x, dtype):
|
||||
self.input_dtype = x.dtype
|
||||
return x.cast(dtype)
|
||||
@@ -19,6 +20,7 @@ class Cast(Function):
|
||||
# ************* unary ops *************
|
||||
|
||||
class Sin(Function):
|
||||
__slots__ = "x"
|
||||
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.SIN)
|
||||
@@ -26,6 +28,7 @@ class Sin(Function):
|
||||
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
|
||||
return self.ret
|
||||
@@ -35,6 +38,7 @@ class Relu(Function):
|
||||
return mask.binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(Function):
|
||||
__slots__ = "x"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e)))
|
||||
@@ -43,6 +47,7 @@ class Log(Function):
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
@@ -53,27 +58,29 @@ class Exp(Function):
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
__slots__ = "input_shape"
|
||||
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reduce_op(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.EXPAND, self.input_shape)
|
||||
return grad_output.expand_op(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
__slots__ = "x", "ret"
|
||||
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.movement_op(MovementOps.EXPAND, self.x.shape))
|
||||
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).movement_op(MovementOps.EXPAND, self.x.shape)
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, self.x.shape)
|
||||
grad_output_expanded = grad_output.expand_op(self.x.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
@@ -83,6 +90,7 @@ class Equal(Function):
|
||||
return x.binary_op(BinaryOps.CMPEQ, y)
|
||||
|
||||
class Maximum(Function):
|
||||
__slots__ = "x", "y", "ret"
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
self.ret = x.binary_op(BinaryOps.MAX, y)
|
||||
@@ -113,6 +121,7 @@ class Sub(Function):
|
||||
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
@@ -122,6 +131,7 @@ class Mul(Function):
|
||||
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Pow(Function):
|
||||
__slots__ = 'x', 'y', 'ret'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
|
||||
return self.ret
|
||||
@@ -131,6 +141,7 @@ class Pow(Function):
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.DIV, y)
|
||||
@@ -143,49 +154,55 @@ class Div(Function):
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
__slots__ = 'input_shape'
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.movement_op(MovementOps.EXPAND, shape)
|
||||
return x.expand_op(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
__slots__ = 'input_shape'
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.movement_op(MovementOps.RESHAPE, shape)
|
||||
return x.reshape_op(shape)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.RESHAPE, self.input_shape)
|
||||
return grad_output.reshape_op(self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
__slots__ = 'input_order'
|
||||
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_order = order
|
||||
return x.movement_op(MovementOps.PERMUTE, order)
|
||||
return x.permute_op(order)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, argsort(self.input_order))
|
||||
return grad_output.permute_op(argsort(self.input_order))
|
||||
|
||||
class Pad(Function):
|
||||
__slots__ = 'narg'
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple((p[0], s+p[0]) for s,p in zip(x.shape, arg))
|
||||
return x.movement_op(MovementOps.PAD, arg)
|
||||
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
||||
return x.pad_op(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.movement_op(MovementOps.SHRINK, self.narg)
|
||||
return grad_output.shrink_op(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
__slots__ = 'narg'
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple((p[0], s-p[1]) for s,p in zip(x.shape, arg))
|
||||
return x.movement_op(MovementOps.SHRINK, arg)
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink_op(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.movement_op(MovementOps.PAD, self.narg)
|
||||
return grad_output.pad_op(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
__slots__ = 'arg'
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
|
||||
self.arg = tuple(-1 if i in axis else 1 for i in range(len(x.shape)))
|
||||
return x.movement_op(MovementOps.STRIDE, self.arg)
|
||||
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
||||
return x.stride_op(self.arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.movement_op(MovementOps.STRIDE, self.arg)
|
||||
return grad_output.stride_op(self.arg)
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator, time
|
||||
import functools, time
|
||||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable
|
||||
from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.shape.shapetracker import MovementOps
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
@@ -19,19 +21,62 @@ class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]]
|
||||
|
||||
class LazyOp(NamedTuple):
|
||||
op: Op
|
||||
# Any == Union[LazyOp, LazyBuffer, DeviceBuffer]
|
||||
src: Tuple[Any, ...] # type: ignore
|
||||
arg: Any = None
|
||||
class LazyOp:
|
||||
# TODO: add dest to support multiple outputs. on second thought, multiple outputs will have multiple LazyOps.
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
op: Op
|
||||
src: Tuple[Union[LazyOp, LazyBuffer], ...]
|
||||
arg: Any
|
||||
buffers: Tuple[LazyBuffer, ...]
|
||||
|
||||
# Any == Union[LazyBuffer, DeviceBuffer]
|
||||
def get_buffers(op:LazyOp) -> List[Any]: return functools.reduce(operator.add, [get_buffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
|
||||
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
|
||||
def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp:
|
||||
if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x]
|
||||
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
|
||||
def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None, buffers: Optional[Tuple[LazyBuffer, ...]] = None):
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.arg = arg
|
||||
if not buffers:
|
||||
buffers = tuple()
|
||||
for s in src:
|
||||
try: buffers += s.get_buffers()
|
||||
except AttributeError: pass
|
||||
self.buffers = buffers
|
||||
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if __value.__class__ is not LazyOp: return False
|
||||
__value = cast(LazyOp, __value)
|
||||
return self.op == __value.op and self.src == __value.src and self.arg == __value.arg
|
||||
def __hash__(self) -> int: return hash((self.op, self.src, self.arg))
|
||||
@property
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
# Any == Union[LazyBuffer, DeviceBuffer]
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
|
||||
def get_buffers(self) -> Tuple[LazyBuffer, ...]: return self.buffers
|
||||
def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
from tinygrad.lazy import elementwise_op
|
||||
assert self.op in BinaryOps or self.op in UnaryOps
|
||||
return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore
|
||||
|
||||
@property
|
||||
def st(self): raise NotImplementedError
|
||||
@property
|
||||
def children(self): raise NotImplementedError
|
||||
@property
|
||||
def shape(self): raise NotImplementedError
|
||||
@property
|
||||
def realized(self): raise NotImplementedError
|
||||
@property
|
||||
def optype(self): raise NotImplementedError
|
||||
def realize(self): raise NotImplementedError
|
||||
def reshape_op(self, _): raise NotImplementedError
|
||||
def pad_op(self, _): raise NotImplementedError
|
||||
def expand_op(self, _): raise NotImplementedError
|
||||
def permute_op(self, _): raise NotImplementedError
|
||||
def shrink_op(self, _): raise NotImplementedError
|
||||
def stride_op(self, _): raise NotImplementedError
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
@@ -46,12 +91,12 @@ class Interpreted:
|
||||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs):
|
||||
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg)
|
||||
if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(FusedOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg)
|
||||
created_context = context is None
|
||||
if context is None: context = dict()
|
||||
if not created_context and ast in context: return context[ast]
|
||||
srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src]
|
||||
srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
@@ -90,7 +135,7 @@ class ASTRunner:
|
||||
return self
|
||||
|
||||
def exec(self, bufs) -> Optional[float]:
|
||||
rawbufs = [x.realized for x in bufs if x.realized is not None and not isinstance(x.realized, RawConst)]
|
||||
rawbufs = [x.realized for x in bufs if x.realized is not None and x.realized.__class__ is not RawConst]
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs))
|
||||
return self(rawbufs)
|
||||
|
||||
@@ -114,21 +159,21 @@ class Compiled:
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, **kwargs):
|
||||
# all movementops do nothing in a Compiled buffer!
|
||||
if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized
|
||||
if ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized: return ast.src[0].realized
|
||||
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized is not None:
|
||||
if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst
|
||||
for a in get_buffers(ast):
|
||||
if output.realized:
|
||||
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
|
||||
for a in ast.buffers:
|
||||
if a.realized == output.realized and not a.st.contiguous:
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it
|
||||
if output.realized is None:
|
||||
if not output.realized:
|
||||
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
|
||||
|
||||
# compilation time
|
||||
|
||||
@@ -14,6 +14,8 @@ class RawBuffer: # pylint: disable=abstract-method
|
||||
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
||||
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
|
||||
@property
|
||||
def key(self): return (self.size, self.dtype.key)
|
||||
|
||||
# NOTE: this interface allows for 0 copy
|
||||
@classmethod
|
||||
@@ -50,3 +52,5 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
||||
class RawConst(RawBuffer): # pylint: disable=abstract-method
|
||||
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
|
||||
@property
|
||||
def key(self): return (str(self._buf), self.dtype.key)
|
||||
|
||||
Reference in New Issue
Block a user