Revert "optimizer: simplify GROUP and LOCAL to have one of each (#2162)" (#2182)

This reverts commit 8cf0bb9351.
This commit is contained in:
George Hotz
2023-10-30 10:22:26 -07:00
committed by GitHub
parent 95f7183c3a
commit 194e4ad6f8
3 changed files with 23 additions and 13 deletions

View File

@@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
[Opt(OptOps.UPCAST, 1, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
[Opt(OptOps.LOCAL, 0, 4)], # check last local
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
# [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)],
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
], apply_tc=True)

View File

@@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
@@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
for (tc_dim, tc_amt) in tc.threads:
fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
# assert tensor core and prevent extra_opts from altering the key shape structure
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
@@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
if self.tensor_core and s0_exists:
for upc in [4,2]:
if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
break
# alias buffer
@@ -228,16 +228,26 @@ class OptimizedKernel(Kernel):
def apply_opt(self, opt:Opt):
self.applied_opts.append(opt)
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
if opt.op == OptOps.LOCAL: # cyan
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
assert axis < self.first_reduce, "can't local a reduce"
assert not(self.tensor_core), "can't local with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
elif opt.op == OptOps.GROUP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
elif opt.op == OptOps.GROUPTOP: # green
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
assert not(self.tensor_core), "can't group with tensor cores"
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(amt)
@@ -247,7 +257,7 @@ class OptimizedKernel(Kernel):
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
assert axis < self.first_reduce, "upcast is for non-reduce"
assert amt <= 8, "don't upcast more than 8"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
@@ -292,7 +302,7 @@ class OptimizedKernel(Kernel):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1:

View File

@@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
]