move permute/flip/shrink to mixins (#13113)

* move permute to mixins

* move more stuff

* two more

* fix local mypy

* fix tests

* fix shrink
This commit is contained in:
George Hotz
2025-11-05 14:14:15 -08:00
committed by GitHub
parent 2d4f01fda0
commit bcfe42937f
7 changed files with 198 additions and 192 deletions

View File

@@ -57,7 +57,7 @@ class _ROCParseCtx:
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
for prog in prog_evs:
arch = "gfx%d%x%x" % ((trgt:=dev_evs[prog.device].props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
for addr, info in llvm_disasm(arch, unwrap(prog.lib)).items():
self.disasms[unwrap(prog.base) + addr] = info
self.addr2prg[unwrap(prog.base) + addr] = prog

View File

@@ -894,7 +894,7 @@ class TestNumpy(unittest.TestCase):
a = Tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
self.assertIsNot(a[...], a)
self.assertIs(a[...], a)
numpy_testing_assert_equal_helper(a[...], a)
# `a[...]` was `a` in numpy <1.9.
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
@@ -1037,9 +1037,9 @@ class TestNumpy(unittest.TestCase):
# Before `...` would return a itself.
a = Tensor([5])
self.assertIsNot(a, a[()])
self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:])
self.assertIs(a, a[()])
self.assertIs(a, a[...])
self.assertIs(a, a[:])
def test_broaderrors_indexing(self):
a = Tensor.zeros(5, 5)

View File

@@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),

View File

@@ -1,7 +1,8 @@
# mixins add syntactic sugar to Tensor and UOp
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix
from tinygrad.helpers import prod, argfix, flatten, dedup
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint:TypeAlias = UOp|int
@@ -41,10 +42,6 @@ class MovementMixin:
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def reshape(self, shape, *args) -> Self:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
@@ -61,7 +58,131 @@ class MovementMixin:
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
ret = self._mop(Ops.RESHAPE, arg=new_shape)
return self if ret.shape == self.shape else ret
def shrink(self, arg:tuple[tuple["sint", "sint"]|None, ...]) -> Self:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink(((None, (1, 3)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
"""
Returns a tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
`order` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3, 5)
print(t.shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.permute(2, 0, 1).shape)
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
def flip(self, axis, *args) -> Self:
"""
Returns a tensor that reverses the order of the original tensor along given `axis`.
`axis` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip((0, 1)).numpy())
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
# **** high level ****
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def squeeze(self, dim:int|None=None) -> Self:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(0).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(1).shape)
```
"""
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
def unsqueeze(self, dim:int) -> Self:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property
def T(self) -> Self:
"""`.T` is an alias for `.transpose()`."""
return self.transpose()
def transpose(self, dim0=1, dim1=0) -> Self:
"""
Returns a tensor that is a transposed version of the original tensor.
The given dimensions `dim0` and `dim1` are swapped.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.transpose(0, 1).numpy())
```
"""
order = list(range(self.ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
@@ -78,3 +199,60 @@ class MovementMixin:
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def rearrange(self, formula:str, **sizes) -> Self:
"""
Rearranges input according to formula
See: https://einops.rocks/api/rearrange/
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)

View File

@@ -68,7 +68,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# allgather
copied_chunks = []
for i,c in enumerate(reduced_chunks):
this_chunk = [None] * len(buf.device)
this_chunk: list[UOp|None] = [None] * len(buf.device)
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
for step in range(n_lbs-1):
dest = (i+step)%n_lbs

View File

@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
@@ -1055,65 +1055,6 @@ class Tensor(OpMixin):
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
return self._broadcast_to(new_shape)
def permute(self, order, *args) -> Tensor:
"""
Returns a tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
`order` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3, 5)
print(t.shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.permute(2, 0, 1).shape)
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return self._apply_uop(UOp.permute, arg=order_arg)
def flip(self, axis, *args) -> Tensor:
"""
Returns a tensor that reverses the order of the original tensor along given `axis`.
`axis` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip((0, 1)).numpy())
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink(((None, (1, 3)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.
@@ -1535,80 +1476,6 @@ class Tensor(OpMixin):
output_shape = _broadcast_shape(*(t.shape for t in tensors))
return tuple(t._broadcast_to(output_shape) for t in tensors)
def squeeze(self, dim:int|None=None) -> Tensor:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(0).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(1).shape)
```
"""
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
def unsqueeze(self, dim:int) -> Tensor:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property
def T(self) -> Tensor:
"""`.T` is an alias for `.transpose()`."""
return self.transpose()
def transpose(self, dim0=1, dim1=0) -> Tensor:
"""
Returns a tensor that is a transposed version of the original tensor.
The given dimensions `dim0` and `dim1` are swapped.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.transpose(0, 1).numpy())
```
"""
order = list(range(self.ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def diag(self) -> Tensor:
"""
Returns a 2-D square tensor with the elements of input as the main diagonal.
@@ -1654,46 +1521,6 @@ class Tensor(OpMixin):
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
def rearrange(self, formula:str, **sizes) -> Tensor:
"""
Rearranges input according to formula
See: https://einops.rocks/api/rearrange/
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
def masked_select(self, mask):
"""
Selects elements from `self` based on the boolean `mask`.

View File

@@ -549,12 +549,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
# in these two, we have custom logic to check if they are a no-op
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
#def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
# *** uop UNIQUE ***
@@ -1322,7 +1322,8 @@ renderer_infer = PatternMatcher([
def srcs(ctx, src): return f"({ctx[src[0]]},)" if len(src) == 1 else f"({', '.join([ctx[x] for x in src])})"
def render_marg(ctx,x:UOp):
if x.op in {Ops.PERMUTE, Ops.FLIP}: return str(x.marg)
if x.op is Ops.PERMUTE: return str(x.marg)
if x.op is Ops.FLIP: return str(tuple([i for i,x in enumerate(x.marg) if x]))
pieces = []
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]