lazy type annotation and cleanups (#1897)

This commit is contained in:
chenyu
2023-09-22 02:20:23 -04:00
committed by GitHub
parent 78576915de
commit b89ee1ac83
3 changed files with 16 additions and 16 deletions

View File

@@ -37,7 +37,7 @@ class Kernel:
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = (ast.map_buffers({x:(self.arg_bufs[x.realized] if x.realized in self.arg_bufs else x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
def process(self) -> None:
if hasattr(self, "sts"): return # already processed

View File

@@ -42,30 +42,30 @@ def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -
# **** realize functions ****
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
# NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings
src = self.op.src[0]
if not src.realized:
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
# it's equivalent to performing the initial reduction and multiplying the result
# by the size of the expanded dimension.
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore
expanded = src.op.src[0] # type: ignore
if expanded.op.op == MovementOps.RESHAPE: # type: ignore
reshaped = expanded.op.src[0] # type: ignore
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND:
expanded = src.op.src[0]
assert isinstance(expanded.op, LazyOp)
if expanded.op.op == MovementOps.RESHAPE:
reshaped = expanded.op.src[0]
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
else:
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
if simplified: return simplified
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
if src.shape == self.shape:
return src.op # type: ignore
src = src.op # type: ignore
if src.shape == self.shape: return src.op
src = src.op
return LazyOp(self.op.op, (src,), self.op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers}
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in self.op.buffers}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
@@ -85,8 +85,9 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape)
ast = self.op.map_buffers(real_srcs)
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
# NOTE: cast the type to remove the Optional, and add str to the value Union to match argument types
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer, str]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
# **** lazy operations ****
@@ -312,8 +313,8 @@ class LazyBuffer:
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[Any]: return []
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)

View File

@@ -37,8 +37,7 @@ class LazyOp:
@property
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
# Any == Union[LazyBuffer, DeviceBuffer]
def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':