diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 8c238e1316..2152d22416 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -37,7 +37,7 @@ class Kernel: # key for lookup in cache (can change, str might not be right) # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) - self.key = (ast.map_buffers({x:(self.arg_bufs[x.realized] if x.realized in self.arg_bufs else x) for x in self.bufs}).key, tuple([x.key for x in self.bufs])) + self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs])) def process(self) -> None: if hasattr(self, "sts"): return # already processed diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 326ea5758f..2aba499423 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -42,30 +42,30 @@ def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) - # **** realize functions **** def _ast_reduceops(self:LazyBuffer) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before - # NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings src = self.op.src[0] if not src.realized: + assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp" # When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis, # it's equivalent to performing the initial reduction and multiplying the result # by the size of the expanded dimension. - if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore - expanded = src.op.src[0] # type: ignore - if expanded.op.op == MovementOps.RESHAPE: # type: ignore - reshaped = expanded.op.src[0] # type: ignore + if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: + expanded = src.op.src[0] + assert isinstance(expanded.op, LazyOp) + if expanded.op.op == MovementOps.RESHAPE: + reshaped = expanded.op.src[0] simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src) else: simplified = _simplify_sum_reshape_expand_sum(self, expanded, src) if simplified: return simplified if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: # If we did remove an expand above, we might stumble back into a case where the reduction is not necessary - if src.shape == self.shape: - return src.op # type: ignore - src = src.op # type: ignore + if src.shape == self.shape: return src.op + src = src.op return LazyOp(self.op.op, (src,), self.op.arg) # this supports late merging an upstream Reduce op and even an Elementwise op above that def _ast_binaryops(self:LazyBuffer) -> LazyOp: - real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers} + real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in self.op.buffers} # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how # TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1] @@ -85,8 +85,9 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp: # reshape all the late ops into the output shape # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): - if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape) - ast = self.op.map_buffers(real_srcs) + if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape) + # NOTE: cast the type to remove the Optional, and add str to the value Union to match argument types + ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer, str]], real_srcs)) return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast # **** lazy operations **** @@ -312,8 +313,8 @@ class LazyBuffer: @property def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) - def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self) - def get_lazyops(self) -> List[Any]: return [] + def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]): return real_srcs.get(self, self) + def get_lazyops(self) -> List[LazyOp]: return [] def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: y = self for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3dcf9c4056..5519faf0b0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -37,8 +37,7 @@ class LazyOp: @property def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg)) - # Any == Union[LazyBuffer, DeviceBuffer] - def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg) + def map_buffers(self, real_srcs: Dict[LazyBuffer, Union[LazyBuffer, LazyOp, str]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg) def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()] def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':