mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
update float4 condition in hcopt (#11211)
don't need all upcast candidates to be upcast-able, only check the actual one
This commit is contained in:
@@ -39,11 +39,11 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
|
||||
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
|
||||
if unit_stride_axes_mul_4[0] < k.first_reduce:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and (axis:=unit_stride_axes_mul_4[0]) < k.first_upcast:
|
||||
if axis < k.first_reduce:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||||
else:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, axis-k.first_reduce, 4))
|
||||
|
||||
# no more opt if we are grouping
|
||||
if k.group_for_reduces: return k.applied_opts
|
||||
|
||||
Reference in New Issue
Block a user