remove final range in heuristic [pr] (#11251)

all dims are based on AxisType now
This commit is contained in:
chenyu
2025-07-15 11:39:15 -04:00
committed by GitHub
parent d7adc24083
commit 629fa21b6b

View File

@@ -17,7 +17,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
for global_idx in range(k.global_dims):
for global_idx in k.axes_of(AxisType.GLOBAL):
if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3:
print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
@@ -63,7 +63,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
upcasted_axis: set[int] = set()
while resolve(prod(k.sts[0].shape[i] for i in k.upcastable_dims) >= 1024):
xb_choices = []
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
# consider all upcastable axes with 3 or 4 upcast (128 on the DSP)
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
@@ -73,7 +73,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else: break