update float4 condition in hcopt (#11211)

don't need all upcast candidates to be upcast-able, only check the actual one
This commit is contained in:
chenyu
2025-07-13 09:51:45 -04:00
committed by GitHub
parent 55c54d9745
commit e11ccf2342

View File

@@ -39,11 +39,11 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < k.first_reduce:
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and (axis:=unit_stride_axes_mul_4[0]) < k.first_upcast:
if axis < k.first_reduce:
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
else:
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, 4))
# no more opt if we are grouping
if k.group_for_reduces: return k.applied_opts