From 55bd136746c2fa09de71ac7d76e686c834d11bcf Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 2 Nov 2024 13:44:14 -0400 Subject: [PATCH] clean up reshape_and_permute (#7488) probably will rewrite it later as reshape and permute function on Kernel, but for now it's shorter with better types --- tinygrad/codegen/kernel.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0114d0e9e4..ee0d57e3be 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,7 @@ from __future__ import annotations import itertools, functools from dataclasses import dataclass from collections import defaultdict -from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict +from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \ @@ -198,13 +198,10 @@ class Kernel: # ******************** base simplifiers ******************** # apply reshape and permute to all shapetrackers - def reshape_and_permute(self, new_shape_fxn, axis): - new_sts = [] - for st in self.sts: - if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st = st.permute(tuple(axis)) - new_sts.append(st) - self.sts = new_sts + def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]): + def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st + def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st + self.sts = [permute(reshape(st)) for st in self.sts] # drops the final dimension def upcast(self):