mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
This reverts commit 8cf0bb9351.
This commit is contained in:
@@ -472,12 +472,12 @@ class TestLinearizerOpts(unittest.TestCase):
|
||||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check last unroll
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check last local
|
||||
[Opt(OptOps.LASTLOCAL, 0, 4)], # check last local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
# [Opt(OptOps.GROUPTOP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)],
|
||||
# [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
|
||||
], apply_tc=True)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); GROUPTOP = auto() # noqa: E702
|
||||
UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
@@ -197,7 +197,7 @@ class OptimizedKernel(Kernel):
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads])))
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1)
|
||||
|
||||
# assert tensor core and prevent extra_opts from altering the key shape structure
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
@@ -216,7 +216,7 @@ class OptimizedKernel(Kernel):
|
||||
if self.tensor_core and s0_exists:
|
||||
for upc in [4,2]:
|
||||
if self.full_shape[s0] % upc == 0:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, s0, upc))
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
|
||||
break
|
||||
|
||||
# alias buffer
|
||||
@@ -228,16 +228,26 @@ class OptimizedKernel(Kernel):
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
self.applied_opts.append(opt)
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUPTOP else 0))
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0))
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
|
||||
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "local is for non-reduce that aren't TC dims"
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
assert not(self.tensor_core), "can't local with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "group is for reduce dims"
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group"
|
||||
assert not(self.tensor_core), "can't group with tensor cores"
|
||||
self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(amt)
|
||||
@@ -247,7 +257,7 @@ class OptimizedKernel(Kernel):
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce-(len(self.tensor_core.threads) if self.tensor_core else 0), "upcast is for non-reduce that aren't TC dims"
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
assert amt <= 8, "don't upcast more than 8"
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
@@ -292,7 +302,7 @@ class OptimizedKernel(Kernel):
|
||||
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}")
|
||||
if MV_THREADS_PER_ROW > 1:
|
||||
self.apply_opt(Opt(OptOps.GROUPTOP, 0, MV_THREADS_PER_ROW))
|
||||
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1:
|
||||
|
||||
@@ -13,7 +13,7 @@ actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUPTOP, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=8), Opt(op=OptOps.GROUPTOP, axis=1, amt=8),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user