ruff format mixin (#13261)

This commit is contained in:
George Hotz
2025-11-13 10:10:38 -08:00
committed by GitHub
parent 3049f3edda
commit 6b1bae6614
3 changed files with 251 additions and 116 deletions

View File

@@ -1,4 +1,6 @@
from tinygrad.mixin.math import MathMixin
from tinygrad.mixin.movement import MovementMixin
class OpMixin(MathMixin, MovementMixin): pass
class OpMixin(MathMixin, MovementMixin):
pass

View File

@@ -2,24 +2,38 @@ from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType
class MathMixin:
# required to implement
def alu(self, op:Ops, *src:Self) -> Self: raise NotImplementedError
def const_like(self, b:ConstType) -> Self: raise NotImplementedError
def alu(self, op: Ops, *src: Self) -> Self:
raise NotImplementedError
def const_like(self, b: ConstType) -> Self:
raise NotImplementedError
# great functions you get!
def ufix(self, x:Self|ConstType) -> Self: return self.const_like(x) if not isinstance(x, MathMixin) else x
def _binop(self, op:Ops, x:Self|ConstType, reverse:bool) -> Self:
def ufix(self, x: Self | ConstType) -> Self:
return self.const_like(x) if not isinstance(x, MathMixin) else x
def _binop(self, op: Ops, x: Self | ConstType, reverse: bool) -> Self:
return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def logical_not(self): return self.ne(True)
def logical_not(self):
return self.ne(True)
def neg(self):
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
if (dtype := getattr(self, "dtype")) is None:
raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
return self.logical_not() if dtype.scalar() == dtypes.bool else self * (-1)
def _check_dtype(self):
if (dtype:=getattr(self, 'dtype')) is not None:
if isinstance(dtype, tuple): dtype = dtype[0]
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)): raise RuntimeError(f"{dtype} is not supported")
def add(self, x:Self|ConstType, reverse:bool=False):
if (dtype := getattr(self, "dtype")) is not None:
if isinstance(dtype, tuple):
dtype = dtype[0]
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)):
raise RuntimeError(f"{dtype} is not supported")
def add(self, x: Self | ConstType, reverse: bool = False):
"""
Adds `self` and `x`.
Equivalent to `self + x`.
@@ -37,7 +51,8 @@ class MathMixin:
```
"""
return self._binop(Ops.ADD, x, reverse)
def mul(self, x:Self|ConstType, reverse:bool=False):
def mul(self, x: Self | ConstType, reverse: bool = False):
"""
Multiplies `self` and `x`.
Equivalent to `self * x`.
@@ -56,7 +71,8 @@ class MathMixin:
```
"""
return self._binop(Ops.MUL, x, reverse)
def bitwise_and(self, x:Self|ConstType, reverse:bool=False):
def bitwise_and(self, x: Self | ConstType, reverse: bool = False):
"""
Computes the bitwise AND of `self` and `x`.
Equivalent to `self & x`.
@@ -70,7 +86,8 @@ class MathMixin:
"""
self._check_dtype()
return self._binop(Ops.AND, x, reverse)
def bitwise_or(self, x:Self|ConstType, reverse:bool=False):
def bitwise_or(self, x: Self | ConstType, reverse: bool = False):
"""
Computes the bitwise OR of `self` and `x`.
Equivalent to `self | x`.
@@ -84,7 +101,8 @@ class MathMixin:
"""
self._check_dtype()
return self._binop(Ops.OR, x, reverse)
def bitwise_xor(self, x:Self|ConstType, reverse:bool=False):
def bitwise_xor(self, x: Self | ConstType, reverse: bool = False):
"""
Computes bitwise xor of `self` and `x`.
Equivalent to `self ^ x`.
@@ -99,7 +117,8 @@ class MathMixin:
"""
self._check_dtype()
return self._binop(Ops.XOR, x, reverse)
def idiv(self, x:Self|ConstType, reverse:bool=False):
def idiv(self, x: Self | ConstType, reverse: bool = False):
"""
Divides `self` by `x`.
Equivalent to `self // x`.
@@ -111,62 +130,150 @@ class MathMixin:
```
"""
return self._binop(Ops.IDIV, x, reverse)
def mod(self, x:Self|ConstType, reverse:bool=False): return self._binop(Ops.MOD, x, reverse)
def sub(self, x:Self|ConstType, reverse:bool=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def div(self, x:Self|ConstType, reverse:bool=False):
return (self.ufix(x)*self.alu(Ops.RECIPROCAL)) if reverse else (self*self.ufix(x).alu(Ops.RECIPROCAL))
def __neg__(self): return self.neg()
def mod(self, x: Self | ConstType, reverse: bool = False):
return self._binop(Ops.MOD, x, reverse)
def __add__(self, x:Self|ConstType): return self.add(x)
def __sub__(self, x:Self|ConstType): return self.sub(x)
def __mul__(self, x:Self|ConstType): return self.mul(x)
def __truediv__(self, x:Self|ConstType): return self.div(x)
def __floordiv__(self, x:Self|ConstType): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
def __mod__(self, x:Self|ConstType): return self.mod(x)
def __and__(self, x:Self|ConstType): return self.bitwise_and(x)
def __or__(self, x:Self|ConstType): return self.bitwise_or(x)
def __xor__(self, x:Self|ConstType): return self.bitwise_xor(x)
def sub(self, x: Self | ConstType, reverse: bool = False):
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def __radd__(self, x:Self|ConstType): return self.add(x, True)
def __rsub__(self, x:Self|ConstType): return self.sub(x, True)
def __rmul__(self, x:Self|ConstType): return self.mul(x, True)
def __rtruediv__(self, x:Self|ConstType): return self.div(x, True)
def __rfloordiv__(self, x:Self|ConstType): return self.idiv(x, True)
def __rand__(self, x:Self|ConstType): return self.bitwise_and(x, True)
def __ror__(self, x:Self|ConstType): return self.bitwise_or(x, True)
def __rxor__(self, x:Self|ConstType): return self.bitwise_xor(x, True)
def __rmod__(self, x:Self|ConstType): return self.mod(x, True)
def div(self, x: Self | ConstType, reverse: bool = False):
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
def __lt__(self, x:Self|ConstType): return self.alu(Ops.CMPLT, self.ufix(x))
def __gt__(self, x:Self|ConstType): return self.ufix(x).alu(Ops.CMPLT, self)
def __ge__(self, x:Self|ConstType): return (self < x).logical_not()
def __le__(self, x:Self|ConstType): return (self > x).logical_not()
def __neg__(self):
return self.neg()
def __add__(self, x: Self | ConstType):
return self.add(x)
def __sub__(self, x: Self | ConstType):
return self.sub(x)
def __mul__(self, x: Self | ConstType):
return self.mul(x)
def __truediv__(self, x: Self | ConstType):
return self.div(x)
def __floordiv__(self, x: Self | ConstType):
return self.idiv(x) # TODO: idiv is trunc div, not floordiv
def __mod__(self, x: Self | ConstType):
return self.mod(x)
def __and__(self, x: Self | ConstType):
return self.bitwise_and(x)
def __or__(self, x: Self | ConstType):
return self.bitwise_or(x)
def __xor__(self, x: Self | ConstType):
return self.bitwise_xor(x)
def __radd__(self, x: Self | ConstType):
return self.add(x, True)
def __rsub__(self, x: Self | ConstType):
return self.sub(x, True)
def __rmul__(self, x: Self | ConstType):
return self.mul(x, True)
def __rtruediv__(self, x: Self | ConstType):
return self.div(x, True)
def __rfloordiv__(self, x: Self | ConstType):
return self.idiv(x, True)
def __rand__(self, x: Self | ConstType):
return self.bitwise_and(x, True)
def __ror__(self, x: Self | ConstType):
return self.bitwise_or(x, True)
def __rxor__(self, x: Self | ConstType):
return self.bitwise_xor(x, True)
def __rmod__(self, x: Self | ConstType):
return self.mod(x, True)
def __lt__(self, x: Self | ConstType):
return self.alu(Ops.CMPLT, self.ufix(x))
def __gt__(self, x: Self | ConstType):
return self.ufix(x).alu(Ops.CMPLT, self)
def __ge__(self, x: Self | ConstType):
return (self < x).logical_not()
def __le__(self, x: Self | ConstType):
return (self > x).logical_not()
def ne(self, x: Self | ConstType):
return self.alu(Ops.CMPNE, self.ufix(x))
def eq(self, x: Self | ConstType):
return self.ne(x).logical_not()
def __ne__(self, x: Self | ConstType): # type: ignore[override]
return self.ne(x)
def ne(self, x:Self|ConstType): return self.alu(Ops.CMPNE, self.ufix(x))
def eq(self, x:Self|ConstType): return self.ne(x).logical_not()
def __ne__(self, x:Self|ConstType): return self.ne(x) # type: ignore[override]
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
def lshift(self, x:Self|int, reverse:bool=False): return self._binop(Ops.SHL, x, reverse)
def rshift(self, x:Self|int, reverse:bool=False): return self._binop(Ops.SHR, x, reverse)
def __lshift__(self, x:Self|int): return self.lshift(x)
def __rshift__(self, x:Self|int): return self.rshift(x)
def __rlshift__(self, x:Self|int): return self.lshift(x, True)
def __rrshift__(self, x:Self|int): return self.rshift(x, True)
def lshift(self, x: Self | int, reverse: bool = False):
return self._binop(Ops.SHL, x, reverse)
def maximum(self, x:Self|ConstType): return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x:Self|ConstType): return -(-self).maximum(-x)
def where(self, x:Self|ConstType, y:Self|ConstType):
if isinstance(x, type(self)): return self.alu(Ops.WHERE, x, x.ufix(y))
if isinstance(y, type(self)): return self.alu(Ops.WHERE, y.ufix(x), y)
def rshift(self, x: Self | int, reverse: bool = False):
return self._binop(Ops.SHR, x, reverse)
def __lshift__(self, x: Self | int):
return self.lshift(x)
def __rshift__(self, x: Self | int):
return self.rshift(x)
def __rlshift__(self, x: Self | int):
return self.lshift(x, True)
def __rrshift__(self, x: Self | int):
return self.rshift(x, True)
def maximum(self, x: Self | ConstType):
return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x: Self | ConstType):
return -(-self).maximum(-x)
def where(self, x: Self | ConstType, y: Self | ConstType):
if isinstance(x, type(self)):
return self.alu(Ops.WHERE, x, x.ufix(y))
if isinstance(y, type(self)):
return self.alu(Ops.WHERE, y.ufix(x), y)
raise RuntimeError("where needs at least one UOp arg")
def threefry(self, seed:Self): return self.alu(Ops.THREEFRY, seed)
def reciprocal(self): return self.alu(Ops.RECIPROCAL)
def trunc(self): return self.alu(Ops.TRUNC)
def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)
def exp2(self): return self.alu(Ops.EXP2)
def pow(self, x:Self|ConstType): return self.alu(Ops.POW, self.ufix(x))
def __pow__(self, x:Self|ConstType): return self.pow(x)
def threefry(self, seed: Self):
return self.alu(Ops.THREEFRY, seed)
def reciprocal(self):
return self.alu(Ops.RECIPROCAL)
def trunc(self):
return self.alu(Ops.TRUNC)
def sqrt(self):
return self.alu(Ops.SQRT)
def sin(self):
return self.alu(Ops.SIN)
def log2(self):
return self.alu(Ops.LOG2)
def exp2(self):
return self.alu(Ops.EXP2)
def pow(self, x: Self | ConstType):
return self.alu(Ops.POW, self.ufix(x))
def __pow__(self, x: Self | ConstType):
return self.pow(x)

