diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b86315feac..4b1630a8d2 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -8,7 +8,7 @@ from tinygrad.graph import log_op from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction -from tinygrad.shape.symbolic import Variable, NumNode, sint, all_int +from tinygrad.shape.symbolic import Variable, sint, all_int from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer from tinygrad.runtime.ops_cpu import RawNumpyBuffer @@ -195,10 +195,8 @@ class LazyBuffer: assert self.dtype.np, f"{self.dtype} is not supported in toCPU" self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self realized = self_casted.contiguous().realize().realized - # TODO: replace NumNode with int in shape - output_shape = tuple(s.b if isinstance(s, NumNode) else s for s in self.shape) - assert all_int(output_shape), f"no toCPU if shape is symbolic, {output_shape=}" - return cast(RawBuffer, realized).toCPU().reshape(output_shape) + assert all_int(self.shape), f"no toCPU if shape is symbolic, {self.shape=}" + return cast(RawBuffer, realized).toCPU().reshape(self.shape) def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: # srcs includes self @@ -227,7 +225,7 @@ class LazyBuffer: return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) - def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer: + def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: return self.op.replace_with_movement_ops([(op, arg)]) if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: @@ -301,7 +299,7 @@ class LazyBuffer: return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg) - def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index a9b5fa3486..fd0f1c7db7 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,9 +1,10 @@ import math -from typing import Tuple, Optional +from typing import Tuple, Optional, cast from tinygrad.helpers import argsort, DType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer +from tinygrad.shape.symbolic import sint class Contiguous(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous() @@ -192,12 +193,14 @@ class Pad(Function): return grad_output.shrink(self.narg) class Shrink(Function): - def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.pad(self.narg) + assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward" + # need this cast because mypy cannot narrow the type even with assert + return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 448351e377..8dc0643319 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -45,10 +45,9 @@ class ShapeTracker: # these are multiview strides, value is None if it's not a simple strided dimension # TODO: this can be shared code between simplify and merge_views - def real_offset(self) -> int: - real_offset, mask = self.expr_node(Variable('zero', 0, 0)) - assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}" - return real_offset.b + def real_offset(self) -> sint: + real_offset, _ = self.expr_node(Variable('zero', 0, 0)) + return real_offset.b if isinstance(real_offset, NumNode) else real_offset # NOTE: if a stride is not always valid, it will be None def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: @@ -59,7 +58,7 @@ class ShapeTracker: for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs: ret[idxs.index(this_dim.a)] = this_dim.b - elif isinstance(this_dim, Variable): + elif isinstance(this_dim, Variable) and this_dim in idxs: ret[idxs.index(this_dim)] = 1 idx_vars, valid_vars = idx.vars(), valid.vars() for i,tidx in enumerate(idxs): @@ -103,7 +102,7 @@ class ShapeTracker: self.views[-1] = self.views[-1].pad(arg) return self - def shrink(self, arg: Tuple[Tuple[int, int], ...]): + def shrink(self, arg: Tuple[Tuple[sint, sint], ...]): self.views[-1] = self.views[-1].shrink(arg) return self diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 2043db1648..9510e15606 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -157,7 +157,6 @@ class NumNode(Node): self.b:int = num self.min, self.max = num, num def __int__(self): return self.b - def __index__(self): return self.b def __eq__(self, other): return self.b == other def __hash__(self): return self.hash # needed with __eq__ override def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 0af74810bb..8be7d61fb7 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -2,7 +2,7 @@ from __future__ import annotations import functools from typing import Tuple, List, Optional, NamedTuple from tinygrad.helpers import prod -from tinygrad.shape.symbolic import Variable, Node, is_sym_int, sint, all_int +from tinygrad.shape.symbolic import Variable, Node, NumNode, is_sym_int, sint, all_int @functools.lru_cache(maxsize=None) def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: @@ -72,14 +72,15 @@ class View(NamedTuple): # MovementOps live here now - def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None) -> View: + def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: offset = sum([s * x[0] for s, x in zip(self.strides,arg)]) if self.mask: # move the old mask nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)]) # merge the masks if we have two mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask - return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask) + shape = [y-x for x,y in arg] + return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View: @@ -91,7 +92,7 @@ class View(NamedTuple): return self @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def shrink(self, arg: Tuple[Tuple[int, int], ...]) -> View: + def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View: assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape) return self.__unsafe_resize(arg)