no numnode in shape (#1871)

This commit is contained in:
chenyu
2023-09-16 16:49:45 -07:00
committed by GitHub
parent 18ec5a9e09
commit cd66c9e249
5 changed files with 21 additions and 21 deletions

View File

@@ -8,7 +8,7 @@ from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Variable, NumNode, sint, all_int
from tinygrad.shape.symbolic import Variable, sint, all_int
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
@@ -195,10 +195,8 @@ class LazyBuffer:
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
realized = self_casted.contiguous().realize().realized
# TODO: replace NumNode with int in shape
output_shape = tuple(s.b if isinstance(s, NumNode) else s for s in self.shape)
assert all_int(output_shape), f"no toCPU if shape is symbolic, {output_shape=}"
return cast(RawBuffer, realized).toCPU().reshape(output_shape)
assert all_int(self.shape), f"no toCPU if shape is symbolic, {self.shape=}"
return cast(RawBuffer, realized).toCPU().reshape(self.shape)
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# srcs includes self
@@ -227,7 +225,7 @@ class LazyBuffer:
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
@@ -301,7 +299,7 @@ class LazyBuffer:
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)

View File

@@ -1,9 +1,10 @@
import math
from typing import Tuple, Optional
from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
@@ -192,12 +193,14 @@ class Pad(Function):
return grad_output.shrink(self.narg)
class Shrink(Function):
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.pad(self.narg)
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
# need this cast because mypy cannot narrow the type even with assert
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
class Flip(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:

View File

@@ -45,10 +45,9 @@ class ShapeTracker:
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> int:
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
return real_offset.b
def real_offset(self) -> sint:
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
@@ -59,7 +58,7 @@ class ShapeTracker:
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
elif isinstance(this_dim, Variable) and this_dim in idxs:
ret[idxs.index(this_dim)] = 1
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
@@ -103,7 +102,7 @@ class ShapeTracker:
self.views[-1] = self.views[-1].pad(arg)
return self
def shrink(self, arg: Tuple[Tuple[int, int], ...]):
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]):
self.views[-1] = self.views[-1].shrink(arg)
return self

View File

@@ -157,7 +157,6 @@ class NumNode(Node):
self.b:int = num
self.min, self.max = num, num
def __int__(self): return self.b
def __index__(self): return self.b
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import functools
from typing import Tuple, List, Optional, NamedTuple
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable, Node, is_sym_int, sint, all_int
from tinygrad.shape.symbolic import Variable, Node, NumNode, is_sym_int, sint, all_int
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
@@ -72,14 +72,15 @@ class View(NamedTuple):
# MovementOps live here now
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None) -> View:
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
if self.mask:
# move the old mask
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
# merge the masks if we have two
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
shape = [y-x for x,y in arg]
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
@@ -91,7 +92,7 @@ class View(NamedTuple):
return self
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def shrink(self, arg: Tuple[Tuple[int, int], ...]) -> View:
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
return self.__unsafe_resize(arg)