clean up reshape_and_permute (#7488)

probably will rewrite it later as reshape and permute function on Kernel, but for now it's shorter with better types
This commit is contained in:
chenyu
2024-11-02 13:44:14 -04:00
committed by GitHub
parent 74c7b9d84a
commit 55bd136746

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import itertools, functools
from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
@@ -198,13 +198,10 @@ class Kernel:
# ******************** base simplifiers ********************
# apply reshape and permute to all shapetrackers
def reshape_and_permute(self, new_shape_fxn, axis):
new_sts = []
for st in self.sts:
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st = st.permute(tuple(axis))
new_sts.append(st)
self.sts = new_sts
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
self.sts = [permute(reshape(st)) for st in self.sts]
# drops the final dimension
def upcast(self):