diff --git a/tinygrad/opt/heuristic.py b/tinygrad/opt/heuristic.py index 02dd3c7c53..be7a990991 100644 --- a/tinygrad/opt/heuristic.py +++ b/tinygrad/opt/heuristic.py @@ -49,11 +49,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: # **** below this line need to be optional and benchmarked **** # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) - # this can be made much smarter to_upcast: list[int] = [] # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) for axis in range(k.first_reduce): - # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent # for now skip upcasting here if there is a symbolic axis if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \ prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: @@ -108,8 +106,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: k.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local - local_axis_ranking = [(any(k.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(k.sts))), axis) \ - for axis in range(len(k.full_shape[:k.first_reduce]))] + local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in range(k.first_reduce)] to_local: list[tuple[int, int]] = [] for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): local_size = prod(sz for _, sz in to_local)