From 946da97820eb14f6df0ee6236c6be91ddc7be686 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:19:40 -0700 Subject: [PATCH] swap action (#5565) * swap action * don't allow same action expressed differently * oops, was reversed * one line is fine * only swap --- tinygrad/codegen/kernel.py | 12 +++++++++--- tinygrad/engine/search.py | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index a0f22f7626..163a1ea769 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -22,7 +22,7 @@ from enum import Enum, auto class OptOps(Enum): TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 - GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto() # noqa: E702 + GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); MERGE = auto(); SWAP = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value class KernelOptError(Exception): pass @@ -425,7 +425,8 @@ class Kernel: axis = opt.real_axis(self) check(axis < len(self.full_shape), "invalid axis") - if opt.amt is not None: + if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs + elif opt.amt is not None: amt = opt.amt if opt.amt != 0 else self.full_shape[axis] check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless") if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift") @@ -478,6 +479,11 @@ class Kernel: check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals") self.dont_use_locals = True + elif opt.op is OptOps.SWAP: + check(axis < amt and amt < self.global_dims, "swap is only for globals with axis < amt") + permute = list(range(self.shape_len)) + permute[axis], permute[amt] = permute[amt], permute[axis] + self.reshape_and_permute(None, tuple(permute)) elif opt.op is OptOps.MERGE: check(axis >= self.shape_len-self.upcasted, "only merge upcasted") check(self.full_shape[axis:axis+2] == self.output_shape[axis:axis+2], "can't merge reduces") @@ -495,7 +501,7 @@ class Kernel: for i,st in enumerate(self.sts): if self.sts[i].shape[axis] == 1: continue # reduced check(self.sts[i].shape[axis] > amt//4, f"pad adds more than quadruple the work {self.sts[i].shape[axis]=} > {amt//4=}") - if (ru := round_up(cast(int, self.sts[i].shape[axis]), cast(int, amt)) - self.sts[i].shape[axis]): + if (ru := round_up(cast(int, self.sts[i].shape[axis]), amt) - self.sts[i].shape[axis]): # pad right seems to be faster self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1)) padded = True diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 501393b09a..2bd08d5017 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for a if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) +actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals):