remove Kernel.upcasted_axis [pr] (#11175)

This commit is contained in:
chenyu
2025-07-10 23:19:21 -04:00
committed by GitHub
parent ccd382bc6f
commit b219e47bef
2 changed files with 3 additions and 8 deletions

View File

@@ -79,7 +79,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
for axis, upcast_amount in itertools.product(range(k.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and \
any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index, st in enumerate(k.sts)):
any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in k.sts[buf_index].real_strides()[k.first_upcast:]) \
for buf_index, st in enumerate(k.sts)):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if xb_choices:

View File

@@ -11,7 +11,7 @@ from tinygrad.device import Device
from tinygrad.opt.tc import TensorCore
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, all_int, to_function_name, unwrap, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.kernelize.kernelize import view_left
@@ -116,12 +116,6 @@ class Kernel:
return ret
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
return list(zip(upcasted_shape, upcasted_stride,
[x!=y for x,y in zip(self.sts[0].shape[self.first_upcast:], self.full_shape[self.first_upcast:])]))
@property
def first_reduce(self) -> int:
for i in range(self.first_upcast):