clean up hcopt [pr] (#11205)

removed one condition that's always true
This commit is contained in:
chenyu
2025-07-12 23:10:27 -04:00
committed by GitHub
parent 2b48b961be
commit d90d837013

View File

@@ -28,8 +28,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if not [x for x in k.sts[0].unit_stride_axes() if x >= k.first_upcast and k.sts[0].shape[x]%4 == 0] and \
k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
if k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
try: # may fail due to excessive smem usage
@@ -40,13 +39,11 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
if buf.src[0].dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {k.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < k.first_reduce:
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < k.first_reduce:
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
# no more opt if we are grouping
if k.group_for_reduces: return k.applied_opts