symbolic shape type with TypeGuard (#1852)

This commit is contained in:
chenyu
2023-09-12 14:27:22 -07:00
committed by GitHub
parent c91b44f7bf
commit e67306ba04
10 changed files with 73 additions and 53 deletions

View File

@@ -34,6 +34,7 @@ setup(name='tinygrad',
"flake8",
"pylint",
"mypy",
"typing-extensions",
"pre-commit",
"ruff",
],

View File

@@ -58,7 +58,7 @@ class TestSymbolicOps(unittest.TestCase):
def test_attention_training(self):
Tensor.training = True
self.test_attention(dropout_p=0.0)
with self.assertRaises(TypeError):
with self.assertRaises(AssertionError):
# symbolic shape dropout is not supported
self.test_attention(dropout_p=0.5)

View File

@@ -34,7 +34,7 @@ class TestSymbolic(unittest.TestCase):
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)

View File

@@ -5,6 +5,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
from tinygrad.runtime.lib import buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
class LocalBuffer(NamedTuple):
@@ -101,13 +102,13 @@ class Kernel:
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
@property
def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
@property
def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
@property
def full_unupcasted_shape(self) -> Tuple[int, ...]: return self.full_shape[:self.shape_len-self.upcasted]
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
@property
def shape_len(self) -> int: return len(self.sts[0].shape)

View File

@@ -8,7 +8,7 @@ from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Node, Variable
from tinygrad.shape.symbolic import Variable, NumNode, sint, all_int
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
@@ -69,7 +69,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[int, ...] = self.shape
intermediate_shape: Tuple[sint, ...] = self.shape
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
if psrc[1].optype == ReduceOps:
@@ -195,9 +195,10 @@ class LazyBuffer:
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
realized = self_casted.contiguous().realize().realized
# TODO: how does this work with numpy and a symbolic shape?
#assert all(isinstance(x, int) for x in self.shape), "no toCPU if shape is symbolic"
return cast(RawBuffer, realized).toCPU().reshape(self.shape)
# TODO: replace NumNode with int in shape
output_shape = tuple(s.b if isinstance(s, NumNode) else s for s in self.shape)
assert all_int(output_shape), f"no toCPU if shape is symbolic, {output_shape=}"
return cast(RawBuffer, realized).toCPU().reshape(output_shape)
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# srcs includes self
@@ -226,7 +227,7 @@ class LazyBuffer:
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
@@ -236,19 +237,19 @@ class LazyBuffer:
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
@@ -271,7 +272,7 @@ class LazyBuffer:
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer:
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand(arg)
@@ -294,7 +295,7 @@ class LazyBuffer:
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
# move permutes before reshapes if we can
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)

View File

@@ -67,6 +67,8 @@ class ConvTranspose2d:
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
# TODO: remove this once we can represent Tensor with int shape in typing
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
bound = 1 / math.sqrt(self.weight.shape[1])
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import functools
from typing import Tuple, Union, List, Optional, cast
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
from tinygrad.shape.view import View, sint
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
@@ -26,7 +26,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
class ShapeTracker:
__slots__ = "views"
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
def __init__(self, shape:Union[ShapeTracker, Tuple[sint, ...]], views:Optional[List[View]]=None):
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View.create(shape)]
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
@@ -34,10 +34,8 @@ class ShapeTracker:
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
# NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
# TODO: this needs to be fixed
@property
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # type: ignore
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
@property
def key(self) -> Tuple[View, ...]: return tuple(self.views)
@@ -53,11 +51,11 @@ class ShapeTracker:
return real_offset.b
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]], ...]:
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
@@ -121,7 +119,7 @@ class ShapeTracker:
self.views[-1] = self.views[-1].stride(mul)
return self
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
def reshape(self, new_shape: Tuple[sint, ...]):
new_view = self.views[-1].reshape(new_shape)
if new_view is None:
extra_view = View.create(new_shape)
@@ -136,7 +134,7 @@ class ShapeTracker:
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Optional[List[List[int]]]:
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
# Pre-allocate all groups.
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
# Index for new_shape and axis_groups.

View File

@@ -5,6 +5,7 @@ from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
from typing_extensions import TypeGuard
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@@ -320,6 +321,11 @@ def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
assert isinstance(ret, NumNode)
return ret.b
# symbolic int
sint = Union[Node, int]
def all_int(t: Tuple[sint, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import functools
from typing import Tuple, List, Optional, NamedTuple, Union
from typing import Tuple, List, Optional, NamedTuple
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable, Node, is_sym_int
from tinygrad.shape.symbolic import Variable, Node, is_sym_int, sint, all_int
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
@@ -27,9 +27,6 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return filter_strides(shape, tuple(strides))
# symbolic int
sint = Union[Node, int]
class View(NamedTuple):
shape:Tuple[sint, ...]
strides:Tuple[sint, ...]
@@ -128,8 +125,8 @@ class View(NamedTuple):
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # type: ignore # mypy cannot resolve, all ints here
if all_int(self.shape) and all_int(new_shape):
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# after the asserts, it's okay to check contiguous
if self.contiguous: return View.create(new_shape)

View File

@@ -7,9 +7,10 @@ from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import NumNode, sint, all_int
# An instantiation of the Function is the Context
class Function:
@@ -75,7 +76,7 @@ class Tensor:
def device(self) -> str: return self.lazydata.device
@property
def shape(self) -> Tuple[int, ...]: return self.lazydata.shape
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
@@ -121,7 +122,9 @@ class Tensor:
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
def empty(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
_seed: int = int(time.time())
@staticmethod
@@ -129,13 +132,14 @@ class Tensor:
@staticmethod
def rand(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
@@ -173,22 +177,22 @@ class Tensor:
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
# ***** toposort and backward pass *****
@@ -224,11 +228,11 @@ class Tensor:
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
@@ -299,7 +303,9 @@ class Tensor:
if isinstance(s, int):
dim_collapsed += 1
else:
final_shape.append(dim_shape)
# TODO: replace NumNode with int in shape
assert isinstance(dim_shape, (int, NumNode)), f"does not support symbolic shape {dim_shape}"
final_shape.append(int(dim_shape))
if isinstance(s, Tensor):
tensors.append(s)
dim.append(i-dim_collapsed)
@@ -326,7 +332,7 @@ class Tensor:
return ret
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
@@ -367,6 +373,7 @@ class Tensor:
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def chunk(self, num:int, dim:int) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
return [self[tuple(sl)] for sl in slice_params]
@@ -410,10 +417,10 @@ class Tensor:
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out * (math.prod(out.shape)/math.prod(self.shape))
return out * (prod(out.shape)/prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
@@ -429,8 +436,8 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return math.prod(self.shape) - idx.max() - 1
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
@@ -441,6 +448,7 @@ class Tensor:
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
@@ -552,8 +560,12 @@ class Tensor:
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
def triu(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
@@ -693,6 +705,8 @@ class Tensor:
return self * mask * (1/(1.0 - p))
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
@@ -716,7 +730,7 @@ class Tensor:
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return math.prod(self.shape)
def numel(self) -> int: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)