mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
lazy type annotation and cleanups (#1897)
This commit is contained in:
@@ -37,7 +37,7 @@ class Kernel:
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
|
||||
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
|
||||
self.key = (ast.map_buffers({x:(self.arg_bufs[x.realized] if x.realized in self.arg_bufs else x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
|
||||
self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
|
||||
|
||||
def process(self) -> None:
|
||||
if hasattr(self, "sts"): return # already processed
|
||||
|
||||
@@ -42,30 +42,30 @@ def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -
|
||||
# **** realize functions ****
|
||||
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
# NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings
|
||||
src = self.op.src[0]
|
||||
if not src.realized:
|
||||
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
|
||||
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
|
||||
# it's equivalent to performing the initial reduction and multiplying the result
|
||||
# by the size of the expanded dimension.
|
||||
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore
|
||||
expanded = src.op.src[0] # type: ignore
|
||||
if expanded.op.op == MovementOps.RESHAPE: # type: ignore
|
||||
reshaped = expanded.op.src[0] # type: ignore
|
||||
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND:
|
||||
expanded = src.op.src[0]
|
||||
assert isinstance(expanded.op, LazyOp)
|
||||
if expanded.op.op == MovementOps.RESHAPE:
|
||||
reshaped = expanded.op.src[0]
|
||||
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
|
||||
else:
|
||||
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
|
||||
if simplified: return simplified
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
|
||||
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
|
||||
if src.shape == self.shape:
|
||||
return src.op # type: ignore
|
||||
src = src.op # type: ignore
|
||||
if src.shape == self.shape: return src.op
|
||||
src = src.op
|
||||
return LazyOp(self.op.op, (src,), self.op.arg)
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers}
|
||||
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in self.op.buffers}
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
@@ -85,8 +85,9 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
# reshape all the late ops into the output shape
|
||||
# NOTE: these RESHAPEs will return self if they don't change the shape
|
||||
for x in real_srcs.keys():
|
||||
if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape)
|
||||
ast = self.op.map_buffers(real_srcs)
|
||||
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
|
||||
# NOTE: cast the type to remove the Optional, and add str to the value Union to match argument types
|
||||
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer, str]], real_srcs))
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
# **** lazy operations ****
|
||||
@@ -312,8 +313,8 @@ class LazyBuffer:
|
||||
|
||||
@property
|
||||
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[Any]: return []
|
||||
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[LazyOp]: return []
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
|
||||
|
||||
@@ -37,8 +37,7 @@ class LazyOp:
|
||||
@property
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
# Any == Union[LazyBuffer, DeviceBuffer]
|
||||
def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
|
||||
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
|
||||
|
||||
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
|
||||
|
||||
Reference in New Issue
Block a user