mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
minor hcopt cleanup [pr] (#11231)
This commit is contained in:
@@ -49,11 +49,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: list[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(k.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
|
||||
prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
@@ -108,8 +106,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(k.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(k.sts))), axis) \
|
||||
for axis in range(len(k.full_shape[:k.first_reduce]))]
|
||||
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in range(k.first_reduce)]
|
||||
to_local: list[tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
|
||||
Reference in New Issue
Block a user