mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
no numnode in shape (#1871)
This commit is contained in:
@@ -8,7 +8,7 @@ from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, sint, all_int
|
||||
from tinygrad.shape.symbolic import Variable, sint, all_int
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
@@ -195,10 +195,8 @@ class LazyBuffer:
|
||||
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
|
||||
realized = self_casted.contiguous().realize().realized
|
||||
# TODO: replace NumNode with int in shape
|
||||
output_shape = tuple(s.b if isinstance(s, NumNode) else s for s in self.shape)
|
||||
assert all_int(output_shape), f"no toCPU if shape is symbolic, {output_shape=}"
|
||||
return cast(RawBuffer, realized).toCPU().reshape(output_shape)
|
||||
assert all_int(self.shape), f"no toCPU if shape is symbolic, {self.shape=}"
|
||||
return cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||
|
||||
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# srcs includes self
|
||||
@@ -227,7 +225,7 @@ class LazyBuffer:
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
@@ -301,7 +299,7 @@ class LazyBuffer:
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, cast
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
@@ -192,12 +193,14 @@ class Pad(Function):
|
||||
return grad_output.shrink(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.pad(self.narg)
|
||||
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
|
||||
# need this cast because mypy cannot narrow the type even with assert
|
||||
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
|
||||
@@ -45,10 +45,9 @@ class ShapeTracker:
|
||||
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> int:
|
||||
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
|
||||
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
|
||||
return real_offset.b
|
||||
def real_offset(self) -> sint:
|
||||
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
|
||||
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
@@ -59,7 +58,7 @@ class ShapeTracker:
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
elif isinstance(this_dim, Variable) and this_dim in idxs:
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
@@ -103,7 +102,7 @@ class ShapeTracker:
|
||||
self.views[-1] = self.views[-1].pad(arg)
|
||||
return self
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]):
|
||||
self.views[-1] = self.views[-1].shrink(arg)
|
||||
return self
|
||||
|
||||
|
||||
@@ -157,7 +157,6 @@ class NumNode(Node):
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def __int__(self): return self.b
|
||||
def __index__(self): return self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, NamedTuple
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import Variable, Node, is_sym_int, sint, all_int
|
||||
from tinygrad.shape.symbolic import Variable, Node, NumNode, is_sym_int, sint, all_int
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
@@ -72,14 +72,15 @@ class View(NamedTuple):
|
||||
|
||||
# MovementOps live here now
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None) -> View:
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
||||
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
||||
if self.mask:
|
||||
# move the old mask
|
||||
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
||||
shape = [y-x for x,y in arg]
|
||||
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
|
||||
@@ -91,7 +92,7 @@ class View(NamedTuple):
|
||||
return self
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def shrink(self, arg: Tuple[Tuple[int, int], ...]) -> View:
|
||||
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
||||
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
||||
return self.__unsafe_resize(arg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user