mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
with unroll, the action space goes from 161 -> 127 (#2060)
* with unroll, the action space goes from 161 -> 127 * more reliable instrumentation * beam search is so op * beam bugfix
This commit is contained in:
@@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, Device, Compiled
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, flatten
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
@@ -52,20 +52,24 @@ if __name__ == "__main__":
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
|
||||
# try a greedy search
|
||||
if getenv("GREEDY"):
|
||||
# try a beam search
|
||||
if getenv("BEAM"):
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
if str(lin.ast) in global_db:
|
||||
for ao in global_db[str(lin.ast)]:
|
||||
lin.apply_opt(ao)
|
||||
else:
|
||||
best_tm = float('inf')
|
||||
beam = [lin]
|
||||
while 1:
|
||||
acted_lins = get_linearizer_actions(lin)
|
||||
timed_lins = {k:time_linearizer(v, rawbufs) for k,v in acted_lins.items()}
|
||||
opts = sorted(timed_lins.items(), key=lambda x: x[1])
|
||||
if opts[0][0] == 0: break # we are done
|
||||
lin = acted_lins[opts[0][0]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", lin.colored_shape())
|
||||
acted_lins = flatten([get_linearizer_actions(lin).items() for lin in beam])
|
||||
timed_lins = [(v,time_linearizer(v, rawbufs)) for k,v in acted_lins if k != 0]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
|
||||
best_tm = opts[0][1]
|
||||
beam = [x[0] for x in opts[:getenv("BEAM")]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
lin = beam[0]
|
||||
global_db[str(lin.ast)] = lin.applied_opts
|
||||
lins.append(lin)
|
||||
|
||||
|
||||
15
extra/optimization/get_action_space.py
Normal file
15
extra/optimization/get_action_space.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from tqdm import tqdm
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.search import actions
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(False, False, False)
|
||||
tactions = set()
|
||||
for ast_str in tqdm(ast_strs):
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
lin.hand_coded_optimizations()
|
||||
for o in lin.applied_opts:
|
||||
assert o in actions
|
||||
tactions.add(o)
|
||||
print(len(tactions))
|
||||
print(sorted(list(tactions)))
|
||||
21
extra/optimization/test_time_linearizer.py
Normal file
21
extra/optimization/test_time_linearizer.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
for i, ast_str in enumerate(ast_strs):
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
test_tm = time_linearizer(lin, rawbufs)
|
||||
if test_tm < 1e-2: continue
|
||||
print(f"EXAMPLE {i}")
|
||||
acted_lins = get_linearizer_actions(lin)
|
||||
ok_avg, short_avg = 0, 0
|
||||
for k,v in acted_lins.items():
|
||||
tm1 = time_linearizer(v, rawbufs)
|
||||
tm2 = time_linearizer(v, rawbufs)
|
||||
tm3 = time_linearizer(v, rawbufs, False)
|
||||
print(v.colored_shape(50), f"{tm1*1e3:10.2f} {tm2*1e3:10.2f} {tm3*1e3:10.2f} : {((tm1-tm2)/tm1)*100:5.2f}% vs {((tm1-tm3)/tm1)*100:5.2f}%")
|
||||
ok_avg += (tm1-tm2)/tm1
|
||||
short_avg += (tm1-tm3)/tm1
|
||||
print(f"{ok_avg/len(acted_lins)*100:5.2f}% vs {short_avg/len(acted_lins)*100:5.2f}%")
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
@@ -322,20 +322,25 @@ class OptimizedKernel(Kernel):
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
self.applied_opts.append(opt)
|
||||
assert self.full_shape[opt.axis] % opt.amt == 0, "no longer valid shift"
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else 0)
|
||||
assert self.full_shape[axis] % opt.amt == 0, "no longer valid shift"
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert opt.axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
|
||||
self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce)
|
||||
assert axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
|
||||
self.shift_to(axis, opt.amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
self.shift_to(opt.axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.shift_to(axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(opt.amt)
|
||||
elif opt.op == OptOps.GROUPTOP: # green
|
||||
self.shift_to(opt.axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.shift_to(axis, opt.amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(opt.amt)
|
||||
elif opt.op == OptOps.UPCAST: # yellow (or purple if it's a reduce axis)
|
||||
assert opt.axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
self.shift_to(opt.axis, opt.amt, insert_before=None if opt.axis < self.first_reduce else len(self.full_unupcasted_shape))
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted"
|
||||
self.shift_to(axis, opt.amt, insert_before=len(self.full_unupcasted_shape))
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
self.shift_to(axis, opt.amt, insert_before=None)
|
||||
self.upcast()
|
||||
self.simplify_ones()
|
||||
|
||||
@@ -345,7 +350,8 @@ class OptimizedKernel(Kernel):
|
||||
if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType:
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
|
||||
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
if unit_stride_axes_mul_4[0] < self.first_reduce: self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
else: self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
@@ -432,14 +438,14 @@ class OptimizedKernel(Kernel):
|
||||
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s))
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, s2))
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, s2))
|
||||
else:
|
||||
for splits in [4]:
|
||||
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
|
||||
break
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
|
||||
@@ -9,20 +9,22 @@ from collections import defaultdict
|
||||
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
actions = [
|
||||
Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=9), Opt(op=OptOps.UPCAST, axis=0, amt=10), Opt(op=OptOps.UPCAST, axis=0, amt=12), Opt(op=OptOps.UPCAST, axis=0, amt=24),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=9), Opt(op=OptOps.UPCAST, axis=1, amt=10), Opt(op=OptOps.UPCAST, axis=1, amt=11), Opt(op=OptOps.UPCAST, axis=1, amt=12), Opt(op=OptOps.UPCAST, axis=1, amt=13), Opt(op=OptOps.UPCAST, axis=1, amt=14), Opt(op=OptOps.UPCAST, axis=1, amt=15), Opt(op=OptOps.UPCAST, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=17), Opt(op=OptOps.UPCAST, axis=1, amt=20), Opt(op=OptOps.UPCAST, axis=1, amt=21), Opt(op=OptOps.UPCAST, axis=1, amt=24), Opt(op=OptOps.UPCAST, axis=1, amt=25), Opt(op=OptOps.UPCAST, axis=1, amt=28), Opt(op=OptOps.UPCAST, axis=1, amt=32),
|
||||
Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=9), Opt(op=OptOps.UPCAST, axis=2, amt=10), Opt(op=OptOps.UPCAST, axis=2, amt=11), Opt(op=OptOps.UPCAST, axis=2, amt=12), Opt(op=OptOps.UPCAST, axis=2, amt=14), Opt(op=OptOps.UPCAST, axis=2, amt=15), Opt(op=OptOps.UPCAST, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=2, amt=17), Opt(op=OptOps.UPCAST, axis=2, amt=20), Opt(op=OptOps.UPCAST, axis=2, amt=21), Opt(op=OptOps.UPCAST, axis=2, amt=24), Opt(op=OptOps.UPCAST, axis=2, amt=25), Opt(op=OptOps.UPCAST, axis=2, amt=27), Opt(op=OptOps.UPCAST, axis=2, amt=28), Opt(op=OptOps.UPCAST, axis=2, amt=30), Opt(op=OptOps.UPCAST, axis=2, amt=31), Opt(op=OptOps.UPCAST, axis=2, amt=32),
|
||||
Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7), Opt(op=OptOps.UPCAST, axis=3, amt=8), Opt(op=OptOps.UPCAST, axis=3, amt=9), Opt(op=OptOps.UPCAST, axis=3, amt=10), Opt(op=OptOps.UPCAST, axis=3, amt=11), Opt(op=OptOps.UPCAST, axis=3, amt=12), Opt(op=OptOps.UPCAST, axis=3, amt=13), Opt(op=OptOps.UPCAST, axis=3, amt=14), Opt(op=OptOps.UPCAST, axis=3, amt=15), Opt(op=OptOps.UPCAST, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=24), Opt(op=OptOps.UPCAST, axis=3, amt=26), Opt(op=OptOps.UPCAST, axis=3, amt=27), Opt(op=OptOps.UPCAST, axis=3, amt=28), Opt(op=OptOps.UPCAST, axis=3, amt=32),
|
||||
Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5), Opt(op=OptOps.UPCAST, axis=4, amt=6), Opt(op=OptOps.UPCAST, axis=4, amt=7), Opt(op=OptOps.UPCAST, axis=4, amt=8), Opt(op=OptOps.UPCAST, axis=4, amt=9), Opt(op=OptOps.UPCAST, axis=4, amt=10), Opt(op=OptOps.UPCAST, axis=4, amt=11), Opt(op=OptOps.UPCAST, axis=4, amt=12), Opt(op=OptOps.UPCAST, axis=4, amt=13), Opt(op=OptOps.UPCAST, axis=4, amt=14), Opt(op=OptOps.UPCAST, axis=4, amt=15), Opt(op=OptOps.UPCAST, axis=4, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=20), Opt(op=OptOps.UPCAST, axis=4, amt=24), Opt(op=OptOps.UPCAST, axis=4, amt=26), Opt(op=OptOps.UPCAST, axis=4, amt=27), Opt(op=OptOps.UPCAST, axis=4, amt=28), Opt(op=OptOps.UPCAST, axis=4, amt=30), Opt(op=OptOps.UPCAST, axis=4, amt=32),
|
||||
Opt(op=OptOps.UPCAST, axis=5, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4), Opt(op=OptOps.UPCAST, axis=5, amt=5), Opt(op=OptOps.UPCAST, axis=5, amt=6), Opt(op=OptOps.UPCAST, axis=5, amt=7), Opt(op=OptOps.UPCAST, axis=5, amt=8), Opt(op=OptOps.UPCAST, axis=5, amt=9), Opt(op=OptOps.UPCAST, axis=5, amt=11), Opt(op=OptOps.UPCAST, axis=5, amt=13), Opt(op=OptOps.UPCAST, axis=5, amt=16), Opt(op=OptOps.UPCAST, axis=5, amt=24), Opt(op=OptOps.UPCAST, axis=5, amt=26), Opt(op=OptOps.UPCAST, axis=5, amt=27),
|
||||
Opt(op=OptOps.UPCAST, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=6, amt=3), Opt(op=OptOps.UPCAST, axis=6, amt=4), Opt(op=OptOps.UPCAST, axis=6, amt=5), Opt(op=OptOps.UPCAST, axis=6, amt=6), Opt(op=OptOps.UPCAST, axis=6, amt=7), Opt(op=OptOps.UPCAST, axis=6, amt=13), Opt(op=OptOps.UPCAST, axis=6, amt=16), Opt(op=OptOps.UPCAST, axis=6, amt=24), Opt(op=OptOps.UPCAST, axis=6, amt=26), Opt(op=OptOps.UPCAST, axis=6, amt=27), Opt(op=OptOps.UPCAST, axis=6, amt=31),
|
||||
Opt(op=OptOps.UPCAST, axis=7, amt=2), Opt(op=OptOps.UPCAST, axis=7, amt=3), Opt(op=OptOps.UPCAST, axis=7, amt=4), Opt(op=OptOps.UPCAST, axis=7, amt=7),
|
||||
Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=5), Opt(op=OptOps.UPCAST, axis=0, amt=6), Opt(op=OptOps.UPCAST, axis=0, amt=7),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=5), Opt(op=OptOps.UPCAST, axis=1, amt=6), Opt(op=OptOps.UPCAST, axis=1, amt=7),
|
||||
Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=5), Opt(op=OptOps.UPCAST, axis=2, amt=6), Opt(op=OptOps.UPCAST, axis=2, amt=7),
|
||||
Opt(op=OptOps.UPCAST, axis=3, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=5), Opt(op=OptOps.UPCAST, axis=3, amt=6), Opt(op=OptOps.UPCAST, axis=3, amt=7),
|
||||
Opt(op=OptOps.UPCAST, axis=4, amt=2), Opt(op=OptOps.UPCAST, axis=4, amt=3), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=4, amt=5),
|
||||
Opt(op=OptOps.UPCAST, axis=5, amt=3), Opt(op=OptOps.UPCAST, axis=5, amt=4),
|
||||
Opt(op=OptOps.UNROLL, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=5), Opt(op=OptOps.UNROLL, axis=0, amt=6), Opt(op=OptOps.UNROLL, axis=0, amt=7), Opt(op=OptOps.UNROLL, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=9), Opt(op=OptOps.UNROLL, axis=0, amt=10), Opt(op=OptOps.UNROLL, axis=0, amt=11), Opt(op=OptOps.UNROLL, axis=0, amt=12), Opt(op=OptOps.UNROLL, axis=0, amt=13), Opt(op=OptOps.UNROLL, axis=0, amt=14), Opt(op=OptOps.UNROLL, axis=0, amt=15), Opt(op=OptOps.UNROLL, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=17), Opt(op=OptOps.UNROLL, axis=0, amt=20), Opt(op=OptOps.UNROLL, axis=0, amt=21), Opt(op=OptOps.UNROLL, axis=0, amt=24), Opt(op=OptOps.UNROLL, axis=0, amt=25), Opt(op=OptOps.UNROLL, axis=0, amt=27), Opt(op=OptOps.UNROLL, axis=0, amt=28), Opt(op=OptOps.UNROLL, axis=0, amt=30), Opt(op=OptOps.UNROLL, axis=0, amt=31), Opt(op=OptOps.UNROLL, axis=0, amt=32),
|
||||
Opt(op=OptOps.UNROLL, axis=1, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=3), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=5), Opt(op=OptOps.UNROLL, axis=1, amt=6), Opt(op=OptOps.UNROLL, axis=1, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=9), Opt(op=OptOps.UNROLL, axis=1, amt=10), Opt(op=OptOps.UNROLL, axis=1, amt=11), Opt(op=OptOps.UNROLL, axis=1, amt=16), Opt(op=OptOps.UNROLL, axis=1, amt=24), Opt(op=OptOps.UNROLL, axis=1, amt=26), Opt(op=OptOps.UNROLL, axis=1, amt=27), Opt(op=OptOps.UNROLL, axis=1, amt=28), Opt(op=OptOps.UNROLL, axis=1, amt=32),
|
||||
Opt(op=OptOps.UNROLL, axis=2, amt=2), Opt(op=OptOps.UNROLL, axis=2, amt=3), Opt(op=OptOps.UNROLL, axis=2, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=5), Opt(op=OptOps.UNROLL, axis=2, amt=6), Opt(op=OptOps.UNROLL, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=2, amt=8), Opt(op=OptOps.UNROLL, axis=2, amt=9), Opt(op=OptOps.UNROLL, axis=2, amt=11), Opt(op=OptOps.UNROLL, axis=2, amt=12), Opt(op=OptOps.UNROLL, axis=2, amt=13), Opt(op=OptOps.UNROLL, axis=2, amt=14), Opt(op=OptOps.UNROLL, axis=2, amt=16), Opt(op=OptOps.UNROLL, axis=2, amt=24), Opt(op=OptOps.UNROLL, axis=2, amt=26), Opt(op=OptOps.UNROLL, axis=2, amt=27), Opt(op=OptOps.UNROLL, axis=2, amt=28), Opt(op=OptOps.UNROLL, axis=2, amt=30), Opt(op=OptOps.UNROLL, axis=2, amt=31), Opt(op=OptOps.UNROLL, axis=2, amt=32),
|
||||
Opt(op=OptOps.UNROLL, axis=3, amt=2), Opt(op=OptOps.UNROLL, axis=3, amt=3), Opt(op=OptOps.UNROLL, axis=3, amt=11),
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=1, amt=16),
|
||||
Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=16),
|
||||
Opt(op=OptOps.LOCAL, axis=3, amt=2), Opt(op=OptOps.LOCAL, axis=3, amt=3), Opt(op=OptOps.LOCAL, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=8), Opt(op=OptOps.LOCAL, axis=3, amt=16),
|
||||
Opt(op=OptOps.LOCAL, axis=4, amt=2), Opt(op=OptOps.LOCAL, axis=4, amt=3), Opt(op=OptOps.LOCAL, axis=4, amt=16),
|
||||
Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.GROUP, axis=1, amt=4), Opt(op=OptOps.GROUP, axis=1, amt=8), Opt(op=OptOps.GROUP, axis=2, amt=8),
|
||||
Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256),
|
||||
Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256),
|
||||
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)]
|
||||
@@ -30,7 +32,7 @@ device:Compiled = cast(Compiled, Device[Device.DEFAULT])
|
||||
|
||||
# returns time in seconds
|
||||
logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> float:
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float:
|
||||
if should_copy: lin = deepcopy(lin) # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
@@ -39,21 +41,22 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
||||
real_global_size = prg.global_size[:]
|
||||
if allow_test_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
while prod(test_global_size) > 16384:
|
||||
while prod(test_global_size) > max_global_size:
|
||||
for j in range(2,-1,-1):
|
||||
if test_global_size[j] > 1:
|
||||
if test_global_size[j] > 16:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
factor = prod(prg.global_size) / prod(test_global_size)
|
||||
prg.global_size = test_global_size
|
||||
#print(real_global_size, test_global_size, factor)
|
||||
else:
|
||||
factor = 1
|
||||
tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)]
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
print("FAILED")
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
#print("FAILED")
|
||||
#print(lin.ast)
|
||||
#print(lin.applied_opts)
|
||||
tms = [float('inf')]
|
||||
if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n")
|
||||
return min(tms)
|
||||
|
||||
Reference in New Issue
Block a user