minor hcopt cleanup [pr] (#11231)

This commit is contained in:
chenyu
2025-07-14 09:36:25 -04:00
committed by GitHub
parent 756ba1a5f9
commit da219199f5

View File

@@ -49,11 +49,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# **** below this line need to be optional and benchmarked ****
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: list[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(k.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
@@ -108,8 +106,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(k.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(k.sts))), axis) \
for axis in range(len(k.full_shape[:k.first_reduce]))]
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in range(k.first_reduce)]
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)