swap action (#5565)

* swap action

* don't allow same action expressed differently

* oops, was reversed

* one line is fine

* only swap
This commit is contained in:
George Hotz
2024-07-18 15:19:40 -07:00
committed by GitHub
parent e7a057c20f
commit 946da97820
2 changed files with 10 additions and 3 deletions

View File

@@ -22,7 +22,7 @@ from enum import Enum, auto
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto(); SWAP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
class KernelOptError(Exception): pass
@@ -425,7 +425,8 @@ class Kernel:
axis = opt.real_axis(self)
check(axis < len(self.full_shape), "invalid axis")
if opt.amt is not None:
if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
elif opt.amt is not None:
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
@@ -478,6 +479,11 @@ class Kernel:
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
self.dont_use_locals = True
elif opt.op is OptOps.SWAP:
check(axis < amt and amt < self.global_dims, "swap is only for globals with axis < amt")
permute = list(range(self.shape_len))
permute[axis], permute[amt] = permute[amt], permute[axis]
self.reshape_and_permute(None, tuple(permute))
elif opt.op is OptOps.MERGE:
check(axis >= self.shape_len-self.upcasted, "only merge upcasted")
check(self.full_shape[axis:axis+2] == self.output_shape[axis:axis+2], "can't merge reduces")
@@ -495,7 +501,7 @@ class Kernel:
for i,st in enumerate(self.sts):
if self.sts[i].shape[axis] == 1: continue # reduced
check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
if (ru := round_up(cast(int, self.sts[i].shape[axis]), amt) - self.sts[i].shape[axis]):
# pad right seems to be faster
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
padded = True

View File

@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for a
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def _get_test_global_size(global_size, max_global_size, var_vals):