diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0114d0e9e4..ee0d57e3be 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,7 @@ from __future__ import annotations import itertools, functools from dataclasses import dataclass from collections import defaultdict -from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict +from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \ @@ -198,13 +198,10 @@ class Kernel: # ******************** base simplifiers ******************** # apply reshape and permute to all shapetrackers - def reshape_and_permute(self, new_shape_fxn, axis): - new_sts = [] - for st in self.sts: - if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st = st.permute(tuple(axis)) - new_sts.append(st) - self.sts = new_sts + def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]): + def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st + def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st + self.sts = [permute(reshape(st)) for st in self.sts] # drops the final dimension def upcast(self):