swap first and last global in hcopt / hc tc path (#5566)

This commit is contained in:
chenyu
2024-07-18 18:54:44 -04:00
committed by GitHub
parent 946da97820
commit abe29a05b0

View File

@@ -406,7 +406,8 @@ class Kernel:
if self.full_shape[tc_opts.axes[0]] % upc == 0:
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], upc))
break
# SWAP global
if self.global_dims > 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
return True
except KernelOptError:
return False
@@ -537,6 +538,8 @@ class Kernel:
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
# SWAP global
if self.global_dims > 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
return
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
@@ -569,7 +572,10 @@ class Kernel:
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
# no more opt if we are grouping
if self.group_for_reduces: return
if self.group_for_reduces:
# SWAP global
if self.global_dims > 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
return
# **** below this line need to be optional and benchmarked ****
@@ -643,6 +649,9 @@ class Kernel:
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
# SWAP global
if self.global_dims > 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
# **** kernel outputs ****
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)