mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
ruff format mixin (#13261)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from tinygrad.mixin.math import MathMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
|
||||
class OpMixin(MathMixin, MovementMixin): pass
|
||||
|
||||
class OpMixin(MathMixin, MovementMixin):
|
||||
pass
|
||||
|
||||
@@ -2,24 +2,38 @@ from typing import Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
|
||||
|
||||
class MathMixin:
|
||||
# required to implement
|
||||
def alu(self, op:Ops, *src:Self) -> Self: raise NotImplementedError
|
||||
def const_like(self, b:ConstType) -> Self: raise NotImplementedError
|
||||
def alu(self, op: Ops, *src: Self) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
def const_like(self, b: ConstType) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
def ufix(self, x:Self|ConstType) -> Self: return self.const_like(x) if not isinstance(x, MathMixin) else x
|
||||
def _binop(self, op:Ops, x:Self|ConstType, reverse:bool) -> Self:
|
||||
def ufix(self, x: Self | ConstType) -> Self:
|
||||
return self.const_like(x) if not isinstance(x, MathMixin) else x
|
||||
|
||||
def _binop(self, op: Ops, x: Self | ConstType, reverse: bool) -> Self:
|
||||
return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def logical_not(self): return self.ne(True)
|
||||
|
||||
def logical_not(self):
|
||||
return self.ne(True)
|
||||
|
||||
def neg(self):
|
||||
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
||||
if (dtype := getattr(self, "dtype")) is None:
|
||||
raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1)
|
||||
|
||||
def _check_dtype(self):
|
||||
if (dtype:=getattr(self, 'dtype')) is not None:
|
||||
if isinstance(dtype, tuple): dtype = dtype[0]
|
||||
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)): raise RuntimeError(f"{dtype} is not supported")
|
||||
def add(self, x:Self|ConstType, reverse:bool=False):
|
||||
if (dtype := getattr(self, "dtype")) is not None:
|
||||
if isinstance(dtype, tuple):
|
||||
dtype = dtype[0]
|
||||
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)):
|
||||
raise RuntimeError(f"{dtype} is not supported")
|
||||
|
||||
def add(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Adds `self` and `x`.
|
||||
Equivalent to `self + x`.
|
||||
@@ -37,7 +51,8 @@ class MathMixin:
|
||||
```
|
||||
"""
|
||||
return self._binop(Ops.ADD, x, reverse)
|
||||
def mul(self, x:Self|ConstType, reverse:bool=False):
|
||||
|
||||
def mul(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Multiplies `self` and `x`.
|
||||
Equivalent to `self * x`.
|
||||
@@ -56,7 +71,8 @@ class MathMixin:
|
||||
```
|
||||
"""
|
||||
return self._binop(Ops.MUL, x, reverse)
|
||||
def bitwise_and(self, x:Self|ConstType, reverse:bool=False):
|
||||
|
||||
def bitwise_and(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Computes the bitwise AND of `self` and `x`.
|
||||
Equivalent to `self & x`.
|
||||
@@ -70,7 +86,8 @@ class MathMixin:
|
||||
"""
|
||||
self._check_dtype()
|
||||
return self._binop(Ops.AND, x, reverse)
|
||||
def bitwise_or(self, x:Self|ConstType, reverse:bool=False):
|
||||
|
||||
def bitwise_or(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Computes the bitwise OR of `self` and `x`.
|
||||
Equivalent to `self | x`.
|
||||
@@ -84,7 +101,8 @@ class MathMixin:
|
||||
"""
|
||||
self._check_dtype()
|
||||
return self._binop(Ops.OR, x, reverse)
|
||||
def bitwise_xor(self, x:Self|ConstType, reverse:bool=False):
|
||||
|
||||
def bitwise_xor(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Computes bitwise xor of `self` and `x`.
|
||||
Equivalent to `self ^ x`.
|
||||
@@ -99,7 +117,8 @@ class MathMixin:
|
||||
"""
|
||||
self._check_dtype()
|
||||
return self._binop(Ops.XOR, x, reverse)
|
||||
def idiv(self, x:Self|ConstType, reverse:bool=False):
|
||||
|
||||
def idiv(self, x: Self | ConstType, reverse: bool = False):
|
||||
"""
|
||||
Divides `self` by `x`.
|
||||
Equivalent to `self // x`.
|
||||
@@ -111,62 +130,150 @@ class MathMixin:
|
||||
```
|
||||
"""
|
||||
return self._binop(Ops.IDIV, x, reverse)
|
||||
def mod(self, x:Self|ConstType, reverse:bool=False): return self._binop(Ops.MOD, x, reverse)
|
||||
def sub(self, x:Self|ConstType, reverse:bool=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self, x:Self|ConstType, reverse:bool=False):
|
||||
return (self.ufix(x)*self.alu(Ops.RECIPROCAL)) if reverse else (self*self.ufix(x).alu(Ops.RECIPROCAL))
|
||||
|
||||
def __neg__(self): return self.neg()
|
||||
def mod(self, x: Self | ConstType, reverse: bool = False):
|
||||
return self._binop(Ops.MOD, x, reverse)
|
||||
|
||||
def __add__(self, x:Self|ConstType): return self.add(x)
|
||||
def __sub__(self, x:Self|ConstType): return self.sub(x)
|
||||
def __mul__(self, x:Self|ConstType): return self.mul(x)
|
||||
def __truediv__(self, x:Self|ConstType): return self.div(x)
|
||||
def __floordiv__(self, x:Self|ConstType): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
def __mod__(self, x:Self|ConstType): return self.mod(x)
|
||||
def __and__(self, x:Self|ConstType): return self.bitwise_and(x)
|
||||
def __or__(self, x:Self|ConstType): return self.bitwise_or(x)
|
||||
def __xor__(self, x:Self|ConstType): return self.bitwise_xor(x)
|
||||
def sub(self, x: Self | ConstType, reverse: bool = False):
|
||||
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
|
||||
def __radd__(self, x:Self|ConstType): return self.add(x, True)
|
||||
def __rsub__(self, x:Self|ConstType): return self.sub(x, True)
|
||||
def __rmul__(self, x:Self|ConstType): return self.mul(x, True)
|
||||
def __rtruediv__(self, x:Self|ConstType): return self.div(x, True)
|
||||
def __rfloordiv__(self, x:Self|ConstType): return self.idiv(x, True)
|
||||
def __rand__(self, x:Self|ConstType): return self.bitwise_and(x, True)
|
||||
def __ror__(self, x:Self|ConstType): return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x:Self|ConstType): return self.bitwise_xor(x, True)
|
||||
def __rmod__(self, x:Self|ConstType): return self.mod(x, True)
|
||||
def div(self, x: Self | ConstType, reverse: bool = False):
|
||||
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
|
||||
|
||||
def __lt__(self, x:Self|ConstType): return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
def __gt__(self, x:Self|ConstType): return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
def __ge__(self, x:Self|ConstType): return (self < x).logical_not()
|
||||
def __le__(self, x:Self|ConstType): return (self > x).logical_not()
|
||||
def __neg__(self):
|
||||
return self.neg()
|
||||
|
||||
def __add__(self, x: Self | ConstType):
|
||||
return self.add(x)
|
||||
|
||||
def __sub__(self, x: Self | ConstType):
|
||||
return self.sub(x)
|
||||
|
||||
def __mul__(self, x: Self | ConstType):
|
||||
return self.mul(x)
|
||||
|
||||
def __truediv__(self, x: Self | ConstType):
|
||||
return self.div(x)
|
||||
|
||||
def __floordiv__(self, x: Self | ConstType):
|
||||
return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
|
||||
def __mod__(self, x: Self | ConstType):
|
||||
return self.mod(x)
|
||||
|
||||
def __and__(self, x: Self | ConstType):
|
||||
return self.bitwise_and(x)
|
||||
|
||||
def __or__(self, x: Self | ConstType):
|
||||
return self.bitwise_or(x)
|
||||
|
||||
def __xor__(self, x: Self | ConstType):
|
||||
return self.bitwise_xor(x)
|
||||
|
||||
def __radd__(self, x: Self | ConstType):
|
||||
return self.add(x, True)
|
||||
|
||||
def __rsub__(self, x: Self | ConstType):
|
||||
return self.sub(x, True)
|
||||
|
||||
def __rmul__(self, x: Self | ConstType):
|
||||
return self.mul(x, True)
|
||||
|
||||
def __rtruediv__(self, x: Self | ConstType):
|
||||
return self.div(x, True)
|
||||
|
||||
def __rfloordiv__(self, x: Self | ConstType):
|
||||
return self.idiv(x, True)
|
||||
|
||||
def __rand__(self, x: Self | ConstType):
|
||||
return self.bitwise_and(x, True)
|
||||
|
||||
def __ror__(self, x: Self | ConstType):
|
||||
return self.bitwise_or(x, True)
|
||||
|
||||
def __rxor__(self, x: Self | ConstType):
|
||||
return self.bitwise_xor(x, True)
|
||||
|
||||
def __rmod__(self, x: Self | ConstType):
|
||||
return self.mod(x, True)
|
||||
|
||||
def __lt__(self, x: Self | ConstType):
|
||||
return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
|
||||
def __gt__(self, x: Self | ConstType):
|
||||
return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
|
||||
def __ge__(self, x: Self | ConstType):
|
||||
return (self < x).logical_not()
|
||||
|
||||
def __le__(self, x: Self | ConstType):
|
||||
return (self > x).logical_not()
|
||||
|
||||
def ne(self, x: Self | ConstType):
|
||||
return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
|
||||
def eq(self, x: Self | ConstType):
|
||||
return self.ne(x).logical_not()
|
||||
|
||||
def __ne__(self, x: Self | ConstType): # type: ignore[override]
|
||||
return self.ne(x)
|
||||
|
||||
def ne(self, x:Self|ConstType): return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
def eq(self, x:Self|ConstType): return self.ne(x).logical_not()
|
||||
def __ne__(self, x:Self|ConstType): return self.ne(x) # type: ignore[override]
|
||||
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
||||
|
||||
def lshift(self, x:Self|int, reverse:bool=False): return self._binop(Ops.SHL, x, reverse)
|
||||
def rshift(self, x:Self|int, reverse:bool=False): return self._binop(Ops.SHR, x, reverse)
|
||||
def __lshift__(self, x:Self|int): return self.lshift(x)
|
||||
def __rshift__(self, x:Self|int): return self.rshift(x)
|
||||
def __rlshift__(self, x:Self|int): return self.lshift(x, True)
|
||||
def __rrshift__(self, x:Self|int): return self.rshift(x, True)
|
||||
def lshift(self, x: Self | int, reverse: bool = False):
|
||||
return self._binop(Ops.SHL, x, reverse)
|
||||
|
||||
def maximum(self, x:Self|ConstType): return self.alu(Ops.MAX, self.ufix(x))
|
||||
def minimum(self, x:Self|ConstType): return -(-self).maximum(-x)
|
||||
def where(self, x:Self|ConstType, y:Self|ConstType):
|
||||
if isinstance(x, type(self)): return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
if isinstance(y, type(self)): return self.alu(Ops.WHERE, y.ufix(x), y)
|
||||
def rshift(self, x: Self | int, reverse: bool = False):
|
||||
return self._binop(Ops.SHR, x, reverse)
|
||||
|
||||
def __lshift__(self, x: Self | int):
|
||||
return self.lshift(x)
|
||||
|
||||
def __rshift__(self, x: Self | int):
|
||||
return self.rshift(x)
|
||||
|
||||
def __rlshift__(self, x: Self | int):
|
||||
return self.lshift(x, True)
|
||||
|
||||
def __rrshift__(self, x: Self | int):
|
||||
return self.rshift(x, True)
|
||||
|
||||
def maximum(self, x: Self | ConstType):
|
||||
return self.alu(Ops.MAX, self.ufix(x))
|
||||
|
||||
def minimum(self, x: Self | ConstType):
|
||||
return -(-self).maximum(-x)
|
||||
|
||||
def where(self, x: Self | ConstType, y: Self | ConstType):
|
||||
if isinstance(x, type(self)):
|
||||
return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
if isinstance(y, type(self)):
|
||||
return self.alu(Ops.WHERE, y.ufix(x), y)
|
||||
raise RuntimeError("where needs at least one UOp arg")
|
||||
def threefry(self, seed:Self): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIPROCAL)
|
||||
def trunc(self): return self.alu(Ops.TRUNC)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
def log2(self): return self.alu(Ops.LOG2)
|
||||
def exp2(self): return self.alu(Ops.EXP2)
|
||||
def pow(self, x:Self|ConstType): return self.alu(Ops.POW, self.ufix(x))
|
||||
def __pow__(self, x:Self|ConstType): return self.pow(x)
|
||||
|
||||
def threefry(self, seed: Self):
|
||||
return self.alu(Ops.THREEFRY, seed)
|
||||
|
||||
def reciprocal(self):
|
||||
return self.alu(Ops.RECIPROCAL)
|
||||
|
||||
def trunc(self):
|
||||
return self.alu(Ops.TRUNC)
|
||||
|
||||
def sqrt(self):
|
||||
return self.alu(Ops.SQRT)
|
||||
|
||||
def sin(self):
|
||||
return self.alu(Ops.SIN)
|
||||
|
||||
def log2(self):
|
||||
return self.alu(Ops.LOG2)
|
||||
|
||||
def exp2(self):
|
||||
return self.alu(Ops.EXP2)
|
||||
|
||||
def pow(self, x: Self | ConstType):
|
||||
return self.alu(Ops.POW, self.ufix(x))
|
||||
|
||||
def __pow__(self, x: Self | ConstType):
|
||||
return self.pow(x)
|
||||
|
||||
@@ -4,19 +4,26 @@ from typing import TypeAlias, TYPE_CHECKING, Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
|
||||
from tinygrad.uop.ops import resolve, smax
|
||||
if TYPE_CHECKING: from tinygrad.uop.ops import UOp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import UOp
|
||||
sint: TypeAlias = "UOp | int"
|
||||
|
||||
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
||||
|
||||
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
||||
# unsqueeze left to make every shape same length
|
||||
max_dim = max(len(shape) for shape in shapes)
|
||||
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
||||
|
||||
|
||||
class MovementMixin:
|
||||
# required to implement
|
||||
def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError
|
||||
def _mop(self, op: Ops, arg) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: raise NotImplementedError
|
||||
def shape(self) -> tuple[sint, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
@property
|
||||
@@ -42,18 +49,21 @@ class MovementMixin:
|
||||
"""
|
||||
return prod(self.shape)
|
||||
|
||||
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
|
||||
def _resolve_dim(self, dim: int, *, extra: bool = False) -> int:
|
||||
total = self.ndim + int(extra)
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
if not -max(1, total) <= dim <= max(1, total) - 1:
|
||||
raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Self:
|
||||
if self.shape == new_shape: return self
|
||||
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
||||
def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self:
|
||||
if self.shape == new_shape:
|
||||
return self
|
||||
if self.ndim > len(new_shape):
|
||||
raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
||||
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
||||
shape, _ = _align_left(self.shape, new_shape)
|
||||
# for each dimension, check either dim is 1, or it does not change
|
||||
if not all(s == ns or s == 1 for s,ns in zip(shape, new_shape)):
|
||||
if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
reshaped = self.reshape(shape)
|
||||
ret = reshaped._mop(Ops.EXPAND, arg=new_shape)
|
||||
@@ -85,15 +95,18 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
# resolve None and args
|
||||
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
||||
new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
|
||||
# resolve -1
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||
if (c := new_shape.count(-1)) > 1:
|
||||
raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c:
|
||||
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
if prod(self.shape) != prod(new_shape):
|
||||
raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||
ret = self._mop(Ops.RESHAPE, arg=new_shape)
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Self:
|
||||
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
|
||||
"""
|
||||
Returns a tensor that shrinks the each axis based on input arg.
|
||||
`arg` must have the same length as `self.ndim`.
|
||||
@@ -110,8 +123,9 @@ class MovementMixin:
|
||||
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)])
|
||||
if self.ndim != len(arg):
|
||||
raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def permute(self, order, *args) -> Self:
|
||||
@@ -129,7 +143,8 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
if sorted(order_arg) != list(range(self.ndim)):
|
||||
raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
|
||||
|
||||
def flip(self, axis, *args) -> Self:
|
||||
@@ -150,7 +165,8 @@ class MovementMixin:
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
if len(axis_arg) != len(dedup(axis_arg)):
|
||||
raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
|
||||
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
|
||||
|
||||
@@ -163,7 +179,7 @@ class MovementMixin:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
return self.reshape(shape, *args)
|
||||
|
||||
def squeeze(self, dim:int|None=None) -> Self:
|
||||
def squeeze(self, dim: int | None = None) -> Self:
|
||||
"""
|
||||
Returns a tensor with specified dimensions of input of size 1 removed.
|
||||
If `dim` is not specified, all dimensions with size 1 are removed.
|
||||
@@ -179,11 +195,12 @@ class MovementMixin:
|
||||
print(t.squeeze(1).shape)
|
||||
```
|
||||
"""
|
||||
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
||||
if dim is None:
|
||||
return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
||||
dim = self._resolve_dim(dim)
|
||||
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
||||
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])
|
||||
|
||||
def unsqueeze(self, dim:int) -> Self:
|
||||
def unsqueeze(self, dim: int) -> Self:
|
||||
"""
|
||||
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
||||
|
||||
@@ -234,9 +251,9 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])
|
||||
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self:
|
||||
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
|
||||
"""
|
||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||
|
||||
@@ -251,9 +268,9 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])
|
||||
|
||||
def rearrange(self, formula:str, **sizes) -> Self:
|
||||
def rearrange(self, formula: str, **sizes) -> Self:
|
||||
"""
|
||||
Rearranges input according to formula
|
||||
|
||||
@@ -264,38 +281,43 @@ class MovementMixin:
|
||||
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
||||
```
|
||||
"""
|
||||
|
||||
def parse_formula(formula: str):
|
||||
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
||||
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
||||
pairs = list(zip(lparens, rparens))
|
||||
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
||||
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
||||
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
|
||||
|
||||
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
||||
|
||||
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
||||
|
||||
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
||||
for name in sizes:
|
||||
assert name in lhs, f"axis {name} is not used in transform"
|
||||
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
||||
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
||||
for name in flatten((lhs, rhs)):
|
||||
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
||||
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
||||
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
||||
|
||||
# resolve ellipsis
|
||||
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
||||
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
||||
if "..." in lhs:
|
||||
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
||||
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
|
||||
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
||||
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
||||
|
||||
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
||||
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
||||
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
||||
for i, name in enumerate(lhs):
|
||||
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
||||
t = t.permute([lhs.index(name) for name in rhs])
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
|
||||
# *** movement ops with expand ***
|
||||
|
||||
def repeat_interleave(self, repeats:int, dim:int|None=None) -> Self:
|
||||
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
|
||||
"""
|
||||
Repeats elements of a tensor.
|
||||
|
||||
@@ -306,7 +328,10 @@ class MovementMixin:
|
||||
"""
|
||||
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
|
||||
shp = x.shape
|
||||
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
|
||||
x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
|
||||
x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
|
||||
x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
|
||||
return x
|
||||
|
||||
def repeat(self, repeats, *args) -> Self:
|
||||
"""
|
||||
@@ -323,28 +348,29 @@ class MovementMixin:
|
||||
"""
|
||||
repeats = argfix(repeats, *args)
|
||||
base_shape = _align_left(self.shape, repeats)[0]
|
||||
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r,s in zip(repeats, base_shape)])
|
||||
expanded_shape = flatten([[s] if r == 1 else [r, s] for r,s in zip(repeats, base_shape)])
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
|
||||
expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
|
||||
final_shape = [r * s for r, s in zip(repeats, base_shape)]
|
||||
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
||||
|
||||
# **** pool level ****
|
||||
|
||||
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Self:
|
||||
def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
||||
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||
noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :]
|
||||
assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)]
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)])
|
||||
x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_)))
|
||||
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])
|
||||
|
||||
Reference in New Issue
Block a user