diff --git a/tinygrad/opt/heuristic.py b/tinygrad/opt/heuristic.py index 2b6a90df8f..f6ab24ba8b 100644 --- a/tinygrad/opt/heuristic.py +++ b/tinygrad/opt/heuristic.py @@ -28,8 +28,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]): # are we grouping? (requires local shape support) - if not [x for x in k.sts[0].unit_stride_axes() if x >= k.first_upcast and k.sts[0].shape[x]%4 == 0] and \ - k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048: + if k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]): try: # may fail due to excessive smem usage @@ -40,13 +39,11 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: # upcast float4 images for buf_index,buf in enumerate(k.bufs): unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0] - if buf.src[0].dtype.__class__ is ImageDType: - #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {k.bufs[buf_index]}" - if len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4): - if unit_stride_axes_mul_4[0] < k.first_reduce: - k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) - else: - k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4)) + if buf.src[0].dtype.__class__ is ImageDType and len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4): + if unit_stride_axes_mul_4[0] < k.first_reduce: + k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + else: + k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4)) # no more opt if we are grouping if k.group_for_reduces: return k.applied_opts