mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
symbolic shape type with TypeGuard (#1852)
This commit is contained in:
1
setup.py
1
setup.py
@@ -34,6 +34,7 @@ setup(name='tinygrad',
|
||||
"flake8",
|
||||
"pylint",
|
||||
"mypy",
|
||||
"typing-extensions",
|
||||
"pre-commit",
|
||||
"ruff",
|
||||
],
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
def test_attention_training(self):
|
||||
Tensor.training = True
|
||||
self.test_attention(dropout_p=0.0)
|
||||
with self.assertRaises(TypeError):
|
||||
with self.assertRaises(AssertionError):
|
||||
# symbolic shape dropout is not supported
|
||||
self.test_attention(dropout_p=0.5)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
||||
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
|
||||
st = t.lazydata.st
|
||||
assert st.shape == (2*i+3, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
|
||||
from tinygrad.runtime.lib import buf_is_kernel_arg
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
@@ -101,13 +102,13 @@ class Kernel:
|
||||
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[int, ...]: return self.sts[0].shape
|
||||
def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
|
||||
def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape
|
||||
|
||||
@property
|
||||
def full_unupcasted_shape(self) -> Tuple[int, ...]: return self.full_shape[:self.shape_len-self.upcasted]
|
||||
def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted]
|
||||
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Node, Variable
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, sint, all_int
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
@@ -69,7 +69,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape: Tuple[int, ...] = self.shape
|
||||
intermediate_shape: Tuple[sint, ...] = self.shape
|
||||
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
|
||||
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
|
||||
if psrc[1].optype == ReduceOps:
|
||||
@@ -195,9 +195,10 @@ class LazyBuffer:
|
||||
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
|
||||
realized = self_casted.contiguous().realize().realized
|
||||
# TODO: how does this work with numpy and a symbolic shape?
|
||||
#assert all(isinstance(x, int) for x in self.shape), "no toCPU if shape is symbolic"
|
||||
return cast(RawBuffer, realized).toCPU().reshape(self.shape)
|
||||
# TODO: replace NumNode with int in shape
|
||||
output_shape = tuple(s.b if isinstance(s, NumNode) else s for s in self.shape)
|
||||
assert all_int(output_shape), f"no toCPU if shape is symbolic, {output_shape=}"
|
||||
return cast(RawBuffer, realized).toCPU().reshape(output_shape)
|
||||
|
||||
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
# srcs includes self
|
||||
@@ -226,7 +227,7 @@ class LazyBuffer:
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
@@ -236,19 +237,19 @@ class LazyBuffer:
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
@@ -271,7 +272,7 @@ class LazyBuffer:
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer:
|
||||
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand(arg)
|
||||
@@ -294,7 +295,7 @@ class LazyBuffer:
|
||||
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
|
||||
|
||||
@@ -67,6 +67,8 @@ class ConvTranspose2d:
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||
# TODO: remove this once we can represent Tensor with int shape in typing
|
||||
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(self.weight.shape[1])
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, Union, List, Optional, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
||||
from tinygrad.shape.view import View, sint
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
@@ -26,7 +26,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[sint, ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View.create(shape)]
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
@@ -34,10 +34,8 @@ class ShapeTracker:
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
||||
# NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
|
||||
# TODO: this needs to be fixed
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # type: ignore
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
@@ -53,11 +51,11 @@ class ShapeTracker:
|
||||
return real_offset.b
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]], ...]:
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
|
||||
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
@@ -121,7 +119,7 @@ class ShapeTracker:
|
||||
self.views[-1] = self.views[-1].stride(mul)
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
|
||||
def reshape(self, new_shape: Tuple[sint, ...]):
|
||||
new_view = self.views[-1].reshape(new_shape)
|
||||
if new_view is None:
|
||||
extra_view = View.create(new_shape)
|
||||
@@ -136,7 +134,7 @@ class ShapeTracker:
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
# TODO: if we remove movementops from lazy.py we can delete this
|
||||
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Optional[List[List[int]]]:
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
# Pre-allocate all groups.
|
||||
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
# Index for new_shape and axis_groups.
|
||||
|
||||
@@ -5,6 +5,7 @@ from math import gcd
|
||||
from itertools import product
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
@@ -320,6 +321,11 @@ def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
||||
assert isinstance(ret, NumNode)
|
||||
return ret.b
|
||||
|
||||
# symbolic int
|
||||
sint = Union[Node, int]
|
||||
|
||||
def all_int(t: Tuple[sint, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, NamedTuple, Union
|
||||
from typing import Tuple, List, Optional, NamedTuple
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import Variable, Node, is_sym_int
|
||||
from tinygrad.shape.symbolic import Variable, Node, is_sym_int, sint, all_int
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
@@ -27,9 +27,6 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return filter_strides(shape, tuple(strides))
|
||||
|
||||
# symbolic int
|
||||
sint = Union[Node, int]
|
||||
|
||||
class View(NamedTuple):
|
||||
shape:Tuple[sint, ...]
|
||||
strides:Tuple[sint, ...]
|
||||
@@ -128,8 +125,8 @@ class View(NamedTuple):
|
||||
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # type: ignore # mypy cannot resolve, all ints here
|
||||
if all_int(self.shape) and all_int(new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
@@ -7,9 +7,10 @@ from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device, LoadOps
|
||||
from tinygrad.shape.symbolic import NumNode, sint, all_int
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
@@ -75,7 +76,7 @@ class Tensor:
|
||||
def device(self) -> str: return self.lazydata.device
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: return self.lazydata.shape
|
||||
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType: return self.lazydata.dtype
|
||||
@@ -121,7 +122,9 @@ class Tensor:
|
||||
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
|
||||
def empty(*shape, **kwargs):
|
||||
assert all_int(shape), f"cannot create with symbolic shape {shape}"
|
||||
return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
@staticmethod
|
||||
@@ -129,13 +132,14 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs):
|
||||
assert all_int(shape), f"cannot create with symbolic shape {shape}"
|
||||
Tensor._seed += 1
|
||||
return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
||||
@@ -173,22 +177,22 @@ class Tensor:
|
||||
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
|
||||
|
||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
@@ -224,11 +228,11 @@ class Tensor:
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
|
||||
return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
|
||||
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
|
||||
@@ -299,7 +303,9 @@ class Tensor:
|
||||
if isinstance(s, int):
|
||||
dim_collapsed += 1
|
||||
else:
|
||||
final_shape.append(dim_shape)
|
||||
# TODO: replace NumNode with int in shape
|
||||
assert isinstance(dim_shape, (int, NumNode)), f"does not support symbolic shape {dim_shape}"
|
||||
final_shape.append(int(dim_shape))
|
||||
if isinstance(s, Tensor):
|
||||
tensors.append(s)
|
||||
dim.append(i-dim_collapsed)
|
||||
@@ -326,7 +332,7 @@ class Tensor:
|
||||
return ret
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
@@ -367,6 +373,7 @@ class Tensor:
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def chunk(self, num:int, dim:int) -> List[Tensor]:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
@@ -410,10 +417,10 @@ class Tensor:
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out * (math.prod(out.shape)/math.prod(self.shape))
|
||||
return out * (prod(out.shape)/prod(self.shape))
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
|
||||
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
|
||||
def _softmax(self, axis):
|
||||
m = self - self.max(axis=axis, keepdim=True)
|
||||
e = m.exp()
|
||||
@@ -429,8 +436,8 @@ class Tensor:
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return math.prod(self.shape) - idx.max() - 1
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
@@ -441,6 +448,7 @@ class Tensor:
|
||||
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
||||
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
||||
@@ -552,8 +560,12 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
@@ -693,6 +705,8 @@ class Tensor:
|
||||
return self * mask * (1/(1.0 - p))
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
# NOTE: it works if key, value have symbolic shape
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
@@ -716,7 +730,7 @@ class Tensor:
|
||||
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> int: return math.prod(self.shape)
|
||||
def numel(self) -> int: return prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
Reference in New Issue
Block a user