mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
309 lines
16 KiB
Python
309 lines
16 KiB
Python
from __future__ import annotations
|
|
import sys, math
|
|
from collections import defaultdict
|
|
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
|
|
from tinygrad.dtype import dtypes, DType, ImageDType
|
|
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
|
|
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
|
from tinygrad.shape.symbolic import sint, Variable
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.device import Buffer
|
|
from tinygrad.graph import log_lazybuffer
|
|
from weakref import ref, ReferenceType
|
|
|
|
# lazy can recurse a lot
|
|
sys.setrecursionlimit(10000)
|
|
|
|
lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
|
|
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
|
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
|
if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
|
|
|
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
|
if (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference
|
|
|
|
return LazyBuffer(device, st, dtype, op, arg, srcs, base=base, cache_key=cache_key if enable_cache else None)
|
|
|
|
class LazyBuffer:
|
|
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
|
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
|
base:Optional[LazyBuffer]=None, cache_key=None):
|
|
self.device, self.st, self.dtype, self.shape, self.size, self.cache_key = device, st, dtype, st.shape, st.size, cache_key
|
|
self._base: Optional[LazyBuffer] = None
|
|
if base is None:
|
|
# properties on base
|
|
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
|
self.realized: Optional[Buffer] = None
|
|
self.output_buffer: Optional[Buffer] = None
|
|
self.forced_realize = False
|
|
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
|
else:
|
|
# properties on view
|
|
assert base.base == base, "base must be a base itself"
|
|
self._base = base
|
|
if cache_key is not None: lazycache[cache_key] = ref(self)
|
|
|
|
def __del__(self): lazycache.pop(self.cache_key, None)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<LB {self.device} {self.shape} contig:{self.st.contiguous} {self.st if self.base != self else (self.op, self.realized)}>"
|
|
|
|
# NOTE: this has to be a function to prevent self reference
|
|
@property
|
|
def base(self) -> LazyBuffer: return self._base if self._base is not None else self
|
|
|
|
@staticmethod
|
|
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer:
|
|
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else (), enable_cache=False)
|
|
|
|
def const(self, val:Union[float, int]) -> LazyBuffer:
|
|
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
|
|
|
def contiguous(self):
|
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
|
ret = self.e(LoadOps.CONTIGUOUS)
|
|
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
|
return ret
|
|
self.base.forced_realize = True
|
|
return self
|
|
|
|
def cast(self, dtype:DType, bitcast:bool=False):
|
|
if self.dtype == dtype: return self
|
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
|
|
|
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
|
|
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
|
|
|
def schedule(self, seen=None): return create_schedule([self], seen)
|
|
|
|
def copy_to_device(self, device:str) -> LazyBuffer:
|
|
# no COPY
|
|
if self.device == device: return self
|
|
|
|
# double COPY = one COPY
|
|
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
|
|
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
|
|
|
# const doesn't have to be copied (issues with disk tensor)
|
|
if self.is_unrealized_const():
|
|
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
|
|
|
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
|
if prod(self.st.shape) < prod(self.base.st.shape):
|
|
return LazyBuffer.loadop(LoadOps.COPY, self.shape, self.dtype, device, src=self.contiguous())
|
|
|
|
# copy the base and apply the shapetracker on the new device
|
|
return LazyBuffer.loadop(LoadOps.COPY, self.base.shape, self.dtype, device, src=self.base)._view(self.st)
|
|
|
|
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
|
srcs: List[LazyBuffer] = []
|
|
for s in (self,)+in_srcs:
|
|
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
|
srcs.append(root._view(s.base.contiguous_child[1]))
|
|
else:
|
|
srcs.append(s)
|
|
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
|
|
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
|
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
|
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
|
|
|
# *** reduce ops ***
|
|
|
|
def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
|
if self.shape == new_shape: return self
|
|
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
|
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
|
|
|
|
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
|
assert len(self.shape)==len(new_shape) and all(ns in (1,s) for s,ns in zip(self.shape,new_shape)), f"not a contraction {self.shape=} {new_shape=}"
|
|
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
|
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
|
return self._reduce_op(op, new_shape)
|
|
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore # noqa: E501
|
|
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape)
|
|
# choose largest divisor (>=16) to split on, penalize large strides
|
|
def splitted_shape(dim_aft_div):
|
|
return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
|
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
|
|
|
# *** movement ops ***
|
|
|
|
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
|
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
|
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
|
|
|
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
|
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
|
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
|
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
|
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
|
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
|
|
|
# *** schedule creation ***
|
|
|
|
# recursively create a lazyop
|
|
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
|
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
|
|
if (buf, st) in cache: return cache[(buf, st)]
|
|
if buf != buf.base:
|
|
st = buf.st + st
|
|
buf = buf.base
|
|
# all buffers here are base now
|
|
assert buf.op is not None
|
|
|
|
# consts are always fused and generated
|
|
if buf.op is LoadOps.CONST:
|
|
unbound_st, st_var_vals = st.simplify().unbind()
|
|
var_vals.update(st_var_vals)
|
|
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st))
|
|
|
|
# if we aren't fusing it, it's a load and we add it to the inputs
|
|
if buf.realized or (buf in realizes and not first):
|
|
if buf not in inputs: inputs.append(buf)
|
|
unbound_st, st_var_vals = st.simplify().unbind()
|
|
var_vals.update(st_var_vals)
|
|
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
|
|
|
# if a CONTIGUOUS made it all the way here, just skip it
|
|
if buf.op is LoadOps.CONTIGUOUS:
|
|
assert first
|
|
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
|
|
|
# if it's a reduce, we have to change the shapetracker
|
|
if buf.op in ReduceOps:
|
|
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
|
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
|
|
|
# otherwise we fuse it like normal
|
|
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
|
|
return ret
|
|
|
|
# recursively walk back in the graph to create the schedule
|
|
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
|
|
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
|
|
if out in seen or out.realized or out.op == LoadOps.CONST: return []
|
|
assert out.base == out
|
|
seen.add(out)
|
|
|
|
inputs: List[LazyBuffer] = []
|
|
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
|
if out.op is LoadOps.COPY:
|
|
op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0].base), [out.srcs[0].base]
|
|
elif out.op is LoadOps.CUSTOM:
|
|
op, inputs = LazyOp(LoadOps.CUSTOM, (), out.arg), list(out.srcs)
|
|
elif out.op is LoadOps.EMPTY:
|
|
op = LazyOp(LoadOps.EMPTY)
|
|
else:
|
|
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
|
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
|
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
|
|
|
|
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
|
|
|
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
|
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
|
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
|
if buf in allbufs or buf.base.realized: return
|
|
if GRAPH: log_lazybuffer(buf, scheduled)
|
|
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
|
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
|
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
|
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
|
|
if buf.base != buf:
|
|
# realize all places where the buffer is expanded
|
|
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
|
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
|
prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
|
|
simple_pads.add(buf.base)
|
|
else:
|
|
realizes.add(buf.base)
|
|
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
|
if buf.forced_realize: realizes.add(buf)
|
|
allbufs[buf] = None
|
|
if buf.op in LoadOps: realizes.add(buf.base)
|
|
if buf.op == LoadOps.COPY:
|
|
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
|
realizes.add(buf.srcs[0].base)
|
|
for x in buf.srcs:
|
|
children[x.base][buf] = None
|
|
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
|
|
|
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
|
|
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
|
if buf in realizes or buf.realized: return True
|
|
# NOTE: this broke to_image_idx and coder with JIT
|
|
if buf.op in UNSAFE_PAD_OPS: return False
|
|
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
|
|
|
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
|
if seen is None: seen = set()
|
|
|
|
# start by just realizing the buffers passed in
|
|
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
|
allbufs: Dict[LazyBuffer, None] = {}
|
|
simple_pads: Set[LazyBuffer] = set()
|
|
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
|
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
|
|
|
# check if we have to realize pads
|
|
for p in simple_pads:
|
|
if not _is_padding_okay(p, realizes):
|
|
realizes.add(p)
|
|
|
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
|
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
|
for r in allbufs.keys():
|
|
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
|
|
|
# follow the reduce down
|
|
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
|
|
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
|
|
forced_realize = False
|
|
can_chase = True
|
|
while not forced_realize and len(child_set):
|
|
next_child_set = {}
|
|
for tr,st in child_set.items():
|
|
if tr in realizes:
|
|
realized_children[tr] = st
|
|
# can only have one output buffer
|
|
# can only reduce contiguous
|
|
# max one reduceop per kernel
|
|
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
|
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
|
forced_realize = True
|
|
break
|
|
continue
|
|
for tr_next in children[tr].keys():
|
|
if not tr_next.realized:
|
|
# max one reduceop per kernel
|
|
if tr_next.op in ReduceOps:
|
|
forced_realize = True
|
|
break
|
|
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
|
if len(st_childs) > 1:
|
|
forced_realize = True
|
|
break
|
|
next_child_set[tr_next] = st + st_childs[0].st
|
|
child_set = next_child_set
|
|
if forced_realize:
|
|
tr = r
|
|
if can_chase:
|
|
# can chase this down to contiguous children
|
|
st = tr.st
|
|
while len(children[tr]) == 1:
|
|
tr_next = next(iter(children[tr].keys()))
|
|
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
|
if len(st_childs) > 1: break
|
|
if st.size != st_childs[0].st.size: break
|
|
st = st + st_childs[0].st
|
|
if not st.contiguous or tr_next.op in ReduceOps: break
|
|
tr = tr_next
|
|
reduce_for_op[tr] = r
|
|
realizes.add(tr)
|
|
else:
|
|
assert len(realized_children) == 1
|
|
reduce_for_op[next(iter(realized_children.keys()))] = r
|
|
|
|
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
|