remove a TODO prod(k.full_shape[k.first_upcast:]) (#11191)

IMAGE=2 test/test_ops.py works now
This commit is contained in:
chenyu
2025-07-12 10:16:56 -04:00
committed by GitHub
parent 6f5250d158
commit 12b04efd69

View File

@@ -53,9 +53,6 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(k.full_shape[k.first_upcast:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: list[int] = []
@@ -64,7 +61,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
prod(k.full_shape[k.first_upcast:]) * prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))