mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
swap action (#5565)
* swap action * don't allow same action expressed differently * oops, was reversed * one line is fine * only swap
This commit is contained in:
@@ -22,7 +22,7 @@ from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
@@ -425,7 +425,8 @@ class Kernel:
|
||||
axis = opt.real_axis(self)
|
||||
check(axis < len(self.full_shape), "invalid axis")
|
||||
|
||||
if opt.amt is not None:
|
||||
if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs
|
||||
elif opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
||||
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
||||
@@ -478,6 +479,11 @@ class Kernel:
|
||||
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
||||
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
self.dont_use_locals = True
|
||||
elif opt.op is OptOps.SWAP:
|
||||
check(axis < amt and amt < self.global_dims, "swap is only for globals with axis < amt")
|
||||
permute = list(range(self.shape_len))
|
||||
permute[axis], permute[amt] = permute[amt], permute[axis]
|
||||
self.reshape_and_permute(None, tuple(permute))
|
||||
elif opt.op is OptOps.MERGE:
|
||||
check(axis >= self.shape_len-self.upcasted, "only merge upcasted")
|
||||
check(self.full_shape[axis:axis+2] == self.output_shape[axis:axis+2], "can't merge reduces")
|
||||
@@ -495,7 +501,7 @@ class Kernel:
|
||||
for i,st in enumerate(self.sts):
|
||||
if self.sts[i].shape[axis] == 1: continue # reduced
|
||||
check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}")
|
||||
if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]):
|
||||
if (ru := round_up(cast(int, self.sts[i].shape[axis]), amt) - self.sts[i].shape[axis]):
|
||||
# pad right seems to be faster
|
||||
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
||||
padded = True
|
||||
|
||||
@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for a
|
||||
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
||||
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
|
||||
Reference in New Issue
Block a user