mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
clean up reshape_and_permute (#7488)
probably will rewrite it later as reshape and permute function on Kernel, but for now it's shorter with better types
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import itertools, functools
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
@@ -198,13 +198,10 @@ class Kernel:
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape)))
|
||||
if axis is not None: st = st.permute(tuple(axis))
|
||||
new_sts.append(st)
|
||||
self.sts = new_sts
|
||||
def reshape_and_permute(self, new_shape_fxn:Optional[Callable[[Tuple[sint, ...]], Sequence[sint]]], axis:Optional[Sequence[int]]):
|
||||
def reshape(st:ShapeTracker): return st.reshape(tuple(new_shape_fxn(st.shape))) if new_shape_fxn is not None else st
|
||||
def permute(st:ShapeTracker): return st.permute(tuple(axis)) if axis is not None else st
|
||||
self.sts = [permute(reshape(st)) for st in self.sts]
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
|
||||
Reference in New Issue
Block a user