View File

@@ -4,19 +4,26 @@ from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
if TYPE_CHECKING: from tinygrad.uop.ops import UOp
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint: TypeAlias = "UOp | int"
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
class MovementMixin:
# required to implement
def _mop(self, op:Ops, arg) -> Self: raise NotImplementedError
def _mop(self, op: Ops, arg) -> Self:
raise NotImplementedError
@property
def shape(self) -> tuple[sint, ...]: raise NotImplementedError
def shape(self) -> tuple[sint, ...]:
raise NotImplementedError
# great functions you get!
@property
@@ -42,18 +49,21 @@ class MovementMixin:
"""
return prod(self.shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
def _resolve_dim(self, dim: int, *, extra: bool = False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
if not -max(1, total) <= dim <= max(1, total) - 1:
raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}")
return dim + total if dim < 0 else dim
def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Self:
if self.shape == new_shape: return self
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self:
if self.shape == new_shape:
return self
if self.ndim > len(new_shape):
raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
shape, _ = _align_left(self.shape, new_shape)
# for each dimension, check either dim is 1, or it does not change
if not all(s == ns or s == 1 for s,ns in zip(shape, new_shape)):
if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
reshaped = self.reshape(shape)
ret = reshaped._mop(Ops.EXPAND, arg=new_shape)
@@ -85,15 +95,18 @@ class MovementMixin:
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
if (c := new_shape.count(-1)) > 1:
raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c:
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape):
raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
ret = self._mop(Ops.RESHAPE, arg=new_shape)
return self if ret.shape == self.shape else ret
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Self:
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
@@ -110,8 +123,9 @@ class MovementMixin:
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)])
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
@@ -129,7 +143,8 @@ class MovementMixin:
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
if sorted(order_arg) != list(range(self.ndim)):
raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
def flip(self, axis, *args) -> Self:
@@ -150,7 +165,8 @@ class MovementMixin:
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
if len(axis_arg) != len(dedup(axis_arg)):
raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
@@ -163,7 +179,7 @@ class MovementMixin:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def squeeze(self, dim:int|None=None) -> Self:
def squeeze(self, dim: int | None = None) -> Self:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
@@ -179,11 +195,12 @@ class MovementMixin:
print(t.squeeze(1).shape)
```
"""
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
if dim is None:
return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])
def unsqueeze(self, dim:int) -> Self:
def unsqueeze(self, dim: int) -> Self:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
@@ -234,9 +251,9 @@ class MovementMixin:
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self:
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
@@ -251,9 +268,9 @@ class MovementMixin:
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])
def rearrange(self, formula:str, **sizes) -> Self:
def rearrange(self, formula: str, **sizes) -> Self:
"""
Rearranges input according to formula
@@ -264,38 +281,43 @@ class MovementMixin:
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
for name in sizes:
assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
for name in flatten((lhs, rhs)):
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
if "..." in lhs:
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
for i, name in enumerate(lhs):
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
# *** movement ops with expand ***
def repeat_interleave(self, repeats:int, dim:int|None=None) -> Self:
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
"""
Repeats elements of a tensor.
@@ -306,7 +328,10 @@ class MovementMixin:
"""
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
shp = x.shape
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
return x
def repeat(self, repeats, *args) -> Self:
"""
@@ -323,28 +348,29 @@ class MovementMixin:
"""
repeats = argfix(repeats, *args)
base_shape = _align_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r,s in zip(repeats, base_shape)])
expanded_shape = flatten([[s] if r == 1 else [r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
final_shape = [r * s for r, s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
# **** pool level ****
def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Self:
def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :]
assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)]
# input size scaling factor to make sure shrink for stride is possible
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)]
# repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)])
# handle dilation
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)])
x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_)))
# handle stride
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_)))
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])