From 8cbef912d2db6794c10ff0f9e1df83d158fc7d42 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 2 Nov 2025 12:56:15 +0800 Subject: [PATCH] move reshape to MathTraits (#13054) * move reshape to MathTraits * confirm it works in amd_uop_matmul --- extra/gemm/amd_uop_matmul.py | 16 ++++++++-------- tinygrad/tensor.py | 25 ++----------------------- tinygrad/uop/mathtraits.py | 35 ++++++++++++++++++++++++++++++++++- tinygrad/uop/ops.py | 2 +- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 1637a59987..1ac8c24bd7 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -88,15 +88,15 @@ def hand_spec_kernel3(): # --------------------------- # GLOBAL -> LOCAL (As, Bs) # --------------------------- - b = b.reshape((N // BLOCK_K, BLOCK_K, - N // BLOCK_N, BLOCK_N)) + b = b.reshape(N // BLOCK_K, BLOCK_K, + N // BLOCK_N, BLOCK_N) i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1) index_x = tid % BLOCK_N index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i) - a = a.reshape((N // BLOCK_M, BLOCK_M, - N // BLOCK_K, BLOCK_K)) + a = a.reshape(N // BLOCK_M, BLOCK_M, + N // BLOCK_K, BLOCK_K) i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2) index_x = tid % BLOCK_K index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i @@ -113,12 +113,12 @@ def hand_spec_kernel3(): # --------------------------- # LOCAL -> REG (per-wave tiles) # --------------------------- - Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)) + Bs_view = Bs.reshape(BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN) iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4) i = UOp.range(TN, 5) B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i)) - As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)) + As_view = As.reshape(BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM) iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6) i = UOp.range(TM, 7) A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i)) @@ -139,8 +139,8 @@ def hand_spec_kernel3(): # --------------------------- # REG -> GLOBAL (epilogue) # --------------------------- - c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM, - N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)) + c = c.reshape(N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM, + N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN) iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000) yt = UOp.range(TM, 1001) iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 29d91a1b98..ae59849198 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -10,7 +10,7 @@ from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, p from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient from tinygrad.uop.mathtraits import MathTrait -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender +from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -1038,28 +1038,7 @@ class Tensor(MathTrait): # ***** movement low level ops ***** - def view(self, shape:tuple[sint, ...], *args) -> Tensor: - """`.view` is an alias for `.reshape`.""" - return self.reshape(shape, *args) - - def reshape(self, shape, *args) -> Tensor: - """ - Returns a tensor with the same data as the original tensor but with a different shape. - `shape` can be passed as a tuple or as separate arguments. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6) - print(t.reshape(2, 3).numpy()) - ``` - """ - # resolve None and args - new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))]) - # resolve -1 - if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") - if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - if resolve(prod(self.shape) != prod(new_shape), True): - raise ValueError(f"size mismatch, can't reshape ({', '.join(srender(d) for d in self.shape)}) -> ({', '.join(srender(d) for d in new_shape)})") - return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self + def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg) def expand(self, shape, *args) -> Tensor: """ diff --git a/tinygrad/uop/mathtraits.py b/tinygrad/uop/mathtraits.py index 27c3beeb45..03008db4bc 100644 --- a/tinygrad/uop/mathtraits.py +++ b/tinygrad/uop/mathtraits.py @@ -1,6 +1,10 @@ -from typing import TypeVar +from typing import TypeVar, TypeAlias, TYPE_CHECKING from tinygrad.uop import Ops from tinygrad.dtype import dtypes, ConstType +from tinygrad.helpers import prod, argfix +if TYPE_CHECKING: + from tinygrad.uop.ops import UOp + sint:TypeAlias = UOp|int TMT = TypeVar("TMT", bound="MathTrait") class MathTrait: @@ -171,3 +175,32 @@ class MathTrait: def exp2(self): return self.alu(Ops.EXP2) def pow(self:TMT, x:TMT|ConstType): return self.alu(Ops.POW, self.ufix(x)) def __pow__(self:TMT, x:TMT|ConstType): return self.pow(x) + + # **** movement ops **** + + # required to implement + def _mop(self:TMT, op:Ops, arg) -> TMT: raise NotImplementedError + @property + def shape(self) -> tuple["sint", ...]: raise NotImplementedError + + def view(self:TMT, shape, *args) -> TMT: + """`.view` is an alias for `.reshape`.""" + return self.reshape(shape, *args) + + def reshape(self:TMT, shape, *args) -> TMT: + """ + Returns a tensor with the same data as the original tensor but with a different shape. + `shape` can be passed as a tuple or as separate arguments. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(6) + print(t.reshape(2, 3).numpy()) + ``` + """ + # resolve None and args + new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))]) + # resolve -1 + if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) + if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") + return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f404df9c3f..1afebc7910 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -533,7 +533,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # in these four, if the shape doesn't change we can return self def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False) - def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True) + #def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True) def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True) def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True) def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